package dr.evomodel.branchmodel.lineagespecific;

import beagle.Beagle;
import beagle.BeagleFactory;
import dr.app.beagle.tools.BeagleSequenceSimulator;
import dr.app.beagle.tools.Partition;
import dr.evolution.alignment.PatternList;
import dr.evolution.alignment.SimpleAlignment;
import dr.evolution.datatype.DataType;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.io.NewickImporter;
import dr.evolution.tree.NodeRef;
import dr.evomodel.branchmodel.BranchModel;
import dr.evomodel.branchmodel.HomogeneousBranchModel;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.StrictClockBranchRates;
import dr.evomodel.siteratemodel.GammaSiteRateModel;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.evomodel.substmodel.FrequencyModel;
import dr.evomodel.substmodel.nucleotide.HKY;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.BufferIndexHelper;
import dr.evomodel.treelikelihood.BeagleTreeLikelihood;
import dr.evomodel.treelikelihood.PartialsRescalingScheme;
import dr.evomodel.treelikelihood.SubstitutionModelDelegate;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

/* loaded from: input_file:dr/evomodel/branchmodel/lineagespecific/BeagleBranchLikelihood.class */
public class BeagleBranchLikelihood implements Likelihood {
    private static final boolean DEBUG = true;
    private PatternList patternList;
    private TreeModel treeModel;
    private BranchModel branchModel;
    private SiteRateModel siteRateModel;
    private FrequencyModel freqModel;
    private BranchRateModel branchRateModel;
    private String id = null;
    private boolean used = true;

    /* renamed from: beagle, reason: collision with root package name */
    private Beagle f3beagle;
    private BufferIndexHelper matrixBufferHelper;
    private BufferIndexHelper partialBufferHelper;
    private SubstitutionModelDelegate substitutionModelDelegate;
    int nodeCount;
    boolean[] updateNode;

    /* loaded from: input_file:dr/evomodel/branchmodel/lineagespecific/BeagleBranchLikelihood$LikelihoodColumn.class */
    private class LikelihoodColumn extends NumberColumn {
        public LikelihoodColumn(String str) {
            super(str);
        }

        @Override // dr.inference.loggers.NumberColumn
        public double getDoubleValue() {
            return BeagleBranchLikelihood.this.getLogLikelihood();
        }
    }

    public BeagleBranchLikelihood(PatternList patternList, TreeModel treeModel, BranchModel branchModel, SiteRateModel siteRateModel, FrequencyModel frequencyModel, BranchRateModel branchRateModel) {
        this.patternList = patternList;
        this.treeModel = treeModel;
        this.branchModel = branchModel;
        this.siteRateModel = siteRateModel;
        this.freqModel = frequencyModel;
        this.branchRateModel = branchRateModel;
        loadBeagleInstance();
    }

    public double getBranchLogLikelihood(int i) {
        this.f3beagle.setCategoryRates(this.siteRateModel.getCategoryRates());
        this.f3beagle.setCategoryWeights(0, this.siteRateModel.getCategoryProportions());
        this.f3beagle.setStateFrequencies(0, this.substitutionModelDelegate.getRootStateFrequencies());
        this.substitutionModelDelegate.updateSubstitutionModels(this.f3beagle);
        setTipPartials();
        this.updateNode = new boolean[this.nodeCount];
        Arrays.fill(this.updateNode, true);
        NodeRef node = this.treeModel.getNode(i);
        traverse(this.treeModel, node);
        traverse(this.treeModel, this.treeModel.getParent(node));
        this.treeModel.getParent(node).getNumber();
        return new double[1][0];
    }

    private boolean traverse(TreeModel treeModel, NodeRef nodeRef) {
        boolean z = false;
        int number = nodeRef.getNumber();
        NodeRef parent = treeModel.getParent(nodeRef);
        if (parent != null && this.updateNode[number]) {
            double branchRate = this.branchRateModel.getBranchRate(treeModel, nodeRef);
            double nodeHeight = treeModel.getNodeHeight(parent);
            double nodeHeight2 = treeModel.getNodeHeight(nodeRef);
            double d = branchRate * (nodeHeight - nodeHeight2);
            this.substitutionModelDelegate.flipMatrixBuffer(number);
            this.substitutionModelDelegate.updateTransitionMatrices(this.f3beagle, new int[]{number}, new double[]{d}, 1);
            System.out.println("At branch " + number);
            System.out.println(" Length " + d + ": node " + number + ", height=" + nodeHeight2 + " parent " + parent);
            System.out.println(" Populating transition matrix buffer");
            this.updateNode[number] = false;
            z = true;
        }
        if (!treeModel.isExternal(nodeRef)) {
            NodeRef child = treeModel.getChild(nodeRef, 0);
            boolean traverse = traverse(treeModel, child);
            NodeRef child2 = treeModel.getChild(nodeRef, 1);
            boolean traverse2 = traverse(treeModel, child2);
            if (traverse || traverse2) {
                this.partialBufferHelper.flipOffset(number);
                this.f3beagle.updatePartials(new int[]{this.partialBufferHelper.getOffsetIndex(number), -1, -1, this.partialBufferHelper.getOffsetIndex(child.getNumber()), this.substitutionModelDelegate.getMatrixIndex(child.getNumber()), this.partialBufferHelper.getOffsetIndex(child2.getNumber()), this.substitutionModelDelegate.getMatrixIndex(child2.getNumber())}, 1, -1);
                System.out.println("At branch " + number);
                System.out.println(" Child nodes updated");
                System.out.println(" Populating partial buffer");
                this.updateNode[number] = false;
                z = true;
            }
        }
        return z;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        return 0.0d;
    }

