package dr.oldevomodel.ibd;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DefaultBranchRateModel;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.oldevomodel.substmodel.AbstractSubstitutionModel;
import dr.oldevomodel.substmodel.HKY;
import dr.oldevomodel.treelikelihood.NodePosteriorTreeLikelihood;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

/* loaded from: input_file:dr/oldevomodel/ibd/AvgPosteriorIBDReporter.class */
public class AvgPosteriorIBDReporter extends AbstractModel implements TreeTraitProvider {
    protected double[] ibdweights;
    protected double[][] ibdForward;
    protected double[][] ibdBackward;
    protected double[] diag;
    protected boolean weightsKnown;
    protected HKY substitutionModel;
    protected TreeModel treeModel;
    protected BranchRateModel branchRateModel;
    protected Parameter mutationParameter;
    protected NodePosteriorTreeLikelihood likelihoodReporter;
    protected double[] probabilities;
    TreeTrait avgPosteriorIBDWeight;
    public static final String IBD_REPORTER_LIKELIHOOD = "avgPosteriorIBDReporter";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.oldevomodel.ibd.AvgPosteriorIBDReporter.2
        private XMLSyntaxRule[] rules = {new ElementRule(TreeModel.class), new ElementRule(BranchRateModel.class, true), new ElementRule(AbstractSubstitutionModel.class), new ElementRule(Parameter.class), new ElementRule(NodePosteriorTreeLikelihood.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return AvgPosteriorIBDReporter.IBD_REPORTER_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeModel treeModel = (TreeModel) xMLObject.getChild(TreeModel.class);
            Parameter parameter = (Parameter) xMLObject.getChild(Parameter.class);
            AbstractSubstitutionModel abstractSubstitutionModel = (AbstractSubstitutionModel) xMLObject.getChild(AbstractSubstitutionModel.class);
            BranchRateModel branchRateModel = (BranchRateModel) xMLObject.getChild(BranchRateModel.class);
            if (branchRateModel == null) {
                branchRateModel = new DefaultBranchRateModel();
            }
            return new AvgPosteriorIBDReporter((NodePosteriorTreeLikelihood) xMLObject.getChild(NodePosteriorTreeLikelihood.class), parameter, treeModel, branchRateModel, abstractSubstitutionModel);
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element represents a reporter for average expected number of tips ibd conditional on observed patterns.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return Likelihood.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    AvgPosteriorIBDReporter(NodePosteriorTreeLikelihood nodePosteriorTreeLikelihood, Parameter parameter, TreeModel treeModel, BranchRateModel branchRateModel, AbstractSubstitutionModel abstractSubstitutionModel) {
        super("AvgPosteriorIBDReporter");
        this.avgPosteriorIBDWeight = new TreeTrait.D() { // from class: dr.oldevomodel.ibd.AvgPosteriorIBDReporter.1
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return "AvgPosteriorIBDWeight";
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public Double getTrait(Tree tree, NodeRef nodeRef) {
                if (!AvgPosteriorIBDReporter.this.weightsKnown) {
                    AvgPosteriorIBDReporter.this.expectedIBD();
                    AvgPosteriorIBDReporter.this.weightsKnown = true;
                }
                if (!tree.isExternal(nodeRef)) {
                    return null;
                }
                return Double.valueOf(AvgPosteriorIBDReporter.this.ibdweights[nodeRef.getNumber()] + 1.0d);
            }
        };
        this.substitutionModel = (HKY) abstractSubstitutionModel;
        addModel(this.substitutionModel);
        this.treeModel = treeModel;
        addModel(this.treeModel);
        this.branchRateModel = branchRateModel;
        addModel(this.branchRateModel);
        this.mutationParameter = parameter;
        addVariable(this.mutationParameter);
        this.likelihoodReporter = nodePosteriorTreeLikelihood;
        this.probabilities = new double[abstractSubstitutionModel.getStateCount() * abstractSubstitutionModel.getStateCount()];
    }

    public void forwardIBD() {
        int nodeCount = this.treeModel.getNodeCount();
        int stateCount = this.substitutionModel.getStateCount();
        getDiagonalRates(this.diag);
        int patternCount = this.likelihoodReporter.getPatternCount();
        for (int i = 0; i < nodeCount; i++) {
            NodeRef node = this.treeModel.getNode(i);
            NodeRef parent = this.treeModel.getParent(node);
            this.likelihoodReporter.getNodeMatrix(i, this.probabilities);
            double[] posteriors = this.likelihoodReporter.getPosteriors(i);
            if (parent != null) {
                if (this.treeModel.isExternal(node)) {
                    double branchRate = this.branchRateModel.getBranchRate(this.treeModel, node) * (this.treeModel.getNodeHeight(parent) - this.treeModel.getNodeHeight(node));
                    for (int i2 = 0; i2 < stateCount; i2++) {
                        double exp = Math.exp((-this.diag[i2]) * branchRate) / this.probabilities[i2 + (i2 * stateCount)];
                        for (int i3 = 0; i3 < patternCount; i3++) {
                            this.ibdForward[i][(i3 * stateCount) + i2] = posteriors[(i3 * stateCount) + i2] * exp;
                        }
                    }
                } else {
                    double branchRate2 = this.branchRateModel.getBranchRate(this.treeModel, node) * (this.treeModel.getNodeHeight(parent) - this.treeModel.getNodeHeight(node));
                    int childCount = this.treeModel.getChildCount(node);
                    for (int i4 = 0; i4 < stateCount; i4++) {
                        double exp2 = Math.exp((-this.diag[i4]) * branchRate2) / this.probabilities[i4 + (i4 * stateCount)];
                        for (int i5 = 0; i5 < patternCount; i5++) {
                            this.ibdForward[i][(i5 * stateCount) + i4] = 0.0d;
                            for (int i6 = 0; i6 < childCount; i6++) {
                                int number = this.treeModel.getChild(node, i6).getNumber();
                                double[] dArr = this.ibdForward[i];
                                int i7 = (i5 * stateCount) + i4;
                                dArr[i7] = dArr[i7] + this.ibdForward[number][(i5 * stateCount) + i4];
                            }
                            double[] dArr2 = this.ibdForward[i];
                            int i8 = (i5 * stateCount) + i4;
                            dArr2[i8] = dArr2[i8] * posteriors[(i5 * stateCount) + i4] * exp2;
                        }
                    }
                }
            }
        }
    }

    public void backwardIBD(NodeRef nodeRef) {
        int stateCount = this.substitutionModel.getStateCount();
        int patternCount = this.likelihoodReporter.getPatternCount();
        if (nodeRef == null) {
            nodeRef = this.treeModel.getRoot();
            int number = nodeRef.getNumber();
            for (int i = 0; i < patternCount * stateCount; i++) {
                this.ibdBackward[number][i] = 0.0d;
            }
        }
        getDiagonalRates(this.diag);
        int childCount = this.treeModel.getChildCount(nodeRef);
        int number2 = nodeRef.getNumber();
        double[] posteriors = this.likelihoodReporter.getPosteriors(number2);
        for (int i2 = 0; i2 < childCount; i2++) {
            NodeRef child = this.treeModel.getChild(nodeRef, i2);
            int number3 = child.getNumber();
            this.likelihoodReporter.getNodeMatrix(number3, this.probabilities);
            double branchRate = this.branchRateModel.getBranchRate(this.treeModel, child) * (this.treeModel.getNodeHeight(nodeRef) - this.treeModel.getNodeHeight(child));
            for (int i3 = 0; i3 < patternCount; i3++) {
                for (int i4 = 0; i4 < stateCount; i4++) {
                    this.ibdBackward[number3][(i3 * stateCount) + i4] = this.ibdBackward[number2][(i3 * stateCount) + i4];
                    for (int i5 = 0; i5 < childCount; i5++) {
                        if (i5 != i2) {
                            int number4 = this.treeModel.getChild(nodeRef, i5).getNumber();
                            double[] dArr = this.ibdBackward[number3];
                            int i6 = (i3 * stateCount) + i4;
                            dArr[i6] = dArr[i6] + this.ibdForward[number4][(i3 * stateCount) + i4];
                        }
                    }
                    double[] dArr2 = this.ibdBackward[number3];
                    int i7 = (i3 * stateCount) + i4;
                    dArr2[i7] = dArr2[i7] * ((posteriors[(i3 * stateCount) + i4] * Math.exp((-this.diag[i4]) * branchRate)) / this.probabilities[i4 + (i4 * stateCount)]);
                }
            }
        }
        for (int i8 = 0; i8 < childCount; i8++) {
            backwardIBD(this.treeModel.getChild(nodeRef, i8));
        }
    }

    public void expectedIBD() {
        int stateCount = this.substitutionModel.getStateCount();
        int nodeCount = this.treeModel.getNodeCount();
        int patternCount = this.likelihoodReporter.getPatternCount();
        if (this.ibdweights == null) {
            this.ibdweights = new double[this.treeModel.getExternalNodeCount()];
            this.ibdForward = new double[nodeCount][stateCount * patternCount];
            this.ibdBackward = new double[nodeCount][stateCount * patternCount];
            this.diag = new double[stateCount];
        }
        forwardIBD();
        backwardIBD(null);
        int externalNodeCount = this.treeModel.getExternalNodeCount();
        double[] patternWeights = this.likelihoodReporter.getPatternWeights();
        double d = 0.0d;
        for (int i = 0; i < patternCount; i++) {
            d += patternWeights[i];
        }
        for (int i2 = 0; i2 < externalNodeCount; i2++) {
            double[] posteriors = this.likelihoodReporter.getPosteriors(i2);
            this.ibdweights[i2] = 0.0d;
            for (int i3 = 0; i3 < patternCount; i3++) {
                for (int i4 = 0; i4 < stateCount; i4++) {
                    double[] dArr = this.ibdweights;
                    int i5 = i2;
                    dArr[i5] = dArr[i5] + (((this.ibdBackward[i2][(i3 * stateCount) + i4] * posteriors[(i3 * stateCount) + i4]) * patternWeights[i3]) / d);
                }
            }
        }
    }

    protected void getDiagonalRates(double[] dArr) {
        double kappa = this.substitutionModel.getKappa();
        double[] frequencies = this.substitutionModel.getFrequencyModel().getFrequencies();
        double parameterValue = this.mutationParameter.getParameterValue(0);
        double d = 0.5d / (((frequencies[0] + frequencies[2]) * (frequencies[1] + frequencies[3])) + (kappa * ((frequencies[0] * frequencies[2]) + (frequencies[1] * frequencies[3]))));
        dArr[0] = (frequencies[1] + frequencies[3] + (frequencies[2] * kappa)) * parameterValue * d;
        dArr[1] = (frequencies[0] + frequencies[2] + (frequencies[3] * kappa)) * parameterValue * d;
        dArr[2] = (frequencies[1] + frequencies[3] + (frequencies[0] * kappa)) * parameterValue * d;
        dArr[3] = (frequencies[0] + frequencies[2] + (frequencies[1] * kappa)) * parameterValue * d;
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait[] getTreeTraits() {
        return new TreeTrait[]{this.avgPosteriorIBDWeight};
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait getTreeTrait(String str) {
        return this.avgPosteriorIBDWeight;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.branchRateModel || model == this.treeModel || model == this.substitutionModel || model == this.likelihoodReporter) {
            this.weightsKnown = false;
        } else {
            System.err.println("Weird call back to IBDReporter from " + model.getModelName());
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.mutationParameter) {
            this.weightsKnown = false;
        } else {
            System.err.println("Weird call back to IBDReporter from " + variable.getVariableName());
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }
}
