package dr.evomodel.treedatalikelihood.discrete;

import cern.colt.matrix.impl.AbstractFormatter;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.ProcessOnTreeDelegate;
import dr.evomodelxml.operators.RandomWalkIntegerNodeHeightWeightedOperatorParser;
import dr.inference.model.Bounds;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.Vector;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/NodeHeightToRatiosFullTransformDelegate.class */
public class NodeHeightToRatiosFullTransformDelegate extends NodeHeightToRatiosTransformDelegate {
    private final double maxTipHeight;
    private final Parameter heightParameter;
    private CompoundParameter rootHeightAndRatios;

    /* loaded from: input_file:dr/evomodel/treedatalikelihood/discrete/NodeHeightToRatiosFullTransformDelegate$HeightParameter.class */
    private class HeightParameter extends Parameter.Proxy {
        private final TreeModel tree;
        private Bounds<Double> bounds;

        private HeightParameter(TreeModel treeModel, Bounds<Double> bounds) {
            super("LocationShiftedRootHeightParameter", 1);
            this.bounds = null;
            this.tree = treeModel;
            addBounds(bounds);
        }

        @Override // dr.inference.model.Parameter.Proxy, dr.inference.model.Parameter, dr.inference.model.Variable
        public void addBounds(Bounds<Double> bounds) {
            this.bounds = bounds;
        }

        @Override // dr.inference.model.Parameter.Proxy, dr.inference.model.Parameter, dr.inference.model.Variable
        public Bounds<Double> getBounds() {
            return this.bounds;
        }

        @Override // dr.inference.model.Parameter
        public double getParameterValue(int i) {
            return this.tree.getNodeHeight(this.tree.getRoot()) - NodeHeightToRatiosFullTransformDelegate.this.maxTipHeight;
        }

        @Override // dr.inference.model.Parameter
        public void setParameterValue(int i, double d) {
            this.tree.setNodeHeight(this.tree.getRoot(), NodeHeightToRatiosFullTransformDelegate.this.getRootHeight(d));
        }

        @Override // dr.inference.model.Parameter
        public void setParameterValueQuietly(int i, double d) {
            this.tree.setNodeHeight(this.tree.getRoot(), NodeHeightToRatiosFullTransformDelegate.this.getRootHeight(d));
        }

        @Override // dr.inference.model.Parameter
        public void setParameterValueNotifyChangedAll(int i, double d) {
            this.tree.setNodeHeight(this.tree.getRoot(), NodeHeightToRatiosFullTransformDelegate.this.getRootHeight(d));
        }
    }

