package dr.evomodel.treelikelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.datatype.DataType;
import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchmodel.EpochBranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.CodonPartitionedRobustCounting;
import dr.evomodel.substmodel.MarkovJumpsSubstitutionModel;
import dr.evomodel.substmodel.SubstitutionModel;
import dr.evomodel.substmodel.UniformizedSubstitutionModel;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.evoxml.util.GraphMLUtils;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.markovjumps.MarkovJumpsRegisterAcceptor;
import dr.inference.markovjumps.MarkovJumpsType;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/treelikelihood/MarkovJumpsBeagleTreeLikelihood.class */
public class MarkovJumpsBeagleTreeLikelihood extends AncestralStateBeagleTreeLikelihood implements MarkovJumpsRegisterAcceptor, MarkovJumpsTraitProvider {
    public static final String ALL_HISTORY = "history_all";
    public static final String HISTORY = "history";
    public static final String TOTAL_COUNTS = "allTransitions";
    private List<MarkovJumpsSubstitutionModel> markovjumps;
    private List<Integer> branchModelNumber;
    private List<Parameter> registerParameter;
    private List<String> jumpTag;
    private List<double[][]> expectedJumps;
    private boolean logHistory;
    private boolean useCompactHistory;
    private String[][] histories;
    private boolean[] scaleByTime;
    private double[] tmpProbabilities;
    private double[][] condJumps;
    private int numRegisters;
    private int historyRegisterNumber;
    private final boolean useUniformization;
    private final int nSimulants;
    private final boolean reportUnconditionedColumns;

    /* loaded from: input_file:dr/evomodel/treelikelihood/MarkovJumpsBeagleTreeLikelihood$ConditionedCountColumn.class */
    protected class ConditionedCountColumn extends CountColumn {
        public ConditionedCountColumn(String str, int i, int i2) {
            super(CodonPartitionedRobustCounting.SITE_SPECIFIC_PREFIX + str, i, i2);
        }

        @Override // dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood.CountColumn, dr.inference.loggers.NumberColumn
        public double getDoubleValue() {
            double d = 0.0d;
            double[][] markovJumpsForRegister = MarkovJumpsBeagleTreeLikelihood.this.getMarkovJumpsForRegister(MarkovJumpsBeagleTreeLikelihood.this.treeModel, this.indexRegistration);
            for (int i = 0; i < MarkovJumpsBeagleTreeLikelihood.this.treeModel.getNodeCount(); i++) {
                d += markovJumpsForRegister[i][this.indexSite];
            }
            return d;
        }
    }

    /* loaded from: input_file:dr/evomodel/treelikelihood/MarkovJumpsBeagleTreeLikelihood$CountColumn.class */
    protected abstract class CountColumn extends NumberColumn {
        protected int indexRegistration;
        protected int indexSite;

        public CountColumn(String str, int i, int i2) {
            super(str + (i2 >= 0 ? GraphMLUtils.START_ATTRIBUTE + (i2 + 1) + GraphMLUtils.END_ATTRIBUTE : ""));
            this.indexRegistration = i;
            this.indexSite = i2;
        }

        @Override // dr.inference.loggers.NumberColumn
        public abstract double getDoubleValue();
    }

    /* loaded from: input_file:dr/evomodel/treelikelihood/MarkovJumpsBeagleTreeLikelihood$UnconditionedCountColumn.class */
    protected class UnconditionedCountColumn extends CountColumn {
        int[] rateCategory;

        public UnconditionedCountColumn(String str, int i, int i2, int[] iArr) {
            super(CodonPartitionedRobustCounting.UNCONDITIONED_PREFIX + str, i, i2);
            this.rateCategory = iArr;
        }

        public UnconditionedCountColumn(MarkovJumpsBeagleTreeLikelihood markovJumpsBeagleTreeLikelihood, String str, int i) {
            this(str, i, -1, null);
        }

        @Override // dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood.CountColumn, dr.inference.loggers.NumberColumn
        public double getDoubleValue() {
            double marginalRate = ((MarkovJumpsSubstitutionModel) MarkovJumpsBeagleTreeLikelihood.this.markovjumps.get(this.indexRegistration)).getMarginalRate() * getExpectedTreeLength();
            if (this.rateCategory != null) {
                marginalRate *= MarkovJumpsBeagleTreeLikelihood.this.siteRateModel.getRateForCategory(this.rateCategory[this.indexSite]);
            }
            return marginalRate;
        }

