package dr.evomodel.antigenic;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.TreeUtils;
import dr.evomodel.tree.TreeModel;
import dr.evomodelxml.treelikelihood.TreeTraitParserUtilities;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.GammaFunction;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.StringAttributeRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/antigenic/NPAntigenicLikelihood.class */
public class NPAntigenicLikelihood extends AbstractModelLikelihood {
    public static final String NP_ANTIGENIC_LIKELIHOOD = "NPAntigenicLikelihood";
    Set<NodeRef> allTips;
    CompoundParameter traitParameter;
    Parameter alpha;
    Parameter clusterPrec;
    Parameter priorPrec;
    Parameter priorMean;
    Parameter assignments;
    Parameter links;
    Parameter means2;
    Parameter means1;
    Parameter locationDrift;
    Parameter offsets;
    boolean hasDrift;
    private boolean depMatrixKnown;
    private boolean[] dataMatrixKnown;
    private boolean logLikelihoodKnown;
    private double logLikelihood;
    private boolean[] logLikelihoodsVectorKnown;
    boolean proposedChangeDepMatrix;
    boolean proposedChangeDataMatrix;
    TreeModel treeModel;
    String traitName;
    double[][] depMatrix;
    double[][] logDepMatrix;
    double[] logLikelihoodsVector;
    double[] storedLogLikelihoodsVector;
    int numdata;
    Parameter transformFactor;
    double k0;
    double v0;
    double[][] T0Inv;
    double[] m;
    double logDetT0;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.antigenic.NPAntigenicLikelihood.1
        public static final String CLUSTER_PREC = "clusterPrec";
        public static final String PRIOR_PREC = "priorPrec";
        public static final String PRIOR_MEAN = "priorMean";
        public static final String ASSIGNMENTS = "assignments";
        public static final String LINKS = "links";
        public static final String MEANS_1 = "clusterMeans1";
        public static final String MEANS_2 = "clusterMeans2";
        public static final String TRANSFORM_FACTOR = "transformFactor";
        public static final String CHI = "chi";
        public static final String OFFSETS = "offsets";
        public static final String LOCATION_DRIFT = "locationDrift";
        boolean integrate = false;
        private final XMLSyntaxRule[] rules = {new StringAttributeRule("traitName", "The name of the trait for which a likelihood should be calculated"), new ElementRule("traitParameter", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(PRIOR_PREC, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(CLUSTER_PREC, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(PRIOR_MEAN, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("assignments", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("links", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(TRANSFORM_FACTOR, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MEANS_1, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule(MEANS_2, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("chi", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("offsets", Parameter.class), new ElementRule("locationDrift", Parameter.class), new ElementRule(TreeModel.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return NPAntigenicLikelihood.NP_ANTIGENIC_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeModel treeModel = (TreeModel) xMLObject.getChild(TreeModel.class);
            Parameter parameter = (Parameter) xMLObject.getChild(CLUSTER_PREC).getChild(Parameter.class);
            Parameter parameter2 = (Parameter) xMLObject.getChild(PRIOR_PREC).getChild(Parameter.class);
            Parameter parameter3 = (Parameter) xMLObject.getChild(PRIOR_MEAN).getChild(Parameter.class);
            Parameter parameter4 = (Parameter) xMLObject.getChild("assignments").getChild(Parameter.class);
            Parameter parameter5 = (Parameter) xMLObject.getChild("links").getChild(Parameter.class);
            Parameter parameter6 = (Parameter) xMLObject.getChild(MEANS_2).getChild(Parameter.class);
            Parameter parameter7 = (Parameter) xMLObject.getChild(MEANS_1).getChild(Parameter.class);
            Parameter parameter8 = (Parameter) xMLObject.getChild("chi").getChild(Parameter.class);
            Parameter parameter9 = (Parameter) xMLObject.getChild(TRANSFORM_FACTOR).getChild(Parameter.class);
            Parameter parameter10 = (Parameter) xMLObject.getChild("locationDrift").getChild(Parameter.class);
            Parameter parameter11 = (Parameter) xMLObject.getChild("offsets").getChild(Parameter.class);
            boolean z = false;
            if (parameter11.getDimension() > 1) {
                z = true;
            }
            return new NPAntigenicLikelihood(treeModel, new TreeTraitParserUtilities().parseTraitsFromTaxonAttributes(xMLObject, "trait", treeModel, this.integrate).traitParameter, parameter4, parameter5, parameter8, parameter, parameter3, parameter2, parameter9, parameter7, parameter6, parameter10, parameter11, Boolean.valueOf(z));
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "conditional likelihood ddCRP";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return NPAntigenicLikelihood.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };

    public NPAntigenicLikelihood(TreeModel treeModel, CompoundParameter compoundParameter, Parameter parameter, Parameter parameter2, Parameter parameter3, Parameter parameter4, Parameter parameter5, Parameter parameter6, Parameter parameter7, Parameter parameter8, Parameter parameter9, Parameter parameter10, Parameter parameter11, Boolean bool) {
        super(NP_ANTIGENIC_LIKELIHOOD);
        this.depMatrixKnown = false;
        this.logLikelihoodKnown = false;
        this.logLikelihood = 0.0d;
        this.proposedChangeDepMatrix = false;
        this.proposedChangeDataMatrix = false;
        this.assignments = parameter;
        this.links = parameter2;
        this.clusterPrec = parameter4;
        this.priorPrec = parameter6;
        this.priorMean = parameter5;
        this.treeModel = treeModel;
        this.traitParameter = compoundParameter;
        this.transformFactor = parameter7;
        this.means1 = parameter8;
        this.means2 = parameter9;
        this.alpha = parameter3;
        this.locationDrift = parameter10;
        this.offsets = parameter11;
        this.hasDrift = false;
        addVariable(compoundParameter);
        addVariable(parameter);
        addVariable(parameter2);
        addModel(treeModel);
        addVariable(parameter3);
        addVariable(parameter7);
        addVariable(this.alpha);
        addVariable(parameter11);
        this.numdata = compoundParameter.getParameterCount();
        this.allTips = TreeUtils.getExternalNodes(treeModel, treeModel.getRoot());
        setDepMatrix();
        for (int i = 0; i < this.numdata; i++) {
            parameter.setParameterValue(i, i);
            parameter2.setParameterValue(i, i);
        }
        this.logLikelihoodsVector = new double[parameter2.getDimension() + 1];
        this.logLikelihoodsVectorKnown = new boolean[parameter2.getDimension() + 1];
        this.storedLogLikelihoodsVector = new double[parameter2.getDimension() + 1];
        this.m = new double[2];
        this.m[0] = parameter5.getParameterValue(0);
        this.m[1] = parameter5.getParameterValue(1);
        this.v0 = 2.0d;
        this.k0 = parameter6.getParameterValue(0) / parameter4.getParameterValue(0);
        this.T0Inv = new double[2][2];
        this.T0Inv[0][0] = this.v0 / parameter4.getParameterValue(0);
        this.T0Inv[1][1] = this.v0 / parameter4.getParameterValue(0);
        this.T0Inv[1][0] = 0.0d;
        this.T0Inv[0][1] = 0.0d;
        this.logDetT0 = -Math.log(this.T0Inv[0][0] * this.T0Inv[1][1]);
    }

    private void setDepMatrix() {
        this.depMatrixKnown = true;
        this.depMatrix = new double[this.numdata][this.numdata];
        recursion(this.treeModel.getRoot(), new ArrayList());
        logCorrectMatrix(this.transformFactor.getParameterValue(0));
        this.logDepMatrix = new double[this.numdata][this.numdata];
        for (int i = 0; i < this.numdata; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                this.logDepMatrix[i][i2] = Math.log(this.depMatrix[i][i2]);
                this.logDepMatrix[i2][i] = this.logDepMatrix[i2][i];
            }
        }
    }

    public double getLogLikGroup(int i) {
        double d = 0.0d;
        int i2 = 0;
        for (int i3 = 0; i3 < this.assignments.getDimension(); i3++) {
            if (((int) this.assignments.getParameterValue(i3)) == i) {
                i2++;
            }
        }
        if (i2 != 0) {
            double[][] dArr = new double[i2][2];
            double[] dArr2 = new double[2];
            int i4 = 0;
            for (int i5 = 0; i5 < this.assignments.getDimension(); i5++) {
                if (((int) this.assignments.getParameterValue(i5)) == i) {
                    dArr[i4][0] = getData(i5, 0);
                    dArr[i4][1] = getData(i5, 0);
                    dArr2[0] = dArr2[0] + dArr[i4][0];
                    dArr2[1] = dArr2[1] + dArr[i4][1];
                    i4++;
                }
            }
            dArr2[0] = dArr2[0] / i2;
            dArr2[1] = dArr2[1] / i2;
            double d2 = this.k0 + i2;
            double d3 = this.v0 + i2;
            double[][] dArr3 = new double[2][2];
            for (int i6 = 0; i6 < i2; i6++) {
                double[] dArr4 = dArr3[0];
                dArr4[0] = dArr4[0] + ((dArr[i6][0] - dArr2[0]) * (dArr[i6][0] - dArr2[0]));
                double[] dArr5 = dArr3[0];
                dArr5[1] = dArr5[1] + ((dArr[i6][0] - dArr2[0]) * (dArr[i6][1] - dArr2[1]));
                double[] dArr6 = dArr3[1];
                dArr6[0] = dArr6[0] + ((dArr[i6][0] - dArr2[0]) * (dArr[i6][1] - dArr2[1]));
                double[] dArr7 = dArr3[1];
                dArr7[1] = dArr7[1] + ((dArr[i6][1] - dArr2[1]) * (dArr[i6][1] - dArr2[1]));
            }
            double[][] dArr8 = new double[2][2];
            dArr8[0][0] = this.T0Inv[0][0] + (i2 * (this.k0 / d2) * (dArr2[0] - this.m[0]) * (dArr2[0] - this.m[0])) + dArr3[0][0];
            dArr8[0][1] = this.T0Inv[0][1] + (i2 * (this.k0 / d2) * (dArr2[1] - this.m[1]) * (dArr2[0] - this.m[0])) + dArr3[0][1];
            dArr8[1][0] = this.T0Inv[1][0] + (i2 * (this.k0 / d2) * (dArr2[0] - this.m[0]) * (dArr2[1] - this.m[1])) + dArr3[1][0];
            dArr8[1][1] = this.T0Inv[1][1] + (i2 * (this.k0 / d2) * (dArr2[1] - this.m[1]) * (dArr2[1] - this.m[1])) + dArr3[1][1];
            d = 0.0d + ((-i2) * Math.log(3.141592653589793d)) + (Math.log(this.k0) - Math.log(d2)) + (((d3 / 2.0d) * (-Math.log((dArr8[0][0] * dArr8[1][1]) - (dArr8[0][1] * dArr8[1][0])))) - ((this.v0 / 2.0d) * this.logDetT0)) + GammaFunction.lnGamma(d3 / 2.0d) + GammaFunction.lnGamma((d3 / 2.0d) - 0.5d) + ((-GammaFunction.lnGamma(this.v0 / 2.0d)) - GammaFunction.lnGamma((this.v0 / 2.0d) - 0.5d));
        }
        this.logLikelihoodsVectorKnown[i] = true;
        return d;
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    public double[] getLogLikelihoodsVector() {
        return this.logLikelihoodsVector;
    }

    public Parameter getLinks() {
        return this.links;
    }

    public Parameter getAssignments() {
        return this.assignments;
    }

    public double getData(int i, int i2) {
        return this.traitParameter.getParameter(i).getParameterValue(i2);
    }

    public double[][] getDepMatrix() {
        return this.depMatrix;
    }

    public double[][] getLogDepMatrix() {
        return this.logDepMatrix;
    }

    public Parameter getPriorMean() {
        return this.priorMean;
    }

    public Parameter getPriorPrec() {
        return this.priorPrec;
    }

    public Parameter getClusterPrec() {
        return this.clusterPrec;
    }

    public void setLogLikelihoodsVector(int i, double d) {
        this.logLikelihoodsVector[i] = d;
    }

    public void setAssingments(int i, double d) {
        this.assignments.setParameterValue(i, d);
    }

    public void setLinks(int i, double d) {
        this.links.setParameterValue(i, d);
    }

    public void setMeans(int i, double[] dArr) {
        this.means1.setParameterValue(i, dArr[0]);
        this.means2.setParameterValue(i, dArr[1]);
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.logLikelihoodKnown) {
            this.logLikelihood = computeLogLikelihood();
        }
        return this.logLikelihood;
    }

    public double computeLogLikelihood() {
        if (!this.depMatrixKnown) {
            setDepMatrix();
        }
        double d = 0.0d;
        for (int i = 0; i < this.logLikelihoodsVector.length; i++) {
            if (!this.logLikelihoodsVectorKnown[i]) {
                this.logLikelihoodsVector[i] = getLogLikGroup(i);
            }
            d += this.logLikelihoodsVector[i];
        }
        for (int i2 = 0; i2 < this.links.getDimension(); i2++) {
            double log = this.links.getParameterValue(i2) == ((double) i2) ? d + Math.log(this.alpha.getParameterValue(0)) : d + Math.log(this.depMatrix[i2][(int) this.links.getParameterValue(i2)]);
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.numdata; i3++) {
                if (i3 != i2) {
                    d2 += this.depMatrix[i3][i2];
                }
            }
            d = log - Math.log(this.alpha.getParameterValue(0) + d2);
        }
        this.logLikelihoodKnown = true;
        return d;
    }

    void recursion(NodeRef nodeRef, List list) {
        ArrayList<NodeRef> arrayList = new ArrayList();
        ArrayList<NodeRef> arrayList2 = new ArrayList();
        if (this.treeModel.isExternal(nodeRef)) {
            list.add(nodeRef);
            return;
        }
        recursion(this.treeModel.getChild(nodeRef, 0), arrayList);
        recursion(this.treeModel.getChild(nodeRef, 1), arrayList2);
        double branchLength = this.treeModel.getBranchLength(this.treeModel.getChild(nodeRef, 0));
        double branchLength2 = this.treeModel.getBranchLength(this.treeModel.getChild(nodeRef, 1));
        HashSet<NodeRef> hashSet = new HashSet();
        hashSet.addAll(this.allTips);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            hashSet.remove((NodeRef) it.next());
        }
        HashSet<NodeRef> hashSet2 = new HashSet();
        hashSet2.addAll(this.allTips);
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            hashSet2.remove((NodeRef) it2.next());
        }
        for (NodeRef nodeRef2 : arrayList) {
            for (NodeRef nodeRef3 : hashSet) {
                double[] dArr = this.depMatrix[nodeRef3.getNumber()];
                int number = nodeRef2.getNumber();
                dArr[number] = dArr[number] + branchLength;
                double[] dArr2 = this.depMatrix[nodeRef2.getNumber()];
                int number2 = nodeRef3.getNumber();
                dArr2[number2] = dArr2[number2] + branchLength;
            }
        }
        for (NodeRef nodeRef4 : arrayList2) {
            for (NodeRef nodeRef5 : hashSet2) {
                double[] dArr3 = this.depMatrix[nodeRef5.getNumber()];
                int number3 = nodeRef4.getNumber();
                dArr3[number3] = dArr3[number3] + branchLength2;
                double[] dArr4 = this.depMatrix[nodeRef4.getNumber()];
                int number4 = nodeRef5.getNumber();
                dArr4[number4] = dArr4[number4] + branchLength2;
            }
        }
        list.addAll(arrayList);
        list.addAll(arrayList2);
    }

    void logCorrectMatrix(double d) {
        for (int i = 0; i < this.numdata; i++) {
            for (int i2 = 0; i2 < i; i2++) {
                this.depMatrix[i][i2] = 1.0d / Math.pow(this.depMatrix[i][i2], d);
                this.depMatrix[i2][i] = this.depMatrix[i][i2];
            }
        }
    }

    public double getTreeDist(int i, int i2) {
        double d = 0.0d;
        NodeRef findMRCA = findMRCA(i, i2);
        NodeRef externalNode = this.treeModel.getExternalNode(i);
        while (true) {
            NodeRef nodeRef = externalNode;
            if (nodeRef == findMRCA) {
                break;
            }
            d += this.treeModel.getBranchLength(nodeRef);
            externalNode = this.treeModel.getParent(nodeRef);
        }
        NodeRef externalNode2 = this.treeModel.getExternalNode(i2);
        while (true) {
            NodeRef nodeRef2 = externalNode2;
            if (nodeRef2 == findMRCA) {
                return d;
            }
            d += this.treeModel.getBranchLength(nodeRef2);
            externalNode2 = this.treeModel.getParent(nodeRef2);
        }
    }

    private NodeRef findMRCA(int i, int i2) {
        HashSet hashSet = new HashSet();
        hashSet.add(this.treeModel.getTaxonId(i));
        hashSet.add(this.treeModel.getTaxonId(i2));
        return TreeUtils.getCommonAncestorNode(this.treeModel, hashSet);
    }

    public void printInformtion(double[][] dArr) {
        StringBuffer stringBuffer = new StringBuffer("matrix \n");
        for (int i = 0; i < this.numdata; i++) {
            stringBuffer.append(" \n");
            for (int i2 = 0; i2 < this.numdata; i2++) {
                stringBuffer.append(dArr[i][i2] + " \t");
            }
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(Parameter parameter) {
        StringBuffer stringBuffer = new StringBuffer("Vector \n");
        for (int i = 0; i < this.numdata; i++) {
            stringBuffer.append(parameter.getParameterValue(i) + " \t");
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(int[] iArr) {
        StringBuffer stringBuffer = new StringBuffer("Vector \n");
        for (int i = 0; i < this.numdata; i++) {
            stringBuffer.append(iArr[i] + " \t");
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printOrder() {
        StringBuffer stringBuffer = new StringBuffer("taxa \n");
        for (int i = 0; i < this.numdata; i++) {
            stringBuffer.append(" \n");
            stringBuffer.append(this.treeModel.getTaxonId(i));
        }
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(double d) {
        StringBuffer stringBuffer = new StringBuffer("Info \n");
        stringBuffer.append(d);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(String str) {
        StringBuffer stringBuffer = new StringBuffer("Info \n");
        stringBuffer.append(str);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    public void printInformation(String str, String str2) {
        StringBuffer stringBuffer = new StringBuffer("Info \n");
        stringBuffer.append(str + " and " + str2);
        Logger.getLogger("dr.evomodel").info(stringBuffer.toString());
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        System.arraycopy(this.logLikelihoodsVector, 0, this.storedLogLikelihoodsVector, 0, this.logLikelihoodsVector.length);
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        double[] dArr = this.logLikelihoodsVector;
        this.logLikelihoodsVector = this.storedLogLikelihoodsVector;
        this.storedLogLikelihoodsVector = dArr;
        this.depMatrixKnown = !this.proposedChangeDepMatrix;
        this.proposedChangeDepMatrix = false;
        this.logLikelihoodKnown = false;
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
    }

    @Override // dr.inference.model.AbstractModel
    public void acceptState() {
        this.proposedChangeDepMatrix = false;
        this.proposedChangeDataMatrix = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
        if (model == this.treeModel) {
            this.depMatrixKnown = false;
        }
        this.logLikelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected final void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.logLikelihoodKnown = false;
        if (variable == this.transformFactor) {
            this.depMatrixKnown = false;
            this.proposedChangeDepMatrix = true;
        }
        if (variable == this.traitParameter) {
            this.logLikelihoodsVectorKnown[(int) this.assignments.getParameterValue(i / 2)] = false;
        }
    }
}