    public NodeHeightToRatiosFullTransformDelegate(TreeModel treeModel, Parameter parameter, Parameter parameter2, BranchRateModel branchRateModel) {
        super(treeModel, parameter, parameter2, branchRateModel);
        if (treeModel.getInternalNodeCount() != parameter.getDimension()) {
            throw new RuntimeException("Use all internal node (including root) for this transform.");
        }
        double d = 0.0d;
        for (int i = 0; i < this.tree.getExternalNodeCount(); i++) {
            double nodeHeight = this.tree.getNodeHeight(this.tree.getNode(i));
            if (nodeHeight > d) {
                d = nodeHeight;
            }
        }
        this.maxTipHeight = d;
        this.heightParameter = new HeightParameter(this.tree, new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, 0.0d, 1));
        this.rootHeightAndRatios = new CompoundParameter("rootHeightAndRatios", new Parameter[]{this.heightParameter, parameter2});
        this.nodeHeights = new NodeHeightProxyParameter(RandomWalkIntegerNodeHeightWeightedOperatorParser.INTERNAL_NODE_HEIGHTS, this.tree, true);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double getRootHeight(double d) {
        return d + this.maxTipHeight;
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate, dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate
    public double[] transform(double[] dArr) {
        setNodeHeights(dArr);
        updateRatios();
        return setCombinedValues();
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate, dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate
    String getReport() {
        updateRatios();
        StringBuilder sb = new StringBuilder();
        sb.append("NodeHeight by inverse ratios: ").append(new Vector(inverse(setCombinedValues())));
        sb.append("\n");
        sb.append("NodeHeights: ").append(new Vector(getNodeHeights().getParameterValues()));
        sb.append(AbstractFormatter.DEFAULT_SLICE_SEPARATOR);
        sb.append("ratios by transform nodeHeights: ").append(new Vector(transform(this.nodeHeights.getParameterValues())));
        sb.append("\n");
        sb.append("ratios: ").append(new Vector(setCombinedValues()));
        sb.append("\n");
        return sb.toString();
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate
    protected int getNodeHeightIndex(NodeRef nodeRef) {
        return getNodeHeightGradientIndex(nodeRef);
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate, dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate
    public double[] updateGradientLogDensity(double[] dArr, double[] dArr2) {
        double[] updateGradientLogDensity = super.updateGradientLogDensity(dArr, dArr2);
        double[] dArr3 = new double[this.ratios.getDimension() + 1];
        System.arraycopy(updateGradientLogDensity, 0, dArr3, 1, this.ratios.getDimension());
        dArr3[0] = updateHeightParameterGradientUnweightedLogDensity(dArr) - updateHeightParameterGradientUnweightedLogDensity(getLogTimeArray());
        return dArr3;
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate, dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate
    public double[] updateGradientUnWeightedLogDensity(double[] dArr, double[] dArr2, int i, int i2) {
        double[] updateGradientUnWeightedLogDensity = super.updateGradientUnWeightedLogDensity(dArr, dArr2, i, i2);
        double[] dArr3 = new double[this.ratios.getDimension() + 1];
        dArr3[0] = updateHeightParameterGradientUnweightedLogDensity(dArr);
        System.arraycopy(updateGradientUnWeightedLogDensity, 0, dArr3, 1, this.ratios.getDimension());
        return dArr3;
    }

    private double[] setCombinedValues() {
        double[] dArr = new double[this.ratios.getDimension() + 1];
        System.arraycopy(this.ratios.getParameterValues(), 0, dArr, 1, this.ratios.getDimension());
        dArr[0] = this.rootHeightAndRatios.getParameterValue(0);
        return dArr;
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate, dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate
    public double[] inverse(double[] dArr) {
        this.heightParameter.setParameterValue(0, dArr[0]);
        return super.inverse(separateRatios(dArr));
    }

    private double[] separateRatios(double[] dArr) {
        double[] dArr2 = new double[this.ratios.getDimension()];
        System.arraycopy(dArr, 1, dArr2, 0, dArr2.length);
        return dArr2;
    }

    @Override // dr.evomodel.treedatalikelihood.discrete.NodeHeightToRatiosTransformDelegate, dr.evomodel.treedatalikelihood.discrete.AbstractNodeHeightTransformDelegate
    public Parameter getParameter() {
        return this.rootHeightAndRatios;
    }

    private double updateHeightParameterGradientUnweightedLogDensity(double[] dArr) {
        this.preOrderTraversal.updateAllNodes();
        this.preOrderTraversal.dispatchTreeTraversalCollectBranchAndNodeOperations();
        double[] dArr2 = new double[this.tree.getInternalNodeCount()];
        List<ProcessOnTreeDelegate.NodeOperation> nodeOperations = this.preOrderTraversal.getNodeOperations();
        dArr2[getNodeHeightGradientIndex(this.tree.getRoot())] = 1.0d;
        Iterator<ProcessOnTreeDelegate.NodeOperation> it = nodeOperations.iterator();
        while (it.hasNext()) {
            NodeRef node = this.tree.getNode(it.next().getLeftChild());
            if (!this.tree.isRoot(node) && !this.tree.isExternal(node)) {
                dArr2[getNodeHeightGradientIndex(node)] = this.ratios.getParameterValue(getRatiosIndex(node)) * dArr2[getNodeHeightGradientIndex(this.tree.getParent(node))];
            }
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr2[i];
        }
        return d;
    }
}