    private void populateTransitionBuffers() {
        for (NodeRef nodeRef : this.treeModel.getNodes()) {
            int number = nodeRef.getNumber();
            this.matrixBufferHelper.flipOffset(number);
            this.substitutionModelDelegate.updateTransitionMatrices(this.f3beagle, new int[]{number}, new double[]{this.treeModel.getBranchLength(nodeRef) * this.branchRateModel.getBranchRate(this.treeModel, nodeRef)}, 1);
        }
    }

    private void setTipPartials() {
        int patternCount = this.patternList.getPatternCount();
        int taxonCount = this.treeModel.getTaxonCount();
        for (int i = 0; i < taxonCount; i++) {
            int taxonIndex = this.patternList.getTaxonIndex(this.treeModel.getTaxonId(i));
            int[] iArr = new int[patternCount];
            for (int i2 = 0; i2 < patternCount; i2++) {
                iArr[i2] = this.patternList.getPatternState(taxonIndex, i2);
            }
            this.f3beagle.setTipStates(i, iArr);
        }
    }

    public void finalizeBeagle() throws Throwable {
        this.f3beagle.finalize();
    }

    private void loadBeagleInstance() {
        this.substitutionModelDelegate = new SubstitutionModelDelegate(this.treeModel, this.branchModel);
        DataType dataType = this.freqModel.getDataType();
        int patternCount = this.patternList.getPatternCount();
        this.nodeCount = this.treeModel.getNodeCount();
        this.matrixBufferHelper = new BufferIndexHelper(this.nodeCount, 0);
        int externalNodeCount = this.treeModel.getExternalNodeCount();
        int internalNodeCount = this.treeModel.getInternalNodeCount();
        this.partialBufferHelper = new BufferIndexHelper(this.nodeCount, externalNodeCount);
        this.f3beagle = BeagleFactory.loadBeagleInstance(externalNodeCount, this.partialBufferHelper.getBufferCount(), externalNodeCount, dataType.getStateCount(), patternCount, this.substitutionModelDelegate.getEigenBufferCount(), this.substitutionModelDelegate.getMatrixBufferCount(), this.siteRateModel.getCategoryCount(), new BufferIndexHelper(internalNodeCount + 1, 0).getBufferCount(), new int[]{0}, 0L, 0L);
    }

    @Override // dr.inference.loggers.Loggable
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArr = new LogColumn[1];
        logColumnArr[0] = new LikelihoodColumn(getId() == null ? "likelihood" : getId());
        return logColumnArr;
    }

    @Override // dr.util.Identifiable
    public String getId() {
        return this.id;
    }

    @Override // dr.util.Identifiable
    public void setId(String str) {
        this.id = str;
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return null;
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
    }

    @Override // dr.inference.model.Likelihood
    public String prettyName() {
        return Likelihood.Abstract.getPrettyName(this);
    }

    @Override // dr.inference.model.Likelihood
    public Set<Likelihood> getLikelihoodSet() {
        return new HashSet(Arrays.asList(this));
    }

    @Override // dr.inference.model.Likelihood
    public boolean isUsed() {
        return this.used;
    }

    @Override // dr.inference.model.Likelihood
    public void setUsed() {
        this.used = true;
    }

    @Override // dr.inference.model.Likelihood
    public boolean evaluateEarly() {
        return false;
    }

    public static void main(String[] strArr) {
        try {
            MathUtils.setSeed(666L);
            ArrayList arrayList = new ArrayList();
            TreeModel treeModel = new TreeModel(new NewickImporter("((SimSeq1:22.0,SimSeq2:22.0):12.0,(SimSeq3:23.1,SimSeq4:23.1):10.899999999999999);").importTree(null));
            FrequencyModel frequencyModel = new FrequencyModel(Nucleotides.INSTANCE, new Parameter.Default(new double[]{0.25d, 0.25d, 0.25d, 0.25d}));
            HKY hky = new HKY(new Parameter.Default(1, 1.0d), frequencyModel);
            HomogeneousBranchModel homogeneousBranchModel = new HomogeneousBranchModel(hky);
            new ArrayList().add(hky);
            new ArrayList().add(frequencyModel);
            StrictClockBranchRates strictClockBranchRates = new StrictClockBranchRates(new Parameter.Default(1, 1.0d));
            GammaSiteRateModel gammaSiteRateModel = new GammaSiteRateModel("siteModel");
            arrayList.add(new Partition(treeModel, homogeneousBranchModel, gammaSiteRateModel, strictClockBranchRates, frequencyModel, 0, 1000 - 1, 1));
            SimpleAlignment simulate = new BeagleSequenceSimulator(arrayList).simulate(false, false);
            System.out.println(simulate);
            System.out.println("BTL(homogeneous) = " + new BeagleTreeLikelihood(simulate, treeModel, homogeneousBranchModel, gammaSiteRateModel, strictClockBranchRates, null, false, PartialsRescalingScheme.DEFAULT, true).getLogLikelihood());
            BeagleBranchLikelihood beagleBranchLikelihood = new BeagleBranchLikelihood(simulate, treeModel, homogeneousBranchModel, gammaSiteRateModel, frequencyModel, strictClockBranchRates);
            System.out.println(beagleBranchLikelihood.getBranchLogLikelihood(4));
            beagleBranchLikelihood.finalizeBeagle();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(-1);
        } catch (Throwable th) {
            th.printStackTrace();
            System.exit(-1);
        }
    }
}
