package dr.evomodel.coalescent;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.coalescent.OldAbstractCoalescentLikelihood;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.MathUtils;
import dr.util.Author;
import dr.util.Citable;
import dr.util.Citation;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.SymmTridiagMatrix;

/* loaded from: input_file:dr/evomodel/coalescent/GMRFSkyrideLikelihood.class */
public class GMRFSkyrideLikelihood extends OldAbstractCoalescentLikelihood implements CoalescentIntervalProvider, Citable {
    public static final double LOG_TWO_TIMES_PI = 1.837877d;
    public static final boolean TIME_AWARE_IS_ON_BY_DEFAULT = true;
    protected Parameter popSizeParameter;
    protected Parameter groupSizeParameter;
    protected Parameter precisionParameter;
    protected Parameter lambdaParameter;
    protected Parameter betaParameter;
    protected int fieldLength;
    protected double[] coalescentIntervals;
    protected double[] storedCoalescentIntervals;
    protected double[] sufficientStatistics;
    protected double[] storedSufficientStatistics;
    protected double logFieldLikelihood;
    protected double storedLogFieldLikelihood;
    protected SymmTridiagMatrix weightMatrix;
    protected SymmTridiagMatrix storedWeightMatrix;
    protected MatrixParameter dMatrix;
    protected boolean timeAwareSmoothing;
    protected boolean rescaleByRootHeight;
    OldAbstractCoalescentLikelihood.IntervalNodeMapping coalesentIntervalNodeMapping;
    public static Citation CITATION = new Citation(new Author[]{new Author("VN", "Minin"), new Author("EW", "Bloomquist"), new Author("MA", "Suchard")}, "Smooth skyride through a rough skyline: Bayesian coalescent-based inference of population dynamics", 2008, "Mol Biol Evol", 25, 1459, 1471, "10.1093/molbev/msn090");

