package dr.inferencexml.distribution;

import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.SingularValueDecomposition;
import dr.inference.distribution.GeneralizedLinearModel;
import dr.inference.distribution.LinearRegression;
import dr.inference.distribution.LogLinearModel;
import dr.inference.distribution.LogisticRegression;
import dr.inference.model.DesignMatrix;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

/* loaded from: input_file:dr/inferencexml/distribution/GeneralizedLinearModelParser.class */
public class GeneralizedLinearModelParser extends AbstractXMLObjectParser {
    public static final String GLM_LIKELIHOOD = "glmModel";
    public static final String DEPENDENT_VARIABLES = "dependentVariables";
    public static final String INDEPENDENT_VARIABLES = "independentVariables";
    public static final String BASIS_MATRIX = "basis";
    public static final String FAMILY = "family";
    public static final String SCALE_VARIABLES = "scaleVariables";
    public static final String INDICATOR = "indicator";
    public static final String LOGISTIC_REGRESSION = "logistic";
    public static final String NORMAL_REGRESSION = "normal";
    public static final String LOG_NORMAL_REGRESSION = "logNormal";
    public static final String LOG_LINEAR = "logLinear";
    public static final String RANDOM_EFFECTS = "randomEffects";
    public static final String CHECK_IDENTIFIABILITY = "checkIdentifiability";
    public static final String CHECK_FULL_RANK = "checkFullRank";
    private boolean checkFullRankOfMatrix;
    private final XMLSyntaxRule[] rules = {AttributeRule.newStringRule(FAMILY), AttributeRule.newBooleanRule("checkIdentifiability", true), AttributeRule.newBooleanRule("checkFullRank", true), new ElementRule("dependentVariables", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true), new ElementRule("independentVariables", new XMLSyntaxRule[]{new ElementRule(Parameter.class, true), new ElementRule(DesignMatrix.class), new ElementRule("indicator", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true)}, 1, 3), new ElementRule("randomEffects", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, 0, 3)};

    @Override // dr.xml.XMLObjectParser
    public String getParserName() {
        return GLM_LIKELIHOOD;
    }

