package dr.evomodel.antigenic;

import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Statistic;
import dr.inference.model.Variable;
import dr.math.distributions.NormalDistribution;
import dr.util.DataTable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.io.FileReader;
import java.io.IOException;

/* loaded from: input_file:dr/evomodel/antigenic/MultidimensionalScalingLikelihood.class */
public class MultidimensionalScalingLikelihood extends AbstractModelLikelihood {
    public static final String MULTIDIMENSIONAL_SCALING_LIKELIHOOD = "multidimensionalScalingLikelihood";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.evomodel.antigenic.MultidimensionalScalingLikelihood.1
        public static final String FILE_NAME = "fileName";
        public static final String TIP_TRAIT = "tipTrait";
        public static final String LOCATIONS = "locations";
        public static final String MDS_DIMENSION = "mdsDimension";
        public static final String MDS_PRECISION = "mdsPrecision";
        private final XMLSyntaxRule[] rules = {AttributeRule.newStringRule("fileName", false, "The name of the file containing the assay table"), AttributeRule.newIntegerRule("mdsDimension", false, "The dimension of the space for MDS"), new ElementRule("locations", MatrixParameter.class), new ElementRule("mdsPrecision", Parameter.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return MultidimensionalScalingLikelihood.MULTIDIMENSIONAL_SCALING_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            try {
                DataTable<double[]> parse = DataTable.Double.parse(new FileReader(xMLObject.getStringAttribute("fileName")));
                if (parse.getRowCount() != parse.getColumnCount()) {
                    throw new XMLParseException("Data table is not symmetrical.");
                }
                return new MultidimensionalScalingLikelihood(xMLObject.getIntegerAttribute("mdsDimension"), false, (Parameter) xMLObject.getElementFirstChild("mdsPrecision"), (MatrixParameter) xMLObject.getElementFirstChild("locations"), parse);
            } catch (IOException e) {
                throw new XMLParseException("Unable to read assay data from file: " + e.getMessage());
            }
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "Provides the likelihood of pairwise distance given vectors of coordinatesfor points according to the multidimensional scaling scheme of XXX & Rafferty (xxxx).";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return MultidimensionalScalingLikelihood.class;
        }
    };
    private int distanceCount;
    private int observationCount;
    private int upperThresholdCount;
    private int lowerThresholdCount;
    private int pointObservationCount;
    private int thresholdCount;
    private String[] locationLabels;
    private int locationCount;
    private double[] observations;
    private ObservationType[] observationTypes;
    private int[] rowLocationIndices;
    private int[] columnLocationIndices;
    private int[] upperThresholdIndices;
    private int[] lowerThresholdIndices;
    private int[] pointObservationIndices;
    private MatrixParameter locationsParameter;
    private Parameter mdsPrecisionParameter;
    private boolean likelihoodKnown;
    private double logLikelihood;
    private double storedLogLikelihood;
    protected boolean distancesKnown;
    private double sumOfSquaredResiduals;
    private double storedSumOfSquaredResiduals;
    private double[] distances;
    private double[] storedDistances;
    protected boolean[] locationUpdated;
    protected boolean[] distanceUpdated;
    protected boolean residualsKnown;
    protected boolean truncationsKnown;
    private double truncationSum;
    private double storedTruncationSum;
    private double[] truncations;
    private double[] storedTruncations;
    protected boolean thresholdsKnown;
    private double thresholdSum;
    private double storedThresholdSum;
    private double[] thresholds;
    private double[] storedThresholds;
    private boolean isLeftTruncated;
    private int mdsDimension;

    /* loaded from: input_file:dr/evomodel/antigenic/MultidimensionalScalingLikelihood$Distances.class */
    public class Distances extends Statistic.Abstract {
        public Distances() {
            super("distances");
        }

        @Override // dr.inference.model.Statistic
        public int getDimension() {
            return MultidimensionalScalingLikelihood.this.distanceCount;
        }

        @Override // dr.inference.model.Statistic
        public double getStatisticValue(int i) {
            if (!MultidimensionalScalingLikelihood.this.distancesKnown) {
                MultidimensionalScalingLikelihood.this.calculateDistances();
            }
            return MultidimensionalScalingLikelihood.this.distances[i];
        }
    }

    /* loaded from: input_file:dr/evomodel/antigenic/MultidimensionalScalingLikelihood$ObservationType.class */
    public enum ObservationType {
        POINT,
        UPPER_BOUND,
        LOWER_BOUND,
        MISSING
    }

    public MultidimensionalScalingLikelihood(String str) {
        super(str);
        this.likelihoodKnown = false;
        this.distancesKnown = false;
        this.residualsKnown = false;
        this.truncationsKnown = false;
        this.thresholdsKnown = false;
    }

