package dr.evomodel.treelikelihood;

import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.UncertainSiteList;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.GeneralDataType;
import dr.evolution.datatype.HiddenCodons;
import dr.evolution.datatype.HiddenDataType;
import dr.evolution.tree.MutableTreeModel;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evolution.tree.TreeTraitProvider;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.tipstatesmodel.TipStatesModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:dr/evomodel/treelikelihood/AncestralStateBeagleTreeLikelihood.class */
public class AncestralStateBeagleTreeLikelihood extends BeagleTreeLikelihood implements TreeTraitProvider, AncestralStateTraitProvider {
    protected TreeTraitProvider.Helper treeTraits;
    private final DataType dataType;
    private int[][] reconstructedStates;
    private int[][] storedReconstructedStates;
    protected boolean areStatesRedrawn;
    protected boolean storedAreStatesRedrawn;
    private boolean useMAP;
    private boolean returnMarginalLogLikelihood;
    private double jointLogLikelihood;
    private double storedJointLogLikelihood;
    private int[][] tipStates;
    private double[][] tipPartials;
    private double[] probabilities;
    private double[] partials;
    protected int[] rateCategory;
    private final CodeFormatter formatter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:dr/evomodel/treelikelihood/AncestralStateBeagleTreeLikelihood$CodeFormatter.class */
    public class CodeFormatter {
        private final DataType dataType;
        private final Function<String, String> appender;
        private final Function<Integer, String> getter;
        private boolean first = true;

        /* JADX WARN: Multi-variable type inference failed */
        CodeFormatter(DataType dataType, boolean z) {
            Function<Integer, String> function;
            this.dataType = dataType;
            this.appender = dataType instanceof GeneralDataType ? str -> {
                return str + " ";
            } : Function.identity();
            if (dataType instanceof HiddenCodons) {
                if (z) {
                    HiddenCodons hiddenCodons = (HiddenCodons) dataType;
                    Objects.requireNonNull(hiddenCodons);
                    function = (v1) -> {
                        return r1.getTripletWithoutHiddenCode(v1);
                    };
                } else {
                    Objects.requireNonNull(dataType);
                    function = (v1) -> {
                        return r1.getTriplet(v1);
                    };
                }
                this.getter = function;
                return;
            }
            if (!(dataType instanceof HiddenDataType) || !z) {
                Objects.requireNonNull(dataType);
                this.getter = (v1) -> {
                    return r1.getCode(v1);
                };
            } else {
                HiddenDataType hiddenDataType = (HiddenDataType) dataType;
                Objects.requireNonNull(hiddenDataType);
                this.getter = (v1) -> {
                    return r1.getCodeWithoutHiddenState(v1);
                };
            }
        }

        String getCodeString(int i) {
            String apply = this.getter.apply(Integer.valueOf(i));
            if (this.first) {
                this.first = false;
            } else {
                apply = this.appender.apply(apply);
            }
            return apply;
        }

        void reset() {
            this.first = true;
        }
    }

