package dr.evolution.continuous;

import dr.evolution.io.NewickImporter;
import dr.evolution.tree.MutableTree;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.SimpleNode;
import dr.evolution.tree.SimpleTree;
import dr.geo.math.SphericalPolarCoordinates;
import dr.matrix.Matrix;
import dr.matrix.MutableMatrix;
import java.io.StringReader;

/* loaded from: input_file:dr/evolution/continuous/ContinuousTraitLikelihood.class */
public class ContinuousTraitLikelihood {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evolution/continuous/ContinuousTraitLikelihood$ContrastedTraitNode.class */
    public class ContrastedTraitNode extends SimpleNode {
        private double[] contrast;
        private double contrastVariance;
        private Contrastable[] traitValue;
        private double nodeVariance;
        private MutableTree tree;
        private NodeRef node;
        private String[] traitNames;

        public ContrastedTraitNode(MutableTree mutableTree, NodeRef nodeRef, String[] strArr) {
            init(mutableTree, nodeRef, strArr.length);
            if (mutableTree.isExternal(nodeRef)) {
                for (int i = 0; i < strArr.length; i++) {
                    Object attribute = mutableTree.getNodeTaxon(nodeRef).getAttribute(strArr[i]);
                    if (attribute == null) {
                        throw new IllegalArgumentException("attribute " + strArr[i] + " does not exist in " + mutableTree.getTaxonId(nodeRef.getNumber()));
                    }
                    if (attribute instanceof Number) {
                        this.traitValue[i] = new Continuous(((Number) attribute).doubleValue());
                    } else if (attribute instanceof String) {
                        this.traitValue[i] = new Continuous(Double.parseDouble((String) attribute));
                    } else if (attribute instanceof Continuous) {
                        this.traitValue[i] = (Continuous) attribute;
                    } else if (attribute instanceof SphericalPolarCoordinates) {
                        this.traitValue[i] = (SphericalPolarCoordinates) attribute;
                    }
                    mutableTree.setNodeAttribute(nodeRef, strArr[i], this.traitValue[i]);
                }
            } else {
                if (mutableTree.getChildCount(nodeRef) != 2) {
                    throw new IllegalArgumentException("Tree must be strictly bifurcating!");
                }
                addChild(new ContrastedTraitNode(mutableTree, mutableTree.getChild(nodeRef, 0), strArr));
                addChild(new ContrastedTraitNode(mutableTree, mutableTree.getChild(nodeRef, 1), strArr));
            }
            this.traitNames = strArr;
        }

        private void init(MutableTree mutableTree, NodeRef nodeRef, int i) {
            setHeight(mutableTree.getNodeHeight(nodeRef));
            setRate(mutableTree.getNodeRate(nodeRef));
            setId(mutableTree.getTaxonId(nodeRef.getNumber()));
            setNumber(nodeRef.getNumber());
            setTaxon(mutableTree.getNodeTaxon(nodeRef));
            this.contrast = new double[i];
            this.contrastVariance = 0.0d;
            this.traitValue = new Contrastable[i];
            this.nodeVariance = 0.0d;
            this.tree = mutableTree;
            this.node = nodeRef;
        }

        public double[] getTraitContrasts() {
            return this.contrast;
        }

        public double getContrastVariance() {
            return this.contrastVariance;
        }

        public double getNodeVariance() {
            return this.nodeVariance;
        }

        public Contrastable getTraitValue(int i) {
            return this.traitValue[i];
        }