    public MultidimensionalScalingLikelihood(int i, boolean z, Parameter parameter, MatrixParameter matrixParameter, DataTable<double[]> dataTable) {
        super(MULTIDIMENSIONAL_SCALING_LIKELIHOOD);
        this.likelihoodKnown = false;
        this.distancesKnown = false;
        this.residualsKnown = false;
        this.truncationsKnown = false;
        this.thresholdsKnown = false;
        String[] rowLabels = dataTable.getRowLabels();
        dataTable.getRowLabels();
        int rowCount = dataTable.getRowCount();
        int i2 = ((rowCount - 1) * rowCount) / 2;
        double[] dArr = new double[i2];
        ObservationType[] observationTypeArr = new ObservationType[i2];
        int[] iArr = new int[i2];
        int[] iArr2 = new int[i2];
        int i3 = 0;
        for (int i4 = 0; i4 < rowCount; i4++) {
            double[] row = dataTable.getRow(i4);
            for (int i5 = i4 + 1; i5 < rowCount; i5++) {
                dArr[i3] = row[i5];
                observationTypeArr[i3] = ObservationType.POINT;
                iArr[i3] = i4;
                iArr2[i3] = i5;
                i3++;
            }
        }
        initialize(i, z, parameter, matrixParameter, rowLabels, dArr, observationTypeArr, iArr, iArr2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initialize(int i, boolean z, Parameter parameter, MatrixParameter matrixParameter, String[] strArr, double[] dArr, ObservationType[] observationTypeArr, int[] iArr, int[] iArr2) {
        this.mdsDimension = i;
        this.locationCount = strArr.length;
        this.distanceCount = (this.locationCount * (this.locationCount - 1)) / 2;
        this.locationLabels = strArr;
        this.observations = dArr;
        this.observationTypes = observationTypeArr;
        this.rowLocationIndices = iArr;
        this.columnLocationIndices = iArr2;
        this.observationCount = dArr.length;
        this.upperThresholdCount = 0;
        this.lowerThresholdCount = 0;
        int length = observationTypeArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            ObservationType observationType = observationTypeArr[i2];
            this.upperThresholdCount += observationType == ObservationType.UPPER_BOUND ? 1 : 0;
            this.lowerThresholdCount += observationType == ObservationType.LOWER_BOUND ? 1 : 0;
        }
        this.thresholdCount = this.upperThresholdCount + this.lowerThresholdCount;
        this.pointObservationCount = this.observationCount - this.thresholdCount;
        this.upperThresholdIndices = new int[this.upperThresholdCount];
        this.lowerThresholdIndices = new int[this.lowerThresholdCount];
        this.pointObservationIndices = new int[this.pointObservationCount];
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        for (int i6 = 0; i6 < this.observationCount; i6++) {
            switch (observationTypeArr[i6]) {
                case POINT:
                    this.pointObservationIndices[i5] = i6;
                    i5++;
                    break;
                case UPPER_BOUND:
                    this.upperThresholdIndices[i3] = i6;
                    i3++;
                    break;
                case LOWER_BOUND:
                    this.lowerThresholdIndices[i4] = i6;
                    i4++;
                    break;
            }
        }
        this.locationsParameter = matrixParameter;
        setupLocationsParameter(this.locationsParameter);
        addVariable(matrixParameter);
        this.locationUpdated = new boolean[matrixParameter.getParameterCount()];
        this.distances = new double[this.distanceCount];
        this.storedDistances = new double[this.distanceCount];
        this.distanceUpdated = new boolean[this.distanceCount];
        this.truncations = new double[this.distanceCount];
        this.storedTruncations = new double[this.distanceCount];
        this.thresholds = new double[this.thresholdCount];
        this.storedThresholds = new double[this.thresholdCount];
        this.mdsPrecisionParameter = parameter;
        addVariable(parameter);
        this.isLeftTruncated = z;
        makeDirty();
        addStatistic(new Distances());
    }

    protected void setupLocationsParameter(MatrixParameter matrixParameter) {
        if (matrixParameter.getColumnDimension() > 0) {
            boolean z = true;
            if (matrixParameter.getColumnDimension() != this.locationCount) {
                System.err.println("locationsParameter column dimension (" + matrixParameter.getColumnDimension() + ") is not equal to the locationCount (" + this.locationCount + ")");
                z = false;
            }
            if (matrixParameter.getRowDimension() != this.mdsDimension) {
                System.err.println("locationsParameter row dimension (" + matrixParameter.getRowDimension() + ") is not equal to the mdsDimension (" + this.mdsDimension + ")");
                z = false;
            }
            if (!z) {
                System.exit(-1);
            }
        } else {
            matrixParameter.setColumnDimension(this.mdsDimension);
            matrixParameter.setRowDimension(this.locationCount);
        }
        for (int i = 0; i < this.locationLabels.length; i++) {
            matrixParameter.getParameter(i).setId(this.locationLabels[i]);
        }
        for (int i2 = 0; i2 < matrixParameter.getParameterCount(); i2++) {
            Parameter parameter = matrixParameter.getParameter(i2);
            try {
                if (parameter.getBounds() != null) {
                }
            } catch (NullPointerException e) {
                parameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, parameter.getDimension()));
            }
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.inference.model.AbstractModel
    public void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.locationsParameter) {
            this.locationUpdated[i / this.mdsDimension] = true;
            this.distancesKnown = false;
            this.residualsKnown = false;
            this.thresholdsKnown = false;
            this.truncationsKnown = false;
        } else if (variable == this.mdsPrecisionParameter) {
            for (int i2 = 0; i2 < this.distanceUpdated.length; i2++) {
                this.distanceUpdated[i2] = true;
            }
            this.residualsKnown = false;
            this.thresholdsKnown = false;
            this.truncationsKnown = false;
        }
        this.likelihoodKnown = false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.inference.model.AbstractModel
    public void storeState() {
        System.arraycopy(this.distances, 0, this.storedDistances, 0, this.distances.length);
        System.arraycopy(this.truncations, 0, this.storedTruncations, 0, this.truncations.length);
        System.arraycopy(this.thresholds, 0, this.storedThresholds, 0, this.thresholds.length);
        this.storedLogLikelihood = this.logLikelihood;
        this.storedTruncationSum = this.truncationSum;
        this.storedThresholdSum = this.thresholdSum;
        this.storedSumOfSquaredResiduals = this.sumOfSquaredResiduals;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.inference.model.AbstractModel
    public void restoreState() {
        double[] dArr = this.storedDistances;
        this.storedDistances = this.distances;
        this.distances = dArr;
        this.distancesKnown = true;
        double[] dArr2 = this.storedTruncations;
        this.storedTruncations = this.truncations;
        this.truncations = dArr2;
        double[] dArr3 = this.storedThresholds;
        this.storedThresholds = this.thresholds;
        this.thresholds = dArr3;
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = true;
        this.truncationSum = this.storedTruncationSum;
        this.truncationsKnown = true;
        this.thresholdSum = this.storedThresholdSum;
        this.thresholdsKnown = true;
        this.sumOfSquaredResiduals = this.storedSumOfSquaredResiduals;
        this.residualsKnown = true;
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    public void makeDirty() {
        this.distancesKnown = false;
        this.likelihoodKnown = false;
        this.residualsKnown = false;
        this.truncationsKnown = false;
        this.thresholdsKnown = false;
        for (int i = 0; i < this.locationUpdated.length; i++) {
            this.locationUpdated[i] = true;
        }
        for (int i2 = 0; i2 < this.distanceUpdated.length; i2++) {
            this.distanceUpdated[i2] = true;
        }
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            if (!this.distancesKnown) {
                calculateDistances();
                this.residualsKnown = false;
            }
            this.logLikelihood = computeLogLikelihood();
            for (int i = 0; i < this.locationUpdated.length; i++) {
                this.locationUpdated[i] = false;
            }
            for (int i2 = 0; i2 < this.distanceUpdated.length; i2++) {
                this.distanceUpdated[i2] = false;
            }
        }
        return this.logLikelihood;
    }

    protected double computeLogLikelihood() {
        double parameterValue = this.mdsPrecisionParameter.getParameterValue(0);
        if (!this.residualsKnown) {
            this.sumOfSquaredResiduals = calculateSumOfSquaredResiduals();
        }
        double log = ((0.5d * Math.log(parameterValue)) * this.pointObservationCount) - ((0.5d * parameterValue) * this.sumOfSquaredResiduals);
        if (this.thresholdCount > 0) {
            if (!this.thresholdsKnown) {
                this.thresholdSum = calculateThresholdObservations(parameterValue);
            }
            log += this.thresholdSum;
        }
        if (this.isLeftTruncated) {
            if (!this.truncationsKnown) {
                calculateTruncations(parameterValue);
            }
            this.truncationSum = calculateTruncationSum();
            log -= this.truncationSum;
        }
        this.likelihoodKnown = true;
        return log;
    }

    protected double calculateThresholdObservations(double d) {
        double d2 = 0.0d;
        double sqrt = 1.0d / Math.sqrt(d);
        int i = 0;
        for (int i2 = 0; i2 < this.upperThresholdCount; i2++) {
            int i3 = this.upperThresholdIndices[i2];
            int distanceIndexForObservation = getDistanceIndexForObservation(i3);
            if (distanceIndexForObservation == -1) {
                this.thresholds[i] = Math.log(NormalDistribution.tailCDF(this.observations[i3], 0.0d, sqrt));
            } else if (this.distanceUpdated[distanceIndexForObservation]) {
                this.thresholds[i] = Math.log(NormalDistribution.tailCDF(this.observations[i3], this.distances[distanceIndexForObservation], sqrt));
            }
            if (Double.isInfinite(this.thresholds[i])) {
                System.out.println("Error calculation threshold probability");
            }
            d2 += this.thresholds[i];
            i++;
        }
        for (int i4 = 0; i4 < this.lowerThresholdCount; i4++) {
            int i5 = this.lowerThresholdIndices[i4];
            int distanceIndexForObservation2 = getDistanceIndexForObservation(i5);
            if (distanceIndexForObservation2 == -1) {
                this.thresholds[i] = NormalDistribution.cdf(this.observations[i5], 0.0d, sqrt, true);
            } else if (this.distanceUpdated[distanceIndexForObservation2]) {
                this.thresholds[i] = NormalDistribution.cdf(this.observations[i5], this.distances[distanceIndexForObservation2], sqrt, true);
            }
            if (Double.isInfinite(this.thresholds[i])) {
                System.out.println("Error calculation threshold probability");
            }
            d2 += this.thresholds[i];
            i++;
        }
        this.thresholdsKnown = true;
        return d2;
    }

    protected void calculateTruncations(double d) {
        double sqrt = 1.0d / Math.sqrt(d);
        for (int i = 0; i < this.distanceCount; i++) {
            if (this.distanceUpdated[i]) {
                this.truncations[i] = NormalDistribution.cdf(this.distances[i], 0.0d, sqrt, true);
            }
        }
        this.truncationsKnown = true;
    }

    protected double calculateTruncationSum() {
        double d;
        double log;
        double d2 = 0.0d;
        for (int i = 0; i < this.observationCount; i++) {
            int distanceIndexForObservation = getDistanceIndexForObservation(i);
            if (distanceIndexForObservation != -1) {
                d = d2;
                log = this.truncations[distanceIndexForObservation];
            } else {
                d = d2;
                log = Math.log(0.5d);
            }
            d2 = d + log;
        }
        return d2;
    }

    protected double calculateSumOfSquaredResiduals() {
        double d = 0.0d;
        for (int i = 0; i < this.observationCount; i++) {
            if (this.observationTypes[i] == ObservationType.POINT) {
                int distanceIndexForObservation = getDistanceIndexForObservation(i);
                double d2 = distanceIndexForObservation == -1 ? -this.observations[i] : this.distances[distanceIndexForObservation] - this.observations[i];
                d += d2 * d2;
            }
        }
        this.residualsKnown = true;
        return d;
    }

    protected void calculateDistances() {
        int i = 0;
        for (int i2 = 0; i2 < this.locationCount; i2++) {
            for (int i3 = i2 + 1; i3 < this.locationCount; i3++) {
                if (this.locationUpdated[i2] || this.locationUpdated[i3]) {
                    this.distances[i] = calculateDistance(this.locationsParameter.getParameter(i2), this.locationsParameter.getParameter(i3));
                    this.distanceUpdated[i] = true;
                }
                i++;
            }
        }
        this.distancesKnown = true;
    }

    private int getDistanceIndexForObservation(int i) {
        int locationIndex = getLocationIndex(this.rowLocationIndices[i]);
        int locationIndex2 = getLocationIndex(this.columnLocationIndices[i]);
        if (locationIndex == locationIndex2) {
            return -1;
        }
        if (locationIndex > locationIndex2) {
            locationIndex = locationIndex2;
            locationIndex2 = locationIndex;
        }
        int i2 = 0;
        for (int i3 = 0; i3 < locationIndex; i3++) {
            i2 += (this.locationCount - i3) - 1;
        }
        return i2 + ((locationIndex2 - locationIndex) - 1);
    }

    protected int getLocationIndex(int i) {
        return i;
    }

    public String[] getLocationLabels() {
        return this.locationLabels;
    }

    protected double calculateDistance(Parameter parameter, Parameter parameter2) {
        double d = 0.0d;
        for (int i = 0; i < this.mdsDimension; i++) {
            double parameterValue = parameter.getParameterValue(i) - parameter2.getParameterValue(i);
            d += parameterValue * parameterValue;
        }
        return Math.sqrt(d);
    }

    public int getMDSDimension() {
        return this.mdsDimension;
    }

    public int getLocationCount() {
        return this.locationCount;
    }

    public MatrixParameter getLocationsParameter() {
        return this.locationsParameter;
    }
}
