package dr.evomodel.continuous;

import dr.evolution.alignment.PatternList;
import dr.evolution.tree.NodeRef;
import dr.evomodel.continuous.LatentTruncation;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundParameter;
import dr.inference.model.FastMatrixParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.distributions.Distribution;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AndRule;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import dr.xml.XORRule;
import java.util.ArrayList;
import java.util.List;
import java.util.logging.Logger;

/* loaded from: input_file:dr/evomodel/continuous/OrderedLatentLiabilityLikelihood.class */
public class OrderedLatentLiabilityLikelihood extends AbstractModelLikelihood implements LatentTruncation, Citable, SoftThresholdLikelihood {
    public static final String ORDERED_LATENT_LIABILITY_LIKELIHOOD = "orderedLatentLiabilityLikelihood";
    private final LatentTruncation.Delegate normalizationDelegate;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.continuous.OrderedLatentLiabilityLikelihood.2
        public static final String TIP_TRAIT = "tipTrait";
        public static final String THRESHOLD_PARAMETER = "threshold";
        public static final String NUM_CLASSES = "numClasses";
        public static final String IS_UNORDERED = "isUnordered";
        public static final String N_DATA = "NData";
        public static final String N_TRAITS = "NTraits";
        private final XMLSyntaxRule[] rules = {new XORRule(new ElementRule(AbstractMultivariateTraitLikelihood.class), new AndRule(AttributeRule.newIntegerRule(N_DATA), AttributeRule.newIntegerRule(N_TRAITS))), new ElementRule("tipTrait", CompoundParameter.class, "The parameter of tip locations from the tree"), new ElementRule("threshold", CompoundParameter.class, "The parameter with nonzero thershold values"), new ElementRule("numClasses", Parameter.class, "Number of multinomial classes in each dimention"), new ElementRule(PatternList.class, "The binary/multinomial tip data"), new ElementRule(TreeModel.class, "The tree model"), AttributeRule.newBooleanRule(IS_UNORDERED, true)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return OrderedLatentLiabilityLikelihood.ORDERED_LATENT_LIABILITY_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            int i;
            int dimension;
            TreeModel treeModel = (TreeModel) xMLObject.getChild(TreeModel.class);
            int taxonCount = treeModel.getTaxonCount();
            CompoundParameter compoundParameter = (CompoundParameter) xMLObject.getElementFirstChild("tipTrait");
            if (xMLObject.hasAttribute(N_DATA) && xMLObject.hasAttribute(N_TRAITS)) {
                String str = (String) xMLObject.getAttribute(N_DATA);
                String str2 = (String) xMLObject.getAttribute(N_TRAITS);
                i = Integer.parseInt(str);
                dimension = Integer.parseInt(str2);
            } else {
                AbstractMultivariateTraitLikelihood abstractMultivariateTraitLikelihood = (AbstractMultivariateTraitLikelihood) xMLObject.getChild(AbstractMultivariateTraitLikelihood.class);
                if (abstractMultivariateTraitLikelihood != null) {
                    i = abstractMultivariateTraitLikelihood.getNumData();
                    dimension = abstractMultivariateTraitLikelihood.getDimTrait();
                } else {
                    i = 1;
                    if (compoundParameter.getParameterCount() != taxonCount) {
                        throw new XMLParseException("Tip trait parameter is wrong dimension");
                    }
                    dimension = compoundParameter.getDimension() / taxonCount;
                }
            }
            PatternList patternList = (PatternList) xMLObject.getChild(PatternList.class);
            CompoundParameter compoundParameter2 = (CompoundParameter) xMLObject.getElementFirstChild("threshold");
            Parameter parameter = (Parameter) xMLObject.getElementFirstChild("numClasses");
            boolean booleanValue = ((Boolean) xMLObject.getAttribute(IS_UNORDERED, false)).booleanValue();
            if (compoundParameter.getDimension() != taxonCount * i * dimension) {
                throw new XMLParseException("Tip trait parameter is wrong dimension in latent liability model");
            }
            if (booleanValue || patternList.getPatternCount() == i * dimension) {
                return new OrderedLatentLiabilityLikelihood(treeModel, patternList, compoundParameter, compoundParameter2, parameter, booleanValue);
            }
            throw new XMLParseException("Data are wrong dimension in latent liability model. Pattern count = " + patternList.getPatternCount() + ", while per-taxon parameter dimension = " + (i * dimension));
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "Provides the likelihood of a latent liability model on multivariate ordered trait data";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return OrderedLatentLiabilityLikelihood.class;
        }
    };
    public TreeModel treeModel;
    private PatternList patternList;
    public CompoundParameter tipTraitParameter;
    private CompoundParameter thresholdParameter;
    public Parameter numClasses;
    private Parameter containsMissing;
    private int NAcode;
    private boolean isUnordered;
    private int[][] tipData;
    private boolean likelihoodKnown;
    private double logLikelihood;
    private double storedLogLikelihood;
    private static final boolean DEBUG = false;
    private double pathParameter;

    public OrderedLatentLiabilityLikelihood(TreeModel treeModel, PatternList patternList, CompoundParameter compoundParameter, CompoundParameter compoundParameter2, Parameter parameter, boolean z) {
        super(ORDERED_LATENT_LIABILITY_LIKELIHOOD);
        this.normalizationDelegate = new LatentTruncation.Delegate() { // from class: dr.evomodel.continuous.OrderedLatentLiabilityLikelihood.1
            @Override // dr.evomodel.continuous.LatentTruncation.Delegate
            protected double computeNormalizationConstant(Distribution distribution) {
                return 0.0d;
            }
        };
        this.isUnordered = false;
        this.likelihoodKnown = false;
        this.pathParameter = 1.0d;
        this.treeModel = treeModel;
        this.patternList = patternList;
        this.tipTraitParameter = compoundParameter;
        this.thresholdParameter = compoundParameter2;
        this.numClasses = parameter;
        this.isUnordered = z;
        this.NAcode = setNAcode();
        addVariable(compoundParameter);
        addVariable(compoundParameter2);
        setTipDataValuesForAllNodes();
        StringBuilder sb = new StringBuilder();
        sb.append("Constructing a latent liability likelihood model:\n");
        sb.append("\tBinary patterns: ").append(patternList.getId()).append("\n");
        sb.append("\tPlease cite:\n").append(Citable.Utils.getCitationString(this));
        Logger.getLogger("dr.evomodel.continous").info(sb.toString());
    }

    public CompoundParameter getTipTraitParameter() {
        return this.tipTraitParameter;
    }

    public PatternList getPatternList() {
        return this.patternList;
    }

    private void setTipDataValuesForAllNodes() {
        if (this.tipData == null) {
            this.tipData = new int[this.treeModel.getExternalNodeCount()][this.patternList.getPatternCount()];
        }
        int dimension = this.tipTraitParameter.getParameter(0).getDimension();
        double[] dArr = new double[dimension * this.treeModel.getExternalNodeCount()];
        double[] dArr2 = new double[dimension * this.treeModel.getExternalNodeCount()];
        getTipDataValuesForAllNode(dArr, dArr2);
        if (this.tipTraitParameter instanceof FastMatrixParameter) {
            this.tipTraitParameter.addBounds(new Parameter.DefaultBounds(dArr, dArr2));
            return;
        }
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
            double[] dArr3 = new double[dimension];
            double[] dArr4 = new double[dimension];
            this.patternList.getTaxonIndex(this.treeModel.getTaxonId(i));
            System.arraycopy(dArr, i * dimension, dArr3, 0, dimension);
            System.arraycopy(dArr2, i * dimension, dArr4, 0, dimension);
            this.tipTraitParameter.getParameter(i).addBounds(new Parameter.DefaultBounds(dArr3, dArr4));
        }
    }

    private void getTipDataValuesForAllNode(double[] dArr, double[] dArr2) {
        int dimension = this.tipTraitParameter.getParameter(0).getDimension();
        double[] dArr3 = new double[dimension];
        double[] dArr4 = new double[dimension];
        for (int i = 0; i < this.treeModel.getExternalNodeCount(); i++) {
            getTipDataValueForNode(this.treeModel.getExternalNode(i), this.patternList.getTaxonIndex(this.treeModel.getTaxonId(i)), dArr3, dArr4);
            System.arraycopy(dArr3, 0, dArr, i * dimension, dimension);
            System.arraycopy(dArr4, 0, dArr2, i * dimension, dimension);
        }
    }

    private void getTipDataValueForNode(NodeRef nodeRef, int i, double[] dArr, double[] dArr2) {
        int number = nodeRef.getNumber();
        for (int i2 = 0; i2 < this.patternList.getPatternCount(); i2++) {
            this.tipData[number][i2] = this.patternList.getPattern(i2)[i];
            switch (this.tipData[number][i2]) {
                case 0:
                    dArr[i2] = 0.0d;
                    dArr2[i2] = Double.NEGATIVE_INFINITY;
                    break;
                case 1:
                    dArr[i2] = Double.POSITIVE_INFINITY;
                    dArr2[i2] = 0.0d;
                    break;
                default:
                    dArr[i2] = Double.POSITIVE_INFINITY;
                    dArr2[i2] = Double.NEGATIVE_INFINITY;
                    break;
            }
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedLogLikelihood = this.logLikelihood;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = true;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = computeLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    @Override // dr.evomodel.continuous.SoftThresholdLikelihood
    public void setPathParameter(double d) {
        this.pathParameter = d;
    }

    @Override // dr.evomodel.continuous.SoftThresholdLikelihood
    public double getLikelihoodCorrection() {
        boolean z = true;
        for (int i = 0; i < this.tipData.length && z; i++) {
            z = validTraitForTip(i);
        }
        if (z) {
            return 0.0d;
        }
        return (-1.0d) / (1.0d - this.pathParameter);
    }

    @Override // dr.inference.model.AbstractModel
    public String toString() {
        return getClass().getName() + "(" + getLogLikelihood() + ")";
    }

    protected double computeLogLikelihood() {
        boolean z = true;
        for (int i = 0; i < this.tipData.length && z; i++) {
            z = validTraitForTip(i);
        }
        if (z) {
            return 0.0d;
        }
        if (this.pathParameter == 1.0d) {
            return Double.NEGATIVE_INFINITY;
        }
        return Math.log(1.0d - this.pathParameter);
    }

    public int[] getData(int i) {
        return this.tipData[i];
    }

    @Override // dr.evomodel.continuous.LatentTruncation
    public boolean validTraitForTip(int i) {
        boolean z = true;
        Parameter parameter = this.tipTraitParameter.getParameter(i);
        int[] iArr = this.tipData[i];
        if (this.isUnordered) {
            int i2 = 0;
            for (int i3 = 0; i3 < iArr.length && z; i3++) {
                int i4 = iArr[i3];
                int parameterValue = (int) this.numClasses.getParameterValue(i3);
                if (i4 == this.NAcode) {
                    z = true;
                    i2 = parameterValue == 1 ? i2 + 1 : i2 + (parameterValue - 1);
                } else if (parameterValue == 1.0d) {
                    z = true;
                    i2++;
                } else if (parameterValue == 2.0d) {
                    double parameterValue2 = parameter.getParameterValue(i2);
                    z = parameterValue2 == 0.0d ? true : parameterValue2 > 0.0d ? ((double) i4) == 1.0d : ((double) i4) == 0.0d;
                    i2++;
                } else {
                    double[] dArr = new double[parameterValue];
                    dArr[0] = 0.0d;
                    for (int i5 = 1; i5 < parameterValue; i5++) {
                        dArr[i5] = parameter.getParameterValue((i2 + i5) - 1);
                    }
                    z = isMax(dArr, i4);
                    i2 += parameterValue - 1;
                }
            }
        } else {
            int i6 = 0;
            for (int i7 = 0; i7 < iArr.length && z; i7++) {
                int i8 = iArr[i7];
                double parameterValue3 = parameter.getParameterValue(i7);
                int parameterValue4 = (int) this.numClasses.getParameterValue(i7);
                if (parameterValue4 == 1.0d) {
                    z = true;
                } else if (parameterValue4 != 2.0d) {
                    if (i8 == 0) {
                        z = parameterValue3 <= 0.0d;
                    } else if (i8 == 1) {
                        z = parameterValue3 >= 0.0d && parameterValue3 <= this.thresholdParameter.getParameter(i6).getParameterValue(0);
                    } else if (i8 == parameterValue4 - 1) {
                        z = parameterValue3 >= this.thresholdParameter.getParameter(i6).getParameterValue(parameterValue4 - 3);
                    } else if (i8 > parameterValue4 - 1) {
                        z = true;
                    } else {
                        z = parameterValue3 >= this.thresholdParameter.getParameter(i6).getParameterValue(i8 - 2) && parameterValue3 <= this.thresholdParameter.getParameter(i6).getParameterValue(i8 - 1);
                    }
                    i6++;
                } else if (parameterValue3 == 0.0d) {
                    z = true;
                } else if (i8 > 1) {
                    z = true;
                } else {
                    z = parameterValue3 > 0.0d ? ((double) i8) == 1.0d : ((double) i8) == 0.0d;
                }
            }
        }
        return z;
    }

    private boolean isMax(double[] dArr, int i) {
        boolean z = true;
        for (int i2 = 0; i2 < dArr.length && z; i2++) {
            z = dArr[i] >= dArr[i2];
        }
        return z;
    }

    private int setNAcode() {
        int dimension = this.numClasses.getDimension();
        int i = 0;
        for (int i2 = 0; i2 < dimension; i2++) {
            int parameterValue = (int) this.numClasses.getParameterValue(i2);
            if (parameterValue > i) {
                i = parameterValue;
            }
        }
        return i;
    }

    @Override // dr.evomodel.continuous.LatentTruncation
    public double getNormalizationConstant(Distribution distribution) {
        return this.normalizationDelegate.getNormalizationConstant(distribution);
    }

    public Boolean getOrdering() {
        return Boolean.valueOf(this.isUnordered);
    }

    @Override // dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.TRAIT_MODELS;
    }

    @Override // dr.util.Citable
    public String getDescription() {
        return "Latent Liability model";
    }

    @Override // dr.util.Citable
    public List<Citation> getCitations() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(CommonCitations.CYBIS_2015_ASSESSING);
        return arrayList;
    }

    public Parameter getThreshold() {
        return this.thresholdParameter;
    }
}