        public int getTraitCount() {
            return this.traitValue.length;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void calculateContrasts(double d) {
            if (isExternal()) {
                return;
            }
            ContrastedTraitNode contrastedTraitNode = (ContrastedTraitNode) getChild(0);
            ContrastedTraitNode contrastedTraitNode2 = (ContrastedTraitNode) getChild(1);
            contrastedTraitNode.calculateContrasts(d);
            contrastedTraitNode2.calculateContrasts(d);
            double pow = contrastedTraitNode.nodeVariance + Math.pow(getHeight() - contrastedTraitNode.getHeight(), d);
            double pow2 = contrastedTraitNode2.nodeVariance + Math.pow(getHeight() - contrastedTraitNode2.getHeight(), d);
            this.contrastVariance = pow + pow2;
            this.nodeVariance = (pow * pow2) / (pow + pow2);
            double d2 = 1.0d / pow;
            double d3 = 1.0d / pow2;
            for (int i = 0; i < getTraitCount(); i++) {
                this.contrast[i] = contrastedTraitNode.traitValue[i].getDifference(contrastedTraitNode2.traitValue[i]);
                this.traitValue[i] = contrastedTraitNode.traitValue[i].getWeightedMean(d2, contrastedTraitNode.traitValue[i], d3, contrastedTraitNode2.traitValue[i]);
                this.tree.setNodeAttribute(this.node, this.traitNames[i], this.traitValue[i]);
            }
        }
    }

    public double calculateLikelihood(MutableTree mutableTree, String[] strArr, Contrastable[] contrastableArr, double d) {
        ContrastedTraitNode contrastedTraitNode = new ContrastedTraitNode(mutableTree, mutableTree.getRoot(), strArr);
        contrastedTraitNode.calculateContrasts(d);
        for (int i = 0; i < contrastableArr.length; i++) {
            contrastableArr[i] = contrastedTraitNode.getTraitValue(i);
        }
        return calculateTraitsLikelihood(contrastedTraitNode);
    }

    private double calculateTraitsLikelihood(ContrastedTraitNode contrastedTraitNode) {
        int traitCount = contrastedTraitNode.getTraitCount();
        return traitCount == 1 ? calculateSingleTraitLikelihood(contrastedTraitNode) : calculateMultipleTraitsLikelihood(contrastedTraitNode, traitCount);
    }

    private double calculateMultipleTraitsLikelihood(ContrastedTraitNode contrastedTraitNode, int i) {
        SimpleTree simpleTree = new SimpleTree(contrastedTraitNode);
        double[][] dArr = new double[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = i2; i3 < i; i3++) {
                double d = 0.0d;
                for (int i4 = 0; i4 < simpleTree.getInternalNodeCount(); i4++) {
                    ContrastedTraitNode contrastedTraitNode2 = (ContrastedTraitNode) simpleTree.getInternalNode(i4);
                    d += (contrastedTraitNode2.contrast[i2] * contrastedTraitNode2.contrast[i3]) / contrastedTraitNode2.contrastVariance;
                }
                double internalNodeCount = d / simpleTree.getInternalNodeCount();
                dArr[i2][i3] = internalNodeCount;
                dArr[i3][i2] = internalNodeCount;
            }
        }
        MutableMatrix createMutableMatrix = Matrix.Util.createMutableMatrix(new double[1][1]);
        MutableMatrix createMutableMatrix2 = Matrix.Util.createMutableMatrix(dArr);
        double d2 = 0.0d;
        try {
            d2 = Matrix.Util.det(createMutableMatrix2);
        } catch (Matrix.NotSquareException e) {
            e.printStackTrace(System.out);
        }
        MutableMatrix createMutableMatrix3 = Matrix.Util.createMutableMatrix(dArr);
        try {
            Matrix.Util.invert(createMutableMatrix3);
        } catch (Matrix.NotSquareException e2) {
            e2.printStackTrace(System.out);
        }
        double d3 = 0.0d;
        int internalNodeCount2 = simpleTree.getInternalNodeCount() + 1;
        for (int i5 = 0; i5 < simpleTree.getInternalNodeCount(); i5++) {
            ContrastedTraitNode contrastedTraitNode3 = (ContrastedTraitNode) simpleTree.getInternalNode(i5);
            double[] traitContrasts = contrastedTraitNode3.getTraitContrasts();
            Matrix createRowVector = Matrix.Util.createRowVector(traitContrasts);
            try {
                Matrix.Util.product(createMutableMatrix3, Matrix.Util.createColumnVector(traitContrasts), createMutableMatrix2);
                Matrix.Util.product(createRowVector, createMutableMatrix2, createMutableMatrix);
            } catch (Matrix.WrongDimensionException e3) {
                e3.printStackTrace(System.out);
            }
            d3 = d3 + (createMutableMatrix.getElement(0, 0) / contrastedTraitNode3.getContrastVariance()) + (i * Math.log(contrastedTraitNode3.getContrastVariance()));
        }
        return (-(((d3 + (i * Math.log(contrastedTraitNode.getNodeVariance()))) + (internalNodeCount2 * Math.log(d2))) + ((internalNodeCount2 * i) * Math.log(6.283185307179586d)))) / 2.0d;
    }

