package dr.inference.glm;

import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.SingularValueDecomposition;
import dr.inference.distribution.DensityModel;
import dr.inference.distribution.ParametricMultivariateDistributionModel;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.DesignMatrix;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inferencexml.glm.GeneralizedLinearModelParser;
import dr.util.Transform;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Logger;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

/* loaded from: input_file:dr/inference/glm/GeneralizedLinearModel.class */
public final class GeneralizedLinearModel extends AbstractModelLikelihood {
    private final Transform linkFunction;
    private final DensityModel density;
    private final boolean isMultivariateDensity;
    private final Parameter dependentParameter;
    private final List<Parameter> independentParameter;
    private final List<Parameter> independentParameterDelta;
    private final List<DesignMatrix> designMatrix;
    private int numIndependentVariables;
    private int numRandomEffects;
    private int N;
    protected List<Parameter> randomEffects;
    private double[] transformedXBeta;
    private double[] storedTransformedXBeta;
    private boolean transformedXBetaKnown;
    private double[] Y;
    private double storedLogLikelihood;
    private double logLikelihood;
    private boolean likelihoodKnown;

    /* loaded from: input_file:dr/inference/glm/GeneralizedLinearModel$LinkFunction.class */
    public enum LinkFunction {
        IDENTITY(new Transform.NoTransform()),
        LOG(new Transform.LogTransform()),
        LOGIT(new Transform.LogitTransform());

        private final Transform transform;

        LinkFunction(Transform transform) {
            this.transform = transform;
        }

        public Transform getTransform() {
            return this.transform;
        }
    }

    /* loaded from: input_file:dr/inference/glm/GeneralizedLinearModel$NumberArrayColumn.class */
    private class NumberArrayColumn extends NumberColumn {
        private final int index;

        public NumberArrayColumn(String str, int i) {
            super(str);
            this.index = i;
        }

        @Override // dr.inference.loggers.NumberColumn
        public double getDoubleValue() {
            return GeneralizedLinearModel.this.getXBeta()[this.index];
        }
    }

    public GeneralizedLinearModel(Parameter parameter, DensityModel densityModel, LinkFunction linkFunction) {
        super(GeneralizedLinearModelParser.GLM_LIKELIHOOD);
        this.independentParameter = new ArrayList();
        this.independentParameterDelta = new ArrayList();
        this.designMatrix = new ArrayList();
        this.numIndependentVariables = 0;
        this.numRandomEffects = 0;
        this.randomEffects = null;
        this.transformedXBetaKnown = false;
        this.likelihoodKnown = false;
        this.dependentParameter = parameter;
        this.linkFunction = linkFunction.getTransform();
        this.density = densityModel;
        this.isMultivariateDensity = densityModel instanceof ParametricMultivariateDistributionModel;
        addModel(densityModel);
        if (parameter != null) {
            addVariable(parameter);
            this.N = parameter.getDimension();
        } else {
            this.N = 0;
        }
        this.transformedXBeta = new double[this.N];
        this.storedTransformedXBeta = new double[this.N];
        this.Y = parameter.getParameterValues();
        this.transformedXBetaKnown = false;
        this.likelihoodKnown = false;
    }

    public void addRandomEffectsParameter(Parameter parameter) {
        if (this.randomEffects == null) {
            this.randomEffects = new ArrayList();
        }
        if (this.N != 0 && parameter.getDimension() != this.N) {
            throw new RuntimeException("Random effects have the wrong dimension");
        }
        addVariable(parameter);
        this.randomEffects.add(parameter);
        this.numRandomEffects++;
    }

    public void addIndependentParameter(Parameter parameter, DesignMatrix designMatrix, Parameter parameter2) {
        if (this.N == 0) {
            this.N = designMatrix.getRowDimension();
        }
        this.designMatrix.add(designMatrix);
        this.independentParameter.add(parameter);
        this.independentParameterDelta.add(parameter2);
        if (this.designMatrix.size() != this.independentParameter.size()) {
            throw new RuntimeException("Independent variables and their design matrices are out of sync");
        }
        addVariable(parameter);
        addVariable(designMatrix);
        if (parameter2 != null) {
            addVariable(parameter2);
        }
        this.numIndependentVariables++;
        Logger.getLogger("dr.inference").info("\tAdding independent predictors '" + parameter.getStatisticName() + "' with design matrix '" + designMatrix.getStatisticName() + "'");
    }

    public boolean getAllIndependentVariablesIdentifiable() {
        int i = 0;
        Iterator<DesignMatrix> it = this.designMatrix.iterator();
        while (it.hasNext()) {
            i += it.next().getColumnDimension();
        }
        double[][] dArr = new double[this.N][i];
        int i2 = 0;
        for (DesignMatrix designMatrix : this.designMatrix) {
            int columnDimension = designMatrix.getColumnDimension();
            for (int i3 = 0; i3 < this.N; i3++) {
                for (int i4 = 0; i4 < columnDimension; i4++) {
                    dArr[i3][i2 + i4] = designMatrix.getParameterValue(i3, i4);
                }
            }
            i2 += columnDimension;
        }
        double[][] dArr2 = dArr;
        if (dArr.length < dArr[0].length) {
            dArr2 = new double[dArr[0].length][dArr.length];
            for (int i5 = 0; i5 < dArr.length; i5++) {
                for (int i6 = 0; i6 < dArr[i5].length; i6++) {
                    dArr2[i6][i5] = dArr[i5][i6];
                }
            }
        }
        int rank = new SingularValueDecomposition(new DenseDoubleMatrix2D(dArr2)).rank();
        boolean z = i == rank;
        Logger.getLogger("dr.inference").info("\tTotal # of predictors = " + i + " and rank = " + rank);
        return z;
    }