        private double getExpectedTreeLength() {
            double d = 0.0d;
            for (int i = 0; i < MarkovJumpsBeagleTreeLikelihood.this.treeModel.getNodeCount(); i++) {
                NodeRef node = MarkovJumpsBeagleTreeLikelihood.this.treeModel.getNode(i);
                if (!MarkovJumpsBeagleTreeLikelihood.this.treeModel.isRoot(node)) {
                    d += MarkovJumpsBeagleTreeLikelihood.this.branchRateModel.getBranchRate(MarkovJumpsBeagleTreeLikelihood.this.treeModel, node) * MarkovJumpsBeagleTreeLikelihood.this.treeModel.getBranchLength(node);
                }
            }
            return d;
        }
    }

    public MarkovJumpsBeagleTreeLikelihood(PatternList patternList, MutableTreeModel mutableTreeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean z, PartialsRescalingScheme partialsRescalingScheme, boolean z2, Map<Set<String>, Parameter> map, DataType dataType, String str, boolean z3, boolean z4, boolean z5, boolean z6, int i) {
        super(patternList, mutableTreeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, z, partialsRescalingScheme, z2, map, dataType, str, z3, z4);
        this.logHistory = false;
        this.useCompactHistory = false;
        this.histories = null;
        this.historyRegisterNumber = -1;
        this.useUniformization = z5;
        this.reportUnconditionedColumns = z6;
        this.nSimulants = i;
        this.markovjumps = new ArrayList();
        this.branchModelNumber = new ArrayList();
        this.registerParameter = new ArrayList();
        this.jumpTag = new ArrayList();
        this.expectedJumps = new ArrayList();
        this.tmpProbabilities = new double[this.stateCount * this.stateCount * this.categoryCount];
        this.condJumps = new double[this.categoryCount][this.stateCount * this.stateCount];
    }

