package dr.evomodel.continuous.hmc;

import dr.evolution.tree.Tree;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitPartialsProvider;
import dr.evomodel.treedatalikelihood.continuous.ElementaryVectorDataModel;
import dr.inference.hmc.PrecisionColumnProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Variable;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.WrappedVector;
import dr.xml.Reportable;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:dr/evomodel/continuous/hmc/TreePrecisionColumnProvider.class */
public class TreePrecisionColumnProvider extends AbstractModel implements PrecisionColumnProvider, Reportable {
    private final TreePrecisionTraitProductProvider productProvider;
    final Tree tree;
    private final ContinuousDataLikelihoodDelegate likelihoodDelegate;
    private final ContinuousTraitPartialsProvider tipData;
    private final Map<Integer, double[]> treeCache;
    private final int numTaxa;
    private final int dimTrait;
    private static final boolean DEBUG = false;
    private static final boolean DEBUG_CACHE = false;
    private static final boolean RESET_DATA = false;
    static final /* synthetic */ boolean $assertionsDisabled;

    public TreePrecisionColumnProvider(TreePrecisionTraitProductProvider treePrecisionTraitProductProvider) {
        super("treePrecisionColumnProvider");
        this.treeCache = new HashMap();
        this.productProvider = treePrecisionTraitProductProvider;
        this.tree = treePrecisionTraitProductProvider.getTree();
        this.likelihoodDelegate = treePrecisionTraitProductProvider.likelihoodDelegate;
        this.tipData = treePrecisionTraitProductProvider.getDataModel();
        this.numTaxa = this.tree.getExternalNodeCount();
        this.dimTrait = this.likelihoodDelegate.getTraitDim();
        if (!$assertionsDisabled && this.likelihoodDelegate.getTraitCount() != 1) {
            throw new AssertionError();
        }
        if (this.tree instanceof TreeModel) {
            addModel((TreeModel) this.tree);
        }
    }

    public double[] getColumn(int i) {
        double[] dArr = this.treeCache.get(Integer.valueOf(i));
        if (dArr == null) {
            dArr = setDataModelAndGetColumn(i);
            this.treeCache.put(Integer.valueOf(i), dArr);
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.tree) {
            this.treeCache.clear();
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    private double[] setDataModelAndGetColumn(int i) {
        CompoundParameter parameter = this.tipData.getParameter();
        if (this.tipData instanceof ContinuousTraitDataModel) {
            ReadableVector.Utils.setParameter(makeElementaryVector(i), this.tipData.getParameter());
        } else {
            if (!(this.tipData instanceof ElementaryVectorDataModel)) {
                throw new RuntimeException("Not yet implemented");
            }
            ((ElementaryVectorDataModel) this.tipData).setTipTraitDimParameters(i / this.dimTrait, 0, i % this.dimTrait);
        }
        return this.productProvider.getProduct(parameter);
    }

    private WrappedVector makeElementaryVector(int i) {
        double[] dArr = new double[this.numTaxa * this.dimTrait];
        dArr[i] = 1.0d;
        return new WrappedVector.Raw(dArr);
    }

    private double[] expensiveColumn(int i) {
        return this.likelihoodDelegate.getTreeTraitPrecision()[i];
    }

    private void debug(double[] dArr, int i) {
        double[] expensiveColumn = expensiveColumn(i);
        System.err.println("via FCD: " + new WrappedVector.Raw(dArr));
        System.err.println("direct : " + new WrappedVector.Raw(expensiveColumn));
        System.err.println();
    }

    public String getReport() {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < this.numTaxa; i++) {
            for (int i2 = 0; i2 < this.dimTrait; i2++) {
                sb.append(new WrappedVector.Raw(getColumn((i * this.dimTrait) + i2))).append("\n");
            }
        }
        return sb.toString();
    }

    static {
        $assertionsDisabled = !TreePrecisionColumnProvider.class.desiredAssertionStatus();
    }
}