    private double calculateSingleTraitLikelihood(ContrastedTraitNode contrastedTraitNode) {
        SimpleTree simpleTree = new SimpleTree(contrastedTraitNode);
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < simpleTree.getInternalNodeCount(); i++) {
            ContrastedTraitNode contrastedTraitNode2 = (ContrastedTraitNode) simpleTree.getInternalNode(i);
            double d3 = contrastedTraitNode2.getTraitContrasts()[0];
            double contrastVariance = contrastedTraitNode2.getContrastVariance();
            d += (d3 * d3) / contrastVariance;
            d2 += Math.log(contrastVariance);
            if (contrastedTraitNode2.isRoot()) {
                d2 += Math.log(contrastedTraitNode2.getNodeVariance());
            }
        }
        double d4 = 0.0d;
        for (int i2 = 0; i2 < simpleTree.getNodeCount(); i2++) {
            NodeRef node = simpleTree.getNode(i2);
            if (!simpleTree.isRoot(node)) {
                d4 += simpleTree.getBranchLength(node);
            }
        }
        double internalNodeCount = d / simpleTree.getInternalNodeCount();
        return (-((((simpleTree.getInternalNodeCount() + 1) * Math.log(6.283185307179586d * internalNodeCount)) + d2) + (d / internalNodeCount))) / 2.0d;
    }

    public static void main(String[] strArr) throws Exception {
        MutableTree mutableTree = (MutableTree) new NewickImporter(new StringReader("((A:1, B:1):1,(C:1, D:1):1);")).importTree(null);
        mutableTree.setTaxonAttribute(0, "U1", new Continuous(1.1d));
        mutableTree.setTaxonAttribute(1, "U1", new Continuous(1.95d));
        mutableTree.setTaxonAttribute(2, "U1", new Continuous(3.15d));
        mutableTree.setTaxonAttribute(3, "U1", new Continuous(4.39d));
        mutableTree.setTaxonAttribute(0, "U2", new Continuous(5.2d));
        mutableTree.setTaxonAttribute(1, "U2", new Continuous(3.8d));
        mutableTree.setTaxonAttribute(2, "U2", new Continuous(3.1d));
        mutableTree.setTaxonAttribute(3, "U2", new Continuous(1.95d));
        ContinuousTraitLikelihood continuousTraitLikelihood = new ContinuousTraitLikelihood();
        Contrastable[] contrastableArr = new Contrastable[2];
        System.out.println("logL = " + continuousTraitLikelihood.calculateLikelihood(mutableTree, new String[]{"U1", "U2"}, contrastableArr, 1.0d));
        System.out.println("mle(trait1) = " + contrastableArr[0]);
        System.out.println("mle(trait2) = " + contrastableArr[1]);
        Contrastable[] contrastableArr2 = new Contrastable[1];
        System.out.println("logL (trait1) = " + continuousTraitLikelihood.calculateLikelihood(mutableTree, new String[]{"U1"}, contrastableArr2, 1.0d));
        System.out.println("mle(trait1) = " + contrastableArr2[0]);
        System.out.println("logL (trait2) = " + continuousTraitLikelihood.calculateLikelihood(mutableTree, new String[]{"U2"}, contrastableArr2, 1.0d));
        System.out.println("mle(trait2) = " + contrastableArr2[0]);
    }
}