    @Override // dr.inference.markovjumps.MarkovJumpsRegisterAcceptor
    public void addRegister(Parameter parameter, MarkovJumpsType markovJumpsType, boolean z) {
        MarkovJumpsSubstitutionModel markovJumpsSubstitutionModel;
        if ((markovJumpsType == MarkovJumpsType.COUNTS && parameter.getDimension() != this.stateCount * this.stateCount) || (markovJumpsType == MarkovJumpsType.REWARDS && parameter.getDimension() != this.stateCount)) {
            throw new RuntimeException("Register parameter of wrong dimension");
        }
        addVariable(parameter);
        final String id = parameter.getId();
        for (int i = 0; i < this.substitutionModelDelegate.getSubstitutionModelCount(); i++) {
            boolean z2 = this.branchModel instanceof EpochBranchModel;
            this.registerParameter.add(parameter);
            SubstitutionModel substitutionModel = this.substitutionModelDelegate.getSubstitutionModel(i);
            if (this.useUniformization) {
                markovJumpsSubstitutionModel = new UniformizedSubstitutionModel(substitutionModel, markovJumpsType, this.nSimulants);
            } else {
                if (markovJumpsType == MarkovJumpsType.HISTORY) {
                    throw new RuntimeException("Can only report complete history using uniformization");
                }
                markovJumpsSubstitutionModel = new MarkovJumpsSubstitutionModel(substitutionModel, markovJumpsType);
            }
            this.markovjumps.add(markovJumpsSubstitutionModel);
            this.branchModelNumber.add(Integer.valueOf(i));
            addModel(markovJumpsSubstitutionModel);
            setupRegistration(this.numRegisters);
            String str = this.substitutionModelDelegate.getSubstitutionModelCount() == 1 ? id : id + i;
            this.jumpTag.add(str);
            this.expectedJumps.add(new double[this.treeModel.getNodeCount()][this.patternCount]);
            boolean[] zArr = this.scaleByTime;
            int length = zArr == null ? 0 : zArr.length;
            this.scaleByTime = new boolean[length + 1];
            if (length > 0) {
                System.arraycopy(zArr, 0, this.scaleByTime, 0, length);
            }
            this.scaleByTime[length] = z;
            if (markovJumpsType != MarkovJumpsType.HISTORY) {
                TreeTrait.DA da = new TreeTrait.DA() { // from class: dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood.1
                    final int registerNumber;

                    {
                        this.registerNumber = MarkovJumpsBeagleTreeLikelihood.this.numRegisters;
                    }

                    @Override // dr.evolution.tree.TreeTrait
                    public String getTraitName() {
                        return id;
                    }

                    @Override // dr.evolution.tree.TreeTrait
                    public TreeTrait.Intent getIntent() {
                        return TreeTrait.Intent.BRANCH;
                    }

                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // dr.evolution.tree.TreeTrait
                    public double[] getTrait(Tree tree, NodeRef nodeRef) {
                        return MarkovJumpsBeagleTreeLikelihood.this.getMarkovJumpsForNodeAndRegister(tree, nodeRef, this.registerNumber);
                    }
                };
                this.treeTraits.addTrait(str + "_base", da);
                this.treeTraits.addTrait(parameter.getId(), new TreeTrait.SumAcrossArrayD(new TreeTrait.SumOverTreeDA(da)));
            } else {
                if (i == 0 || !z2) {
                    if (this.histories != null) {
                        throw new RuntimeException("Only one complete history per markovJumpTreeLikelihood is allowed");
                    }
                    this.histories = new String[this.treeModel.getNodeCount()][this.patternCount];
                    if (this.nSimulants > 1) {
                        throw new RuntimeException("Only one simulant allowed when saving complete history");
                    }
                    this.treeTraits.addTrait(parameter.getId(), new TreeTrait.SumOverTreeDA(new TreeTrait.DA() { // from class: dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood.2
                        final int registerNumber;

                        {
                            this.registerNumber = MarkovJumpsBeagleTreeLikelihood.this.numRegisters;
                        }

                        @Override // dr.evolution.tree.TreeTrait
                        public String getTraitName() {
                            return id;
                        }

                        @Override // dr.evolution.tree.TreeTrait
                        public TreeTrait.Intent getIntent() {
                            return TreeTrait.Intent.BRANCH;
                        }

                        /* JADX WARN: Can't rename method to resolve collision */
                        @Override // dr.evolution.tree.TreeTrait
                        public double[] getTrait(Tree tree, NodeRef nodeRef) {
                            return MarkovJumpsBeagleTreeLikelihood.this.getMarkovJumpsForNodeAndRegister(tree, nodeRef, this.registerNumber);
                        }
                    }));
                    this.historyRegisterNumber = this.numRegisters;
                    ((UniformizedSubstitutionModel) markovJumpsSubstitutionModel).setSaveCompleteHistory(true);
                    if (this.useCompactHistory && this.logHistory) {
                        this.treeTraits.addTrait(ALL_HISTORY, new TreeTrait.SA() { // from class: dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood.3
                            @Override // dr.evolution.tree.TreeTrait
                            public String getTraitName() {
                                return MarkovJumpsBeagleTreeLikelihood.ALL_HISTORY;
                            }

                            @Override // dr.evolution.tree.TreeTrait
                            public TreeTrait.Intent getIntent() {
                                return TreeTrait.Intent.BRANCH;
                            }

                            @Override // dr.evolution.tree.TreeTrait.DefaultBehavior
                            public boolean getFormatAsArray() {
                                return true;
                            }

                            /* JADX WARN: Can't rename method to resolve collision */
                            @Override // dr.evolution.tree.TreeTrait
                            public String[] getTrait(Tree tree, NodeRef nodeRef) {
                                ArrayList arrayList = new ArrayList();
                                for (int i2 = 0; i2 < MarkovJumpsBeagleTreeLikelihood.this.patternCount; i2++) {
                                    String historyForNode = MarkovJumpsBeagleTreeLikelihood.this.getHistoryForNode(tree, nodeRef, i2);
                                    if (historyForNode != null && historyForNode.compareTo("{}") != 0) {
                                        String substring = historyForNode.substring(1, historyForNode.length() - 1);
                                        if (substring.contains("},{")) {
                                            for (String str2 : substring.split("(?<=\\}),(?=\\{)")) {
                                                arrayList.add(str2);
                                            }
                                        } else {
                                            arrayList.add(substring);
                                        }
                                    }
                                }
                                String[] strArr = new String[arrayList.size()];
                                arrayList.toArray(strArr);
                                return strArr;
                            }

                            @Override // dr.evolution.tree.TreeTrait.DefaultBehavior, dr.evolution.tree.TreeTrait
                            public boolean getLoggable() {
                                return true;
                            }
                        });
                    }
                    for (int i2 = 0; i2 < this.patternCount; i2++) {
                        final String str2 = this.patternCount == 1 ? "history" : "history_" + (i2 + 1);
                        final int i3 = i2;
                        this.treeTraits.addTrait(str2, new TreeTrait.S() { // from class: dr.evomodel.treelikelihood.MarkovJumpsBeagleTreeLikelihood.4
                            @Override // dr.evolution.tree.TreeTrait
                            public String getTraitName() {
                                return str2;
                            }

                            @Override // dr.evolution.tree.TreeTrait
                            public TreeTrait.Intent getIntent() {
                                return TreeTrait.Intent.BRANCH;
                            }

                            /* JADX WARN: Can't rename method to resolve collision */
                            @Override // dr.evolution.tree.TreeTrait
                            public String getTrait(Tree tree, NodeRef nodeRef) {
                                String historyForNode = MarkovJumpsBeagleTreeLikelihood.this.getHistoryForNode(tree, nodeRef, i3);
                                if (historyForNode.compareTo("{}") != 0) {
                                    return historyForNode;
                                }
                                return null;
                            }

                            @Override // dr.evolution.tree.TreeTrait.DefaultBehavior, dr.evolution.tree.TreeTrait
                            public boolean getLoggable() {
                                return MarkovJumpsBeagleTreeLikelihood.this.logHistory && !MarkovJumpsBeagleTreeLikelihood.this.useCompactHistory;
                            }
                        });
                    }
                }
                if (z2) {
                    for (int i4 = 0; i4 < this.markovjumps.size(); i4++) {
                        ((UniformizedSubstitutionModel) this.markovjumps.get(i4)).setSaveCompleteHistory(true);
                    }
                }
            }
            this.numRegisters++;
        }
    }

