package dr.evomodel.treedatalikelihood.discrete;

import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/NodeHeightTransformTest.class */
public class NodeHeightTransformTest implements Reportable {
    private final NodeHeightTransform nodeHeightTransform;
    private final NodeHeightGradientForDiscreteTrait nodeHeightGradient;
    private final Parameter ratios;
    private final Transform.ComposeMultivariable realLineTransform;
    protected MultivariateFunction numericUnweighted = new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.discrete.NodeHeightTransformTest.1
        @Override // dr.math.MultivariateFunction
        public double evaluate(double[] dArr) {
            NodeHeightTransformTest.this.nodeHeightTransform.inverse(dArr);
            return NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().getLogLikelihood();
        }

        @Override // dr.math.MultivariateFunction
        public int getNumArguments() {
            return NodeHeightTransformTest.this.nodeHeightTransform.getDimension();
        }

        @Override // dr.math.MultivariateFunction
        public double getLowerBound(int i) {
            return 0.0d;
        }

        @Override // dr.math.MultivariateFunction
        public double getUpperBound(int i) {
            return 1.0d;
        }
    };
    protected MultivariateFunction numericWeighted = new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.discrete.NodeHeightTransformTest.2
        @Override // dr.math.MultivariateFunction
        public double evaluate(double[] dArr) {
            NodeHeightTransformTest.this.nodeHeightTransform.inverse(dArr);
            return NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().getLogLikelihood() - NodeHeightTransformTest.this.nodeHeightTransform.getLogJacobian(dArr);
        }

        @Override // dr.math.MultivariateFunction
        public int getNumArguments() {
            return NodeHeightTransformTest.this.nodeHeightTransform.getDimension();
        }

        @Override // dr.math.MultivariateFunction
        public double getLowerBound(int i) {
            return 0.0d;
        }

        @Override // dr.math.MultivariateFunction
        public double getUpperBound(int i) {
            return 1.0d;
        }
    };
    protected MultivariateFunction numericMultipleWeighted = new MultivariateFunction() { // from class: dr.evomodel.treedatalikelihood.discrete.NodeHeightTransformTest.3
        @Override // dr.math.MultivariateFunction
        public double evaluate(double[] dArr) {
            double[] inverse = NodeHeightTransformTest.this.realLineTransform.inverse(dArr, 0, dArr.length);
            Parameter nodeHeights = NodeHeightTransformTest.this.nodeHeightTransform.getNodeHeights();
            for (int i = 0; i < inverse.length; i++) {
                nodeHeights.setParameterValueQuietly(i, inverse[i]);
            }
            NodeHeightTransformTest.this.nodeHeightTransform.getNodeHeights().fireParameterChangedEvent();
            NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().makeDirty();
            return NodeHeightTransformTest.this.nodeHeightGradient.getLikelihood().getLogLikelihood() - NodeHeightTransformTest.this.realLineTransform.getLogJacobian(inverse, 0, dArr.length);
        }

        @Override // dr.math.MultivariateFunction
        public int getNumArguments() {
            return NodeHeightTransformTest.this.nodeHeightTransform.getNodeHeights().getDimension();
        }

        @Override // dr.math.MultivariateFunction
        public double getLowerBound(int i) {
            return Double.NEGATIVE_INFINITY;
        }

        @Override // dr.math.MultivariateFunction
        public double getUpperBound(int i) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final String NODE_HEIGHT_TRANSFORM_TEST = "nodeHeightTransformTest";
    public static AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.treedatalikelihood.discrete.NodeHeightTransformTest.4
        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            return new NodeHeightTransformTest((NodeHeightTransform) xMLObject.getChild(NodeHeightTransform.class), (NodeHeightGradientForDiscreteTrait) xMLObject.getChild(NodeHeightGradientForDiscreteTrait.class), (Parameter) xMLObject.getChild(Parameter.class));
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(NodeHeightTransform.class), new ElementRule(NodeHeightGradientForDiscreteTrait.class), new ElementRule(Parameter.class)};
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return null;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return NodeHeightTransformTest.class;
        }

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return NodeHeightTransformTest.NODE_HEIGHT_TRANSFORM_TEST;
        }
    };

    public NodeHeightTransformTest(NodeHeightTransform nodeHeightTransform, NodeHeightGradientForDiscreteTrait nodeHeightGradientForDiscreteTrait, Parameter parameter) {
        this.nodeHeightTransform = nodeHeightTransform;
        this.nodeHeightGradient = nodeHeightGradientForDiscreteTrait;
        this.ratios = parameter;
        ArrayList arrayList = new ArrayList();
        if (nodeHeightTransform.getParameter().getDimension() != parameter.getDimension()) {
            arrayList.add(new Transform.LogTransform());
        }
        for (int i = 0; i < parameter.getDimension(); i++) {
            arrayList.add(new Transform.LogitTransform());
        }
        this.realLineTransform = new Transform.ComposeMultivariable(new Transform.Array(arrayList, nodeHeightTransform.getParameter()), nodeHeightTransform);
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        String report = this.nodeHeightGradient.getReport();
        double[] gradientLogDensity = this.nodeHeightGradient.getGradientLogDensity();
        double[] updateGradientUnWeightedLogDensity = this.nodeHeightTransform.updateGradientUnWeightedLogDensity(gradientLogDensity, this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, gradientLogDensity.length);
        double[] gradient = NumericalDerivative.gradient(this.numericUnweighted, this.nodeHeightTransform.transform(this.nodeHeightTransform.getNodeHeights().getParameterValues()));
        double[] updateGradientLogDensity = this.nodeHeightTransform.updateGradientLogDensity(gradientLogDensity, this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, gradientLogDensity.length);
        double[] gradient2 = NumericalDerivative.gradient(this.numericWeighted, this.nodeHeightTransform.transform(this.nodeHeightTransform.getNodeHeights().getParameterValues()));
        StringBuilder sb = new StringBuilder();
        sb.append("\nGradient wrt Unweighted LogLikelihood:");
        sb.append("\nPeeling: ").append(new Vector(updateGradientUnWeightedLogDensity));
        sb.append("\nNumeric: ").append(new Vector(gradient));
        sb.append("\nGradient wrt Weighted LogLikelihood:");
        sb.append("\nPeeling: ").append(new Vector(updateGradientLogDensity));
        sb.append("\nNumeric: ").append(new Vector(gradient2));
        double[] updateGradientLogDensity2 = this.realLineTransform.updateGradientLogDensity(gradientLogDensity, this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, this.nodeHeightTransform.getNodeHeights().getDimension());
        double[] gradient3 = NumericalDerivative.gradient(this.numericMultipleWeighted, this.realLineTransform.transform(this.nodeHeightTransform.getNodeHeights().getParameterValues(), 0, this.nodeHeightTransform.getNodeHeights().getDimension()));
        sb.append("\nGradient wrt Multiple Weighted LogLikelihood:");
        sb.append("\nPeeling: ").append(new Vector(updateGradientLogDensity2));
        sb.append("\nNumeric: ").append(new Vector(gradient3));
        return report + sb.toString();
    }
}