    @Override // dr.xml.AbstractXMLObjectParser
    public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
        GeneralizedLinearModel logLinearModel;
        XMLObject child = xMLObject.getChild("dependentVariables");
        Parameter parameter = child != null ? (Parameter) child.getChild(Parameter.class) : null;
        String stringAttribute = xMLObject.getStringAttribute(FAMILY);
        if (stringAttribute.compareTo(LOGISTIC_REGRESSION) == 0) {
            logLinearModel = new LogisticRegression(parameter);
        } else if (stringAttribute.compareTo("normal") == 0) {
            logLinearModel = new LinearRegression(parameter, false);
        } else if (stringAttribute.compareTo("logNormal") == 0) {
            logLinearModel = new LinearRegression(parameter, true);
        } else {
            if (stringAttribute.compareTo(LOG_LINEAR) != 0) {
                throw new XMLParseException("Family '" + stringAttribute + "' is not currently implemented");
            }
            logLinearModel = new LogLinearModel(parameter);
        }
        if (logLinearModel.requiresScale()) {
            XMLObject child2 = xMLObject.getChild(SCALE_VARIABLES);
            Parameter parameter2 = null;
            Parameter parameter3 = null;
            if (child2 != null) {
                parameter2 = (Parameter) child2.getChild(Parameter.class);
                XMLObject child3 = child2.getChild("indicator");
                if (child3 != null) {
                    parameter3 = (Parameter) child3.getChild(Parameter.class);
                }
            }
            if (parameter2 == null) {
                throw new XMLParseException("Family '" + stringAttribute + "' requires scale parameters");
            }
            if (parameter3 == null) {
                parameter3 = new Parameter.Default(parameter.getDimension(), 0.0d);
            } else {
                if (parameter3.getDimension() != parameter.getDimension()) {
                    throw new XMLParseException("Scale (" + parameter.getDimension() + ") and scaleDesign parameters (" + parameter3.getDimension() + ") must be the same dimension");
                }
                for (int i = 0; i < parameter3.getDimension(); i++) {
                    double parameterValue = parameter3.getParameterValue(i);
                    if (parameterValue < 1.0d || parameterValue > parameter2.getDimension()) {
                        throw new XMLParseException("Invalid scaleDesign value");
                    }
                    parameter3.setParameterValue(i, parameterValue - 1.0d);
                }
            }
            logLinearModel.addScaleParameter(parameter2, parameter3);
        }
        addIndependentParameters(xMLObject, logLinearModel, parameter);
        addRandomEffects(xMLObject, logLinearModel, parameter);
        if (((Boolean) xMLObject.getAttribute("checkIdentifiability", true)).booleanValue() && !logLinearModel.getAllIndependentVariablesIdentifiable()) {
            throw new XMLParseException("All design matrix predictors are not identifiable in " + xMLObject.getId());
        }
        this.checkFullRankOfMatrix = ((Boolean) xMLObject.getAttribute("checkFullRank", true)).booleanValue();
        return logLinearModel;
    }

    public void addRandomEffects(XMLObject xMLObject, GeneralizedLinearModel generalizedLinearModel, Parameter parameter) throws XMLParseException {
        int childCount = xMLObject.getChildCount();
        for (int i = 0; i < childCount; i++) {
            if (xMLObject.getChildName(i).compareTo("randomEffects") == 0) {
                Parameter parameter2 = (Parameter) ((XMLObject) xMLObject.getChild(i)).getChild(Parameter.class);
                checkRandomEffectsDimensions(parameter2, parameter);
                generalizedLinearModel.addRandomEffectsParameter(parameter2);
            }
        }
    }

    public void addIndependentParameters(XMLObject xMLObject, GeneralizedLinearModel generalizedLinearModel, Parameter parameter) throws XMLParseException {
        int childCount = xMLObject.getChildCount();
        for (int i = 0; i < childCount; i++) {
            if (xMLObject.getChildName(i).compareTo("independentVariables") == 0) {
                XMLObject xMLObject2 = (XMLObject) xMLObject.getChild(i);
                Parameter parameter2 = (Parameter) xMLObject2.getChild(Parameter.class);
                DesignMatrix designMatrix = (DesignMatrix) xMLObject2.getChild(DesignMatrix.class);
                checkDimensions(parameter2, parameter, designMatrix);
                XMLObject child = xMLObject2.getChild("indicator");
                Parameter parameter3 = null;
                if (child != null) {
                    parameter3 = (Parameter) child.getChild(Parameter.class);
                    if (parameter3.getDimension() <= 1) {
                        parameter3.setDimension(parameter2.getDimension());
                    }
                    if (parameter3.getDimension() != parameter2.getDimension()) {
                        throw new XMLParseException("dim(" + parameter2.getId() + ") != dim(" + parameter3.getId() + ")");
                    }
                }
                if (this.checkFullRankOfMatrix) {
                    checkFullRank(designMatrix);
                }
                generalizedLinearModel.addIndependentParameter(parameter2, designMatrix, parameter3);
            }
        }
    }

    private void checkFullRank(DesignMatrix designMatrix) throws XMLParseException {
        int columnDimension = designMatrix.getColumnDimension();
        int rank = new SingularValueDecomposition(new DenseDoubleMatrix2D(designMatrix.getParameterAsMatrix())).rank();
        if (rank != columnDimension) {
            throw new XMLParseException("rank(" + designMatrix.getId() + ") = " + rank + ".\nMatrix is not of full rank as colDim(" + designMatrix.getId() + ") = " + columnDimension);
        }
    }

    private void checkRandomEffectsDimensions(Parameter parameter, Parameter parameter2) throws XMLParseException {
        if (parameter2 != null) {
            if (parameter.getDimension() <= 1) {
                parameter.setDimension(parameter2.getDimension());
            }
            if (parameter.getDimension() != parameter2.getDimension()) {
                throw new XMLParseException("dim(" + parameter2.getId() + ") != dim(" + parameter.getId() + ")");
            }
        }
    }

    private void checkDimensions(Parameter parameter, Parameter parameter2, DesignMatrix designMatrix) throws XMLParseException {
        if (parameter2 == null) {
            if (parameter.getDimension() <= 1) {
                parameter.setDimension(designMatrix.getColumnDimension());
            }
            if (parameter.getDimension() != designMatrix.getColumnDimension()) {
                throw new XMLParseException("dim(" + parameter.getId() + ") is incompatible with dim (" + designMatrix.getId() + ")");
            }
            return;
        }
        if (parameter2.getDimension() <= 1) {
            parameter2.setDimension(designMatrix.getRowDimension());
        }
        if (parameter2.getDimension() != designMatrix.getRowDimension() || parameter.getDimension() != designMatrix.getColumnDimension()) {
            throw new XMLParseException("dim(" + parameter2.getId() + ") != dim(" + designMatrix.getId() + " %*% " + parameter.getId() + ")");
        }
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public XMLSyntaxRule[] getSyntaxRules() {
        return this.rules;
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public String getParserDescription() {
        return "Calculates the generalized linear model likelihood of the dependent parameters given one or more blocks of independent parameters and their design matrix.";
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public Class getReturnType() {
        return Likelihood.class;
    }
}