    public void setLogHistories(boolean z) {
        this.logHistory = z;
    }

    public void setUseCompactHistory(boolean z) {
        this.useCompactHistory = z;
    }

    public double[] getMarkovJumpsForNodeAndRegister(Tree tree, NodeRef nodeRef, int i) {
        return getMarkovJumpsForRegister(tree, i)[nodeRef.getNumber()];
    }

    private void refresh(Tree tree) {
        if (tree != this.treeModel) {
            throw new RuntimeException("Must call with internal tree");
        }
        if (!this.likelihoodKnown) {
            calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        if (this.areStatesRedrawn) {
            return;
        }
        redrawAncestralStates();
    }

    public double[][] getMarkovJumpsForRegister(Tree tree, int i) {
        refresh(tree);
        return this.expectedJumps.get(i);
    }

    public String getHistoryForNode(Tree tree, NodeRef nodeRef, int i) {
        return getHistory(tree)[nodeRef.getNumber()][i];
    }

    public String[][] getHistory(Tree tree) {
        refresh(tree);
        return this.histories;
    }

    private void setupRegistration(int i) {
        this.markovjumps.get(i).setRegistration(this.registerParameter.get(i).getParameterValues());
        this.areStatesRedrawn = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        for (int i2 = 0; i2 < this.numRegisters; i2++) {
            if (variable == this.registerParameter.get(i2)) {
                setupRegistration(i2);
                return;
            }
        }
        super.handleVariableChangedEvent(variable, i, changeType);
    }

    @Override // dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood
    protected void hookCalculation(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, int[] iArr, int[] iArr2, double[] dArr, int[] iArr3) {
        int number = nodeRef2.getNumber();
        double[] dArr2 = dArr;
        if (dArr2 == null) {
            getMatrix(number, this.tmpProbabilities);
            dArr2 = this.tmpProbabilities;
        }
        double branchRate = this.branchRateModel.getBranchRate(tree, nodeRef2);
        double nodeHeight = tree.getNodeHeight(nodeRef);
        double nodeHeight2 = tree.getNodeHeight(nodeRef2);
        double d = nodeHeight - nodeHeight2;
        for (int i = 0; i < this.markovjumps.size(); i++) {
            MarkovJumpsSubstitutionModel markovJumpsSubstitutionModel = this.markovjumps.get(i);
            if (this.branchModelNumber.get(i).intValue() != this.branchModel.getBranchModelMapping(nodeRef2).getOrder()[0]) {
                Arrays.fill(this.expectedJumps.get(i)[number], 0.0d);
            } else if (this.useUniformization) {
                computeSampledMarkovJumpsForBranch((UniformizedSubstitutionModel) markovJumpsSubstitutionModel, d, branchRate, number, iArr, iArr2, nodeHeight, nodeHeight2, dArr2, this.scaleByTime[i], this.expectedJumps.get(i), iArr3, true);
            } else {
                computeIntegratedMarkovJumpsForBranch(markovJumpsSubstitutionModel, d, branchRate, number, iArr, iArr2, dArr2, this.condJumps, this.scaleByTime[i], this.expectedJumps.get(i), iArr3);
            }
        }
    }

    private void computeSampledMarkovJumpsForBranch(UniformizedSubstitutionModel uniformizedSubstitutionModel, double d, double d2, int i, int[] iArr, int[] iArr2, double d3, double d4, double[] dArr, boolean z, double[][] dArr2, int[] iArr3, boolean z2) {
        for (int i2 = 0; i2 < this.patternCount; i2++) {
            int i3 = iArr3 == null ? 0 : iArr3[i2];
            double rateForCategory = this.siteRateModel.getRateForCategory(i3);
            double computeCondStatMarkovJumps = uniformizedSubstitutionModel.computeCondStatMarkovJumps(iArr[i2], iArr2[i2], d * d2 * rateForCategory, dArr[(i3 * this.stateCount * this.stateCount) + (iArr[i2] * this.stateCount) + iArr2[i2]]);
            if (z) {
                computeCondStatMarkovJumps /= d2 * rateForCategory;
            }
            dArr2[i][i2] = computeCondStatMarkovJumps;
            if (z2) {
                this.histories[i][i2] = uniformizedSubstitutionModel.getCompleteHistory(this.useCompactHistory ? i2 + 1 : -1, Double.valueOf(d3), Double.valueOf(d4));
            }
        }
    }

    private void computeIntegratedMarkovJumpsForBranch(MarkovJumpsSubstitutionModel markovJumpsSubstitutionModel, double d, double d2, int i, int[] iArr, int[] iArr2, double[] dArr, double[][] dArr2, boolean z, double[][] dArr3, int[] iArr3) {
        for (int i2 = 0; i2 < this.categoryCount; i2++) {
            double rateForCategory = this.siteRateModel.getRateForCategory(i2);
            if (rateForCategory > 0.0d) {
                if (this.categoryCount == 1) {
                    markovJumpsSubstitutionModel.computeCondStatMarkovJumps(d * d2 * rateForCategory, dArr, dArr2[i2]);
                } else {
                    System.arraycopy(dArr, i2 * this.stateCount * this.stateCount, this.tmpProbabilities, 0, this.stateCount * this.stateCount);
                    markovJumpsSubstitutionModel.computeCondStatMarkovJumps(d * d2 * rateForCategory, this.tmpProbabilities, dArr2[i2]);
                }
                if (z) {
                    double d3 = d2 * rateForCategory;
                    for (int i3 = 0; i3 < dArr2[i2].length; i3++) {
                        double[] dArr4 = dArr2[i2];
                        int i4 = i3;
                        dArr4[i4] = dArr4[i4] / d3;
                    }
                }
            } else {
                Arrays.fill(dArr2[i2], 0.0d);
                if (markovJumpsSubstitutionModel.getType() == MarkovJumpsType.REWARDS && z) {
                    for (int i5 = 0; i5 < this.stateCount; i5++) {
                        dArr2[i2][(i5 * this.stateCount) + i5] = d;
                    }
                }
            }
        }
        for (int i6 = 0; i6 < this.patternCount; i6++) {
            dArr3[i][i6] = dArr2[iArr3 == null ? 0 : iArr3[i6]][(iArr[i6] * this.stateCount) + iArr2[i6]];
        }
    }

    @Override // dr.inference.model.AbstractModelLikelihood, dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        int i = this.patternCount * this.numRegisters;
        if (this.reportUnconditionedColumns) {
            i = this.categoryCount == 1 ? i + this.numRegisters : i * 2;
        }
        int i2 = 0;
        LogColumn[] logColumnArr = new LogColumn[i];
        for (int i3 = 0; i3 < this.numRegisters; i3++) {
            for (int i4 = 0; i4 < this.patternCount; i4++) {
                int i5 = i2;
                i2++;
                logColumnArr[i5] = new ConditionedCountColumn(this.jumpTag.get(i3), i3, i4);
                if (this.reportUnconditionedColumns && this.categoryCount > 1) {
                    i2++;
                    logColumnArr[i2] = new UnconditionedCountColumn(this.jumpTag.get(i3), i3, i4, this.rateCategory);
                }
            }
            if (this.reportUnconditionedColumns && this.categoryCount == 1) {
                int i6 = i2;
                i2++;
                logColumnArr[i6] = new UnconditionedCountColumn(this, this.jumpTag.get(i3), i3);
            }
        }
        return logColumnArr;
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.util.Citable
    public String getDescription() {
        return super.getDescription() + " (first citation) with MarkovJumps inference techniques (second citation)";
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.util.Citable
    public List<Citation> getCitations() {
        ArrayList arrayList = new ArrayList(super.getCitations());
        arrayList.add(CommonCitations.MININ_2008_COUNTING);
        return arrayList;
    }
}