    public GMRFSkyrideLikelihood() {
        super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD);
        this.timeAwareSmoothing = true;
    }

    public GMRFSkyrideLikelihood(String str) {
        super(str);
        this.timeAwareSmoothing = true;
    }

    public GMRFSkyrideLikelihood(Tree tree, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, MatrixParameter matrixParameter, boolean z, boolean z2) {
        this(wrapTree(tree), parameter, parameter2, parameter3, parameter4, parameter5, matrixParameter, z, z2, false);
    }

    private static List<Tree> wrapTree(Tree tree) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(tree);
        return arrayList;
    }

    public GMRFSkyrideLikelihood(List<Tree> list, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, MatrixParameter matrixParameter, boolean z, boolean z2) {
        this(list, parameter, parameter2, parameter3, parameter4, parameter5, matrixParameter, z, z2, false);
    }

    public GMRFSkyrideLikelihood(List<Tree> list, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, MatrixParameter matrixParameter, boolean z, boolean z2, boolean z3) {
        super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD);
        this.timeAwareSmoothing = true;
        addKeyword("skyride");
        this.popSizeParameter = parameter;
        this.groupSizeParameter = parameter2;
        this.precisionParameter = parameter3;
        this.lambdaParameter = parameter4;
        this.betaParameter = parameter5;
        this.dMatrix = matrixParameter;
        this.timeAwareSmoothing = z;
        this.rescaleByRootHeight = z2;
        addVariable(this.popSizeParameter);
        addVariable(this.precisionParameter);
        addVariable(this.lambdaParameter);
        if (this.betaParameter != null) {
            addVariable(this.betaParameter);
        }
        setTree(list);
        int correctFieldLength = getCorrectFieldLength();
        if (this.popSizeParameter.getDimension() <= 1) {
            this.popSizeParameter.setDimension(correctFieldLength);
        }
        this.fieldLength = this.popSizeParameter.getDimension();
        if (correctFieldLength != this.fieldLength) {
            throw new IllegalArgumentException("Population size parameter should have length " + correctFieldLength);
        }
        this.buildIntervalNodeMapping = z3;
        wrapSetupIntervals();
        this.coalescentIntervals = new double[this.fieldLength];
        this.storedCoalescentIntervals = new double[this.fieldLength];
        this.sufficientStatistics = new double[this.fieldLength];
        this.storedSufficientStatistics = new double[this.fieldLength];
        if (z3) {
            this.coalesentIntervalNodeMapping = new OldAbstractCoalescentLikelihood.IntervalNodeMapping.Default(this.tree.getNodeCount(), this.tree);
        } else {
            this.coalesentIntervalNodeMapping = new OldAbstractCoalescentLikelihood.IntervalNodeMapping.None();
        }
        setupGMRFWeights();
        addStatistic(new OldAbstractCoalescentLikelihood.DeltaStatistic());
        initializationReport();
        if (this.groupSizeParameter != null) {
            for (int i = 0; i < this.groupSizeParameter.getDimension(); i++) {
                this.groupSizeParameter.setParameterValue(i, 1.0d);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int getCorrectFieldLength() {
        return this.tree.getExternalNodeCount() - 1;
    }

    protected void wrapSetupIntervals() {
        setupIntervals();
    }

    protected int setTree(List<Tree> list) {
        if (list.size() != 1) {
            throw new RuntimeException("GMRFSkyrideLikelihood only implemented for one tree");
        }
        this.tree = list.get(0);
        this.treesSet = null;
        if (!(this.tree instanceof TreeModel)) {
            return 1;
        }
        addModel((TreeModel) this.tree);
        return 1;
    }

    public double[] getCoalescentIntervals() {
        return this.coalescentIntervals;
    }

    public void initializationReport() {
        System.out.println("Creating a GMRF smoothed skyride model:");
        System.out.println("\tPopulation sizes: " + this.popSizeParameter.getDimension());
        System.out.println("\tIf you publish results using this model, please reference: Minin, Bloomquist and Suchard (2008) Molecular Biology and Evolution, 25, 1459-1471.");
    }

    public static void checkTree(TreeModel treeModel) {
        for (int i = 0; i < treeModel.getInternalNodeCount(); i++) {
            NodeRef internalNode = treeModel.getInternalNode(i);
            if (internalNode != treeModel.getRoot()) {
                double nodeHeight = treeModel.getNodeHeight(treeModel.getParent(internalNode));
                double nodeHeight2 = treeModel.getNodeHeight(treeModel.getChild(internalNode, 0));
                double nodeHeight3 = treeModel.getNodeHeight(treeModel.getChild(internalNode, 1));
                double d = nodeHeight2;
                if (nodeHeight3 > d) {
                    d = nodeHeight3;
                }
                treeModel.setNodeHeight(internalNode, d + (MathUtils.nextDouble() * (nodeHeight - d)));
            }
        }
        treeModel.pushTreeChangedEvent();
    }

    @Override // dr.evomodel.coalescent.OldAbstractCoalescentLikelihood, dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = calculateLogCoalescentLikelihood();
            this.logFieldLikelihood = calculateLogFieldLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood + this.logFieldLikelihood;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double peakLogCoalescentLikelihood() {
        return this.logLikelihood;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double peakLogFieldLikelihood() {
        return this.logFieldLikelihood;
    }

    public double[] getSufficientStatistics() {
        return this.sufficientStatistics;
    }

    @Override // dr.evomodel.coalescent.OldAbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    public String toString() {
        return getId() + "(" + Double.toString(getLogLikelihood()) + ")";
    }

    protected void setupSufficientStatistics() {
        int i = 0;
        double d = 0.0d;
        double d2 = 0.0d;
        this.coalesentIntervalNodeMapping.initializeMaps();
        for (int i2 = 0; i2 < getIntervalCount(); i2++) {
            d += getInterval(i2);
            d2 += getInterval(i2) * getLineageCount(i2) * (getLineageCount(i2) - 1);
            int i3 = -1;
            if (this.buildIntervalNodeMapping) {
                int[] nodeNumbersForInterval = this.intervalNodeMapping.getNodeNumbersForInterval(i2);
                for (int i4 = 0; i4 < nodeNumbersForInterval.length - 1; i4++) {
                    this.coalesentIntervalNodeMapping.addNode(nodeNumbersForInterval[i4]);
                }
                i3 = nodeNumbersForInterval[nodeNumbersForInterval.length - 1];
            }
            if (getIntervalType(i2) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) {
                this.coalescentIntervals[i] = d;
                this.sufficientStatistics[i] = d2 / 2.0d;
                this.coalesentIntervalNodeMapping.addNode(i3);
                i++;
                d = 0.0d;
                d2 = 0.0d;
            }
        }
        this.coalesentIntervalNodeMapping.setIntervalStartIndices(i);
    }

    @Override // dr.evomodel.coalescent.OldAbstractCoalescentLikelihood
    public OldAbstractCoalescentLikelihood.IntervalNodeMapping getIntervalNodeMapping() {
        return this.coalesentIntervalNodeMapping;
    }

    protected double getFieldScalar() {
        return this.rescaleByRootHeight ? this.tree.getNodeHeight(this.tree.getRoot()) : 1.0d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setupGMRFWeights() {
        setupSufficientStatistics();
        double[] dArr = new double[this.fieldLength - 1];
        double[] dArr2 = new double[this.fieldLength];
        if (this.timeAwareSmoothing) {
            for (int i = 0; i < this.fieldLength - 1; i++) {
                dArr[i] = ((-2.0d) / (this.coalescentIntervals[i] + this.coalescentIntervals[i + 1])) * getFieldScalar();
            }
        } else {
            for (int i2 = 0; i2 < this.fieldLength - 1; i2++) {
                dArr[i2] = -1.0d;
            }
        }
        for (int i3 = 1; i3 < this.fieldLength - 1; i3++) {
            dArr2[i3] = -(dArr[i3] + dArr[i3 - 1]);
        }
        dArr2[0] = -dArr[0];
        dArr2[this.fieldLength - 1] = -dArr[this.fieldLength - 2];
        this.weightMatrix = new SymmTridiagMatrix(dArr2, dArr);
    }

    public SymmTridiagMatrix getScaledWeightMatrix(double d) {
        SymmTridiagMatrix copy = this.weightMatrix.copy();
        for (int i = 0; i < copy.numRows() - 1; i++) {
            copy.set(i, i, copy.get(i, i) * d);
            copy.set(i + 1, i, copy.get(i + 1, i) * d);
        }
        copy.set(this.fieldLength - 1, this.fieldLength - 1, copy.get(this.fieldLength - 1, this.fieldLength - 1) * d);
        return copy;
    }

    public SymmTridiagMatrix getStoredScaledWeightMatrix(double d) {
        SymmTridiagMatrix copy = this.storedWeightMatrix.copy();
        for (int i = 0; i < copy.numRows() - 1; i++) {
            copy.set(i, i, copy.get(i, i) * d);
            copy.set(i + 1, i, copy.get(i + 1, i) * d);
        }
        copy.set(this.fieldLength - 1, this.fieldLength - 1, copy.get(this.fieldLength - 1, this.fieldLength - 1) * d);
        return copy;
    }

    public SymmTridiagMatrix getScaledWeightMatrix(double d, double d2) {
        if (d2 == 1.0d) {
            return getScaledWeightMatrix(d);
        }
        SymmTridiagMatrix copy = this.weightMatrix.copy();
        for (int i = 0; i < copy.numRows() - 1; i++) {
            copy.set(i, i, d * ((1.0d - d2) + (d2 * copy.get(i, i))));
            copy.set(i + 1, i, copy.get(i + 1, i) * d * d2);
        }
        copy.set(this.fieldLength - 1, this.fieldLength - 1, d * ((1.0d - d2) + (d2 * copy.get(this.fieldLength - 1, this.fieldLength - 1))));
        return copy;
    }

    private void makeIntervalsKnown() {
        if (this.intervalsKnown) {
            return;
        }
        wrapSetupIntervals();
        setupGMRFWeights();
        this.intervalsKnown = true;
    }

    public int getCoalescentIntervalDimension() {
        makeIntervalsKnown();
        return this.coalescentIntervals.length;
    }

    public double getCoalescentInterval(int i) {
        makeIntervalsKnown();
        return this.coalescentIntervals[i];
    }

    public int getNumberOfCoalescentEvents() {
        return this.tree.getExternalNodeCount() - 1;
    }

    public double getCoalescentEventsStatisticValue(int i) {
        return this.sufficientStatistics[i];
    }

    public void setupCoalescentIntervals() {
        setupIntervals();
        setupSufficientStatistics();
    }

    public double[] getCoalescentIntervalHeights() {
        makeIntervalsKnown();
        double[] dArr = new double[this.coalescentIntervals.length];
        dArr[0] = this.coalescentIntervals[0];
        for (int i = 1; i < dArr.length; i++) {
            dArr[i] = dArr[i - 1] + this.coalescentIntervals[i];
        }
        return dArr;
    }

    public SymmTridiagMatrix getCopyWeightMatrix() {
        return this.weightMatrix.copy();
    }

    public SymmTridiagMatrix getStoredScaledWeightMatrix(double d, double d2) {
        if (d2 == 1.0d) {
            return getStoredScaledWeightMatrix(d);
        }
        SymmTridiagMatrix copy = this.storedWeightMatrix.copy();
        for (int i = 0; i < copy.numRows() - 1; i++) {
            copy.set(i, i, d * ((1.0d - d2) + (d2 * copy.get(i, i))));
            copy.set(i + 1, i, copy.get(i + 1, i) * d * d2);
        }
        copy.set(this.fieldLength - 1, this.fieldLength - 1, d * ((1.0d - d2) + (d2 * copy.get(this.fieldLength - 1, this.fieldLength - 1))));
        return copy;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.coalescent.OldAbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    public void storeState() {
        super.storeState();
        System.arraycopy(this.coalescentIntervals, 0, this.storedCoalescentIntervals, 0, this.coalescentIntervals.length);
        System.arraycopy(this.sufficientStatistics, 0, this.storedSufficientStatistics, 0, this.sufficientStatistics.length);
        this.storedWeightMatrix = this.weightMatrix.copy();
        this.storedLogFieldLikelihood = this.logFieldLikelihood;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.coalescent.OldAbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    public void restoreState() {
        super.restoreState();
        double[] dArr = this.coalescentIntervals;
        this.coalescentIntervals = this.storedCoalescentIntervals;
        this.storedCoalescentIntervals = dArr;
        double[] dArr2 = this.sufficientStatistics;
        this.sufficientStatistics = this.storedSufficientStatistics;
        this.storedSufficientStatistics = dArr2;
        this.weightMatrix = this.storedWeightMatrix;
        this.logFieldLikelihood = this.storedLogFieldLikelihood;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.coalescent.OldAbstractCoalescentLikelihood, dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
    }

    protected double calculateLogCoalescentLikelihood() {
        makeIntervalsKnown();
        double d = 0.0d;
        double[] parameterValues = this.popSizeParameter.getParameterValues();
        for (int i = 0; i < this.fieldLength; i++) {
            d += (-parameterValues[i]) - (this.sufficientStatistics[i] * Math.exp(-parameterValues[i]));
        }
        return d;
    }

    protected double calculateLogFieldLikelihood() {
        makeIntervalsKnown();
        DenseVector denseVector = new DenseVector(this.fieldLength);
        DenseVector denseVector2 = new DenseVector(this.popSizeParameter.getParameterValues());
        getScaledWeightMatrix(this.precisionParameter.getParameterValue(0), this.lambdaParameter.getParameterValue(0)).mult(denseVector2, denseVector);
        double log = 0.0d + (((0.5d * (this.fieldLength - 1)) * Math.log(this.precisionParameter.getParameterValue(0))) - (0.5d * denseVector2.dot(denseVector)));
        return this.lambdaParameter.getParameterValue(0) == 1.0d ? log - (((this.fieldLength - 1) / 2.0d) * 1.837877d) : log - ((this.fieldLength / 2.0d) * 1.837877d);
    }

    public Parameter getPrecisionParameter() {
        return this.precisionParameter;
    }

    public Parameter getPopSizeParameter() {
        return this.popSizeParameter;
    }

    public Parameter getLambdaParameter() {
        return this.lambdaParameter;
    }

    public SymmTridiagMatrix getWeightMatrix() {
        return this.weightMatrix.copy();
    }

    public Parameter getBetaParameter() {
        return this.betaParameter;
    }

    public MatrixParameter getDesignMatrix() {
        return this.dMatrix;
    }

    public double calculateWeightedSSE() {
        double d = 0.0d;
        double parameterValue = this.popSizeParameter.getParameterValue(0);
        double d2 = this.coalescentIntervals[0];
        for (int i = 1; i < this.fieldLength; i++) {
            double parameterValue2 = this.popSizeParameter.getParameterValue(i);
            double d3 = this.coalescentIntervals[i];
            double d4 = parameterValue2 - parameterValue;
            d += (d4 * d4) / ((d2 + d3) / 2.0d);
            parameterValue = parameterValue2;
            d2 = d3;
        }
        return d;
    }

    public Citation.Category getCategory() {
        return Citation.Category.TREE_PRIORS;
    }

    public String getDescription() {
        return "Skyride coalescent";
    }

    public List<Citation> getCitations() {
        return Collections.singletonList(CITATION);
    }
}