    public double[] getXBeta() {
        double[] dArr = new double[this.N];
        for (int i = 0; i < this.numIndependentVariables; i++) {
            Parameter parameter = this.independentParameter.get(i);
            Parameter parameter2 = this.independentParameterDelta.get(i);
            DesignMatrix designMatrix = this.designMatrix.get(i);
            int dimension = parameter.getDimension();
            for (int i2 = 0; i2 < dimension; i2++) {
                double parameterValue = parameter.getParameterValue(i2);
                if (parameter2 != null) {
                    parameterValue *= parameter2.getParameterValue(i2);
                }
                for (int i3 = 0; i3 < this.N; i3++) {
                    int i4 = i3;
                    dArr[i4] = dArr[i4] + (designMatrix.getParameterValue(i3, i2) * parameterValue);
                }
            }
        }
        for (int i5 = 0; i5 < this.numRandomEffects; i5++) {
            Parameter parameter3 = this.randomEffects.get(i5);
            for (int i6 = 0; i6 < this.N; i6++) {
                int i7 = i6;
                dArr[i7] = dArr[i7] + parameter3.getParameterValue(i6);
            }
        }
        return dArr;
    }

    public double[] getXBeta(int i) {
        double[] dArr = new double[this.N];
        Parameter parameter = this.independentParameter.get(i);
        Parameter parameter2 = this.independentParameterDelta.get(i);
        DesignMatrix designMatrix = this.designMatrix.get(i);
        int dimension = parameter.getDimension();
        for (int i2 = 0; i2 < dimension; i2++) {
            double parameterValue = parameter.getParameterValue(i2);
            if (parameter2 != null) {
                parameterValue *= parameter2.getParameterValue(i2);
            }
            for (int i3 = 0; i3 < this.N; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] + (designMatrix.getParameterValue(i3, i2) * parameterValue);
            }
        }
        if (this.numRandomEffects != 0) {
            throw new RuntimeException("Attempting to retrieve fixed effects without controlling for random effects");
        }
        return dArr;
    }

    public int getNumberOfFixedEffects() {
        return this.numIndependentVariables;
    }

    public int getNumberOfRandomEffects() {
        return this.numRandomEffects;
    }

    public Parameter getFixedEffect(int i) {
        return this.independentParameter.get(i);
    }

    public Parameter getRandomEffect(int i) {
        return this.randomEffects.get(i);
    }

    public Parameter getDependentVariable() {
        return this.dependentParameter;
    }

    public int getEffectNumber(Parameter parameter) {
        return this.independentParameter.indexOf(parameter);
    }

    public double[][] getX(int i) {
        return this.designMatrix.get(i).getParameterAsMatrix();
    }

    private void calculateTransformedXBeta() {
        double[] xBeta = getXBeta();
        for (int i = 0; i < this.N; i++) {
            this.transformedXBeta[i] = this.linkFunction.inverse(xBeta[i]);
        }
        this.transformedXBetaKnown = true;
    }

    private double calculateLogLikelihood() {
        if (!this.transformedXBetaKnown) {
            calculateTransformedXBeta();
        }
        double d = 0.0d;
        if (!this.isMultivariateDensity) {
            for (int i = 0; i < this.Y.length; i++) {
                this.density.getLocationVariable().setValue(0, Double.valueOf(this.transformedXBeta[i]));
                d += this.density.logPdf(new double[]{this.Y[i]});
            }
        }
        return d;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
        }
        return this.logLikelihood;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.dependentParameter) {
            this.Y = this.dependentParameter.getParameterValues();
        }
        this.transformedXBetaKnown = false;
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedLogLikelihood = this.logLikelihood;
        System.arraycopy(this.transformedXBeta, 0, this.storedTransformedXBeta, 0, this.transformedXBeta.length);
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        System.arraycopy(this.storedTransformedXBeta, 0, this.transformedXBeta, 0, this.transformedXBeta.length);
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    @Override // dr.inference.model.AbstractModel
    public String toString() {
        return super.toString() + ": " + getLogLikelihood();
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
    }

    @Override // dr.inference.model.AbstractModelLikelihood, dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArr = new LogColumn[this.N];
        for (int i = 0; i < this.N; i++) {
            logColumnArr[i] = new NumberArrayColumn(getId() + i, i);
        }
        return logColumnArr;
    }

    @Override // dr.inference.model.AbstractModel
    public Element createElement(Document document) {
        throw new RuntimeException("Not implemented yet!");
    }
}