    /* JADX WARN: Type inference failed for: r1v20, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r1v39, types: [double[], double[][]] */
    public AncestralStateBeagleTreeLikelihood(PatternList patternList, MutableTreeModel mutableTreeModel, BranchModel branchModel, SiteRateModel siteRateModel, BranchRateModel branchRateModel, TipStatesModel tipStatesModel, boolean z, PartialsRescalingScheme partialsRescalingScheme, boolean z2, Map<Set<String>, Parameter> map, DataType dataType, final String str, boolean z3, boolean z4) {
        super(patternList, mutableTreeModel, branchModel, siteRateModel, branchRateModel, tipStatesModel, z, partialsRescalingScheme, z2, map);
        this.treeTraits = new TreeTraitProvider.Helper();
        this.areStatesRedrawn = false;
        this.storedAreStatesRedrawn = false;
        this.useMAP = false;
        this.returnMarginalLogLikelihood = true;
        this.rateCategory = null;
        this.dataType = dataType;
        this.probabilities = new double[this.stateCount * this.stateCount * this.categoryCount];
        this.partials = new double[this.stateCount * this.patternCount * this.categoryCount];
        if (useAmbiguities()) {
            this.tipPartials = new double[this.tipCount];
        } else {
            this.tipStates = new int[this.tipCount];
        }
        for (int i = 0; i < this.tipCount; i++) {
            int taxonIndex = patternList.getTaxonIndex(mutableTreeModel.getTaxonId(i));
            if (useAmbiguities()) {
                this.tipPartials[i] = getPartials(patternList, taxonIndex);
            } else {
                this.tipStates[i] = getStates(patternList, taxonIndex);
            }
        }
        this.reconstructedStates = new int[mutableTreeModel.getNodeCount()][this.patternCount];
        this.storedReconstructedStates = new int[mutableTreeModel.getNodeCount()][this.patternCount];
        this.useMAP = z3;
        this.returnMarginalLogLikelihood = z4;
        this.formatter = new CodeFormatter(dataType, false);
        this.treeTraits.addTrait(new TreeTrait.IA() { // from class: dr.evomodel.treelikelihood.AncestralStateBeagleTreeLikelihood.1
            @Override // dr.evolution.tree.TreeTrait
            public String getTraitName() {
                return str;
            }

            @Override // dr.evolution.tree.TreeTrait
            public TreeTrait.Intent getIntent() {
                return TreeTrait.Intent.NODE;
            }

            @Override // dr.evolution.tree.TreeTrait.IA, dr.evolution.tree.TreeTrait
            public Class getTraitClass() {
                return int[].class;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // dr.evolution.tree.TreeTrait
            public int[] getTrait(Tree tree, NodeRef nodeRef) {
                return AncestralStateBeagleTreeLikelihood.this.getStatesForNode(tree, nodeRef);
            }

            @Override // dr.evolution.tree.TreeTrait.IA, dr.evolution.tree.TreeTrait
            public String getTraitString(Tree tree, NodeRef nodeRef) {
                return AncestralStateBeagleTreeLikelihood.formattedState(AncestralStateBeagleTreeLikelihood.this.getStatesForNode(tree, nodeRef), AncestralStateBeagleTreeLikelihood.this.formatter);
            }
        });
    }

    private double[] getPartials(PatternList patternList, int i) {
        double[] dArr = new double[this.patternCount * this.stateCount];
        int i2 = 0;
        for (int i3 = 0; i3 < this.patternCount; i3++) {
            if (patternList instanceof UncertainSiteList) {
                ((UncertainSiteList) patternList).fillPartials(i, i3, dArr, i2);
                i2 += this.stateCount;
            } else {
                boolean[] stateSet = this.dataType.getStateSet(patternList.getPatternState(i, i3));
                for (int i4 = 0; i4 < this.stateCount; i4++) {
                    if (stateSet[i4]) {
                        dArr[i2] = 1.0d;
                    } else {
                        dArr[i2] = 0.0d;
                    }
                    i2++;
                }
            }
        }
        return dArr;
    }

    private int[] getStates(PatternList patternList, int i) {
        int[] iArr = new int[this.patternCount];
        for (int i2 = 0; i2 < this.patternCount; i2++) {
            iArr[i2] = patternList.getPatternState(i, i2);
        }
        return iArr;
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood
    public BranchModel getBranchModel() {
        return this.branchModel;
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait[] getTreeTraits() {
        return this.treeTraits.getTreeTraits();
    }

    @Override // dr.evolution.tree.TreeTraitProvider
    public TreeTrait getTreeTrait(String str) {
        return this.treeTraits.getTreeTrait(str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.evomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        super.handleModelChangedEvent(model, obj, i);
        fireModelChanged(model);
    }

    public int[] getStatesForNode(Tree tree, NodeRef nodeRef) {
        if (tree != this.treeModel) {
            throw new RuntimeException("Can only reconstruct states on treeModel given to constructor");
        }
        if (!this.likelihoodKnown) {
            calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        if (!this.areStatesRedrawn) {
            redrawAncestralStates();
        }
        return this.reconstructedStates[nodeRef.getNumber()];
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood
    protected int getScaleBufferCount() {
        return this.internalNodeCount + 2;
    }

    private int drawChoice(double[] dArr) {
        if (!this.useMAP) {
            return MathUtils.randomChoicePDF(dArr);
        }
        double d = dArr[0];
        int i = 0;
        for (int i2 = 1; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.evomodel.treelikelihood.AbstractSinglePartitionTreeLikelihood, dr.evomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.Likelihood
    public void makeDirty() {
        super.makeDirty();
        this.areStatesRedrawn = false;
    }

    public void redrawAncestralStates() {
        this.jointLogLikelihood = 0.0d;
        traverseSample(this.treeModel, this.treeModel.getRoot(), null, null);
        this.areStatesRedrawn = true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.evomodel.treelikelihood.AbstractTreeLikelihood
    public double calculateLogLikelihood() {
        this.areStatesRedrawn = false;
        double calculateLogLikelihood = super.calculateLogLikelihood();
        if (this.returnMarginalLogLikelihood) {
            return calculateLogLikelihood;
        }
        redrawAncestralStates();
        return this.jointLogLikelihood;
    }

    @Override // dr.evomodel.treelikelihood.AncestralStateTraitProvider
    public String formattedState(int[] iArr) {
        return formattedState(iArr, this.formatter);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static String formattedState(int[] iArr, CodeFormatter codeFormatter) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("\"");
        codeFormatter.reset();
        for (int i : iArr) {
            stringBuffer.append(codeFormatter.getCodeString(i));
        }
        stringBuffer.append("\"");
        return stringBuffer.toString();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void getMatrix(int i, double[] dArr) {
        this.f9beagle.getTransitionMatrix(this.substitutionModelDelegate.getMatrixIndex(i), dArr);
    }

    public void setTipStates(int i, int[] iArr) {
        System.arraycopy(iArr, 0, this.tipStates[i], 0, iArr.length);
        this.f9beagle.setTipStates(i, iArr);
        makeDirty();
    }

    public void getTipStates(int i, int[] iArr) {
        System.arraycopy(this.tipStates[i], 0, iArr, 0, iArr.length);
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.evomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void storeState() {
        super.storeState();
        if (this.areStatesRedrawn) {
            for (int i = 0; i < this.reconstructedStates.length; i++) {
                System.arraycopy(this.reconstructedStates[i], 0, this.storedReconstructedStates[i], 0, this.reconstructedStates[i].length);
            }
        }
        this.storedAreStatesRedrawn = this.areStatesRedrawn;
        this.storedJointLogLikelihood = this.jointLogLikelihood;
    }

    @Override // dr.evomodel.treelikelihood.BeagleTreeLikelihood, dr.evomodel.treelikelihood.AbstractTreeLikelihood, dr.inference.model.AbstractModel
    public void restoreState() {
        super.restoreState();
        int[][] iArr = this.reconstructedStates;
        this.reconstructedStates = this.storedReconstructedStates;
        this.storedReconstructedStates = iArr;
        this.areStatesRedrawn = this.storedAreStatesRedrawn;
        this.jointLogLikelihood = this.storedJointLogLikelihood;
    }

    public void traverseSample(Tree tree, NodeRef nodeRef, int[] iArr, int[] iArr2) {
        int number = nodeRef.getNumber();
        NodeRef parent = tree.getParent(nodeRef);
        double[] dArr = new double[this.stateCount];
        int[] iArr3 = new int[this.patternCount];
        if (tree.isExternal(nodeRef)) {
            if (useAmbiguities()) {
                getMatrix(number, this.probabilities);
                double[] dArr2 = this.tipPartials[number];
                for (int i = 0; i < this.patternCount; i++) {
                    int i2 = iArr[i] * this.stateCount;
                    System.arraycopy(this.probabilities, i2 + ((iArr2 == null ? 0 : iArr2[i]) * this.stateCount * this.stateCount), dArr, 0, this.stateCount);
                    for (int i3 = 0; i3 < this.stateCount; i3++) {
                        int i4 = i3;
                        dArr[i4] = dArr[i4] * dArr2[(i * this.stateCount) + i3];
                    }
                    this.reconstructedStates[number][i] = drawChoice(dArr);
                    if (!this.returnMarginalLogLikelihood) {
                        this.jointLogLikelihood += Math.log(this.probabilities[i2 + this.reconstructedStates[number][i]]);
                    }
                }
            } else {
                getTipStates(number, this.reconstructedStates[number]);
                for (int i5 = 0; i5 < this.patternCount; i5++) {
                    int i6 = this.reconstructedStates[number][i5];
                    if (this.dataType.isAmbiguousState(i6)) {
                        int i7 = iArr[i5] * this.stateCount;
                        int i8 = (iArr2 == null ? 0 : iArr2[i5]) * this.stateCount * this.stateCount;
                        getMatrix(number, this.probabilities);
                        System.arraycopy(this.probabilities, i7 + i8, dArr, 0, this.stateCount);
                        if (this.useAmbiguities && !this.dataType.isUnknownState(i6)) {
                            boolean[] stateSet = this.dataType.getStateSet(i6);
                            for (int i9 = 0; i9 < this.stateCount; i9++) {
                                if (!stateSet[i9]) {
                                    dArr[i9] = 0.0d;
                                }
                            }
                        }
                        this.reconstructedStates[number][i5] = drawChoice(dArr);
                    }
                    if (!this.returnMarginalLogLikelihood) {
                        int i10 = iArr[i5] * this.stateCount;
                        getMatrix(number, this.probabilities);
                        if (!this.returnMarginalLogLikelihood) {
                            this.jointLogLikelihood += Math.log(this.probabilities[i10 + this.reconstructedStates[number][i5]]);
                        }
                    }
                }
            }
            hookCalculation(tree, parent, nodeRef, iArr, this.reconstructedStates[number], null, iArr2);
            return;
        }
        if (parent == null) {
            getPartials(number, this.partials);
            boolean z = this.categoryCount > 1;
            double[] dArr3 = null;
            double[] dArr4 = null;
            if (z) {
                iArr2 = new int[this.patternCount];
                dArr3 = new double[this.categoryCount];
                dArr4 = this.siteRateModel.getCategoryProportions();
            }
            for (int i11 = 0; i11 < this.patternCount; i11++) {
                if (z) {
                    for (int i12 = 0; i12 < this.categoryCount; i12++) {
                        dArr3[i12] = 0.0d;
                        for (int i13 = 0; i13 < this.stateCount; i13++) {
                            double[] dArr5 = dArr3;
                            int i14 = i12;
                            dArr5[i14] = dArr5[i14] + this.partials[(i12 * this.stateCount * this.patternCount) + (i11 * this.stateCount) + i13];
                        }
                        double[] dArr6 = dArr3;
                        int i15 = i12;
                        dArr6[i15] = dArr6[i15] * dArr4[i12];
                    }
                    iArr2[i11] = drawChoice(dArr3);
                }
                System.arraycopy(this.partials, ((iArr2 == null ? 0 : iArr2[i11]) * this.stateCount * this.patternCount) + (i11 * this.stateCount), dArr, 0, this.stateCount);
                double[] rootStateFrequencies = this.substitutionModelDelegate.getRootStateFrequencies();
                for (int i16 = 0; i16 < this.stateCount; i16++) {
                    int i17 = i16;
                    dArr[i17] = dArr[i17] * rootStateFrequencies[i16];
                }
                try {
                    iArr3[i11] = drawChoice(dArr);
                } catch (Error e) {
                    System.err.println(e.toString());
                    System.err.println("Please report error to Marc");
                    iArr3[i11] = 0;
                }
                this.reconstructedStates[number][i11] = iArr3[i11];
                if (!this.returnMarginalLogLikelihood) {
                    this.jointLogLikelihood += Math.log(rootStateFrequencies[iArr3[i11]]);
                }
            }
            if (z) {
                if (this.rateCategory == null) {
                    this.rateCategory = new int[this.patternCount];
                }
                System.arraycopy(iArr2, 0, this.rateCategory, 0, this.patternCount);
            }
        } else {
            double[] dArr7 = new double[this.stateCount * this.patternCount * this.categoryCount];
            getPartials(number, dArr7);
            getMatrix(number, this.probabilities);
            for (int i18 = 0; i18 < this.patternCount; i18++) {
                int i19 = iArr[i18] * this.stateCount;
                int i20 = i18 * this.stateCount;
                int i21 = iArr2 == null ? 0 : iArr2[i18];
                int i22 = i21 * this.stateCount * this.stateCount;
                int i23 = i21 * this.stateCount * this.patternCount;
                for (int i24 = 0; i24 < this.stateCount; i24++) {
                    dArr[i24] = dArr7[i23 + i20 + i24] * this.probabilities[i22 + i19 + i24];
                }
                iArr3[i18] = drawChoice(dArr);
                this.reconstructedStates[number][i18] = iArr3[i18];
                if (!this.returnMarginalLogLikelihood) {
                    this.jointLogLikelihood += Math.log(this.probabilities[i19 + iArr3[i18]]);
                }
            }
            hookCalculation(tree, parent, nodeRef, iArr, iArr3, this.probabilities, iArr2);
        }
        traverseSample(tree, tree.getChild(nodeRef, 0), iArr3, iArr2);
        traverseSample(tree, tree.getChild(nodeRef, 1), iArr3, iArr2);
    }

    protected void hookCalculation(Tree tree, NodeRef nodeRef, NodeRef nodeRef2, int[] iArr, int[] iArr2, double[] dArr, int[] iArr3) {
    }
}
