package dr.inference.multidimensionalscaling;

import dr.evomodel.antigenic.MultidimensionalScalingLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.util.DataTable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.Reportable;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.logging.Logger;

/* loaded from: input_file:dr/inference/multidimensionalscaling/MultiDimensionalScalingLikelihood.class */
public class MultiDimensionalScalingLikelihood extends AbstractModelLikelihood implements Reportable, GradientWrtParameterProvider {
    private static final String REQUIRED_FLAGS_PROPERTY = "mds.required.flags";
    private static final String MULTIDIMENSIONAL_SCALING_LIKELIHOOD = "multiDimensionalScalingLikelihood";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.inference.multidimensionalscaling.MultiDimensionalScalingLikelihood.1
        static final String FILE_NAME = "fileName";
        static final String LOCATIONS = "locations";
        static final String MDS_DIMENSION = "mdsDimension";
        static final String MDS_PRECISION = "mdsPrecision";
        static final String INCLUDE_TRUNCATION = "includeTruncation";
        static final String USE_OLD = "useOld";
        static final String FORCE_REORDER = "forceReorder";
        private final XMLSyntaxRule[] rules = {AttributeRule.newStringRule("fileName", false, "The name of the file containing the assay table"), AttributeRule.newIntegerRule("mdsDimension", false, "The dimension of the space for MDS"), new ElementRule("locations", MatrixParameterInterface.class), AttributeRule.newBooleanRule(USE_OLD, true), AttributeRule.newBooleanRule(INCLUDE_TRUNCATION, true), AttributeRule.newBooleanRule(FORCE_REORDER, true), new ElementRule("mdsPrecision", Parameter.class)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return MultiDimensionalScalingLikelihood.MULTIDIMENSIONAL_SCALING_LIKELIHOOD;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            try {
                DataTable<double[]> parse = DataTable.Double.parse(new FileReader(xMLObject.getStringAttribute("fileName")));
                if (parse.getRowCount() != parse.getColumnCount()) {
                    throw new XMLParseException("Data table is not symmetrical.");
                }
                int integerAttribute = xMLObject.getIntegerAttribute("mdsDimension");
                MatrixParameterInterface matrixParameterInterface = (MatrixParameterInterface) xMLObject.getElementFirstChild("locations");
                Parameter parameter = (Parameter) xMLObject.getElementFirstChild("mdsPrecision");
                boolean booleanValue = ((Boolean) xMLObject.getAttribute(USE_OLD, false)).booleanValue();
                boolean booleanValue2 = ((Boolean) xMLObject.getAttribute(INCLUDE_TRUNCATION, false)).booleanValue();
                boolean booleanValue3 = ((Boolean) xMLObject.getAttribute(FORCE_REORDER, false)).booleanValue();
                if (!booleanValue) {
                    return new MultiDimensionalScalingLikelihood(integerAttribute, parameter, matrixParameterInterface, parse, booleanValue2, booleanValue3);
                }
                System.err.println("USE OLD");
                return new MultidimensionalScalingLikelihood(integerAttribute, booleanValue2, parameter, (MatrixParameter) matrixParameterInterface, parse);
            } catch (IOException e) {
                throw new XMLParseException("Unable to read assay data from file: " + e.getMessage());
            }
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "Provides the likelihood of pairwise distance given vectors of coordinatesfor points according to the multidimensional scaling scheme of XXX & Rafferty (to fill in).";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return MultiDimensionalScalingLikelihood.class;
        }
    };
    private final int mdsDimension;
    private final int vectorDimension;
    private final int locationCount;
    private MultiDimensionalScalingCore mdsCore;
    private String[] locationLabels;
    private Parameter mdsPrecisionParameter;
    private MatrixParameterInterface locationsParameter;
    private boolean likelihoodKnown;
    private double logLikelihood;
    private double storedLogLikelihood;
    private long flags;
    private double[] observations;
    private double[] gradient;

    /* loaded from: input_file:dr/inference/multidimensionalscaling/MultiDimensionalScalingLikelihood$ObservationType.class */
    public enum ObservationType {
        POINT,
        UPPER_BOUND,
        LOWER_BOUND,
        MISSING
    }

    @Override // dr.xml.Reportable
    public String getReport() {
        return getId() + ": " + getLogLikelihood();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Likelihood getLikelihood() {
        return this;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public Parameter getParameter() {
        return this.locationsParameter;
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public int getDimension() {
        return this.locationsParameter.getDimension();
    }

    @Override // dr.inference.hmc.GradientWrtParameterProvider
    public double[] getGradientLogDensity() {
        if (this.gradient == null) {
            this.gradient = new double[this.locationsParameter.getDimension()];
        }
        this.mdsCore.getGradient(this.gradient);
        return this.gradient;
    }

    public MultiDimensionalScalingLikelihood(int i, Parameter parameter, MatrixParameterInterface matrixParameterInterface, DataTable<double[]> dataTable, boolean z, boolean z2) {
        super(MULTIDIMENSIONAL_SCALING_LIKELIHOOD);
        int[] iArr;
        this.likelihoodKnown = false;
        this.flags = 0L;
        this.mdsDimension = i;
        String[] rowLabels = dataTable.getRowLabels();
        int rowCount = dataTable.getRowCount();
        this.locationCount = rowCount;
        if (z2) {
            iArr = getPermutation(rowLabels, matrixParameterInterface);
        } else {
            iArr = new int[this.locationCount];
            for (int i2 = 0; i2 < this.locationCount; i2++) {
                iArr[i2] = i2;
            }
        }
        String[] strArr = new String[this.locationCount];
        int i3 = rowCount * rowCount;
        this.observations = new double[i3];
        ObservationType[] observationTypeArr = new ObservationType[i3];
        double[][] dArr = new double[rowCount][rowCount];
        for (int i4 = 0; i4 < rowCount; i4++) {
            strArr[i4] = rowLabels[iArr[i4]];
            double[] row = dataTable.getRow(iArr[i4]);
            for (int i5 = i4 + 1; i5 < rowCount; i5++) {
                double d = row[iArr[i5]];
                dArr[i5][i4] = d;
                dArr[i4][i5] = d;
            }
        }
        int i6 = 0;
        for (int i7 = 0; i7 < rowCount; i7++) {
            for (int i8 = 0; i8 < rowCount; i8++) {
                if (i7 == i8) {
                    this.observations[i6] = 0.0d;
                    observationTypeArr[i6] = ObservationType.POINT;
                } else {
                    this.observations[i6] = dArr[i7][i8];
                    if (Double.isNaN(this.observations[i6])) {
                        observationTypeArr[i6] = ObservationType.MISSING;
                    } else {
                        observationTypeArr[i6] = ObservationType.POINT;
                    }
                }
                i6++;
            }
        }
        this.vectorDimension = initialize(i, parameter, z, matrixParameterInterface, strArr, this.observations, observationTypeArr);
    }

    public double[] getObservations() {
        return this.observations;
    }

    public MatrixParameterInterface getMatrixParameter() {
        return this.locationsParameter;
    }

    private int[] getPermutation(String[] strArr, MatrixParameterInterface matrixParameterInterface) {
        if (strArr.length != matrixParameterInterface.getColumnDimension()) {
            throw new IllegalArgumentException("Dimension mismatch");
        }
        int length = strArr.length;
        HashMap hashMap = new HashMap(matrixParameterInterface.getColumnDimension());
        for (int i = 0; i < length; i++) {
            hashMap.put(strArr[i], Integer.valueOf(i));
        }
        int[] iArr = new int[length];
        for (int i2 = 0; i2 < length; i2++) {
            Integer num = (Integer) hashMap.get(matrixParameterInterface.getParameter(i2).getParameterName());
            if (num == null) {
                Logger.getLogger("dr.app.beagle").info("Missing label!!!");
            } else {
                iArr[i2] = num.intValue();
            }
        }
        return iArr;
    }

    private MultiDimensionalScalingCore getCore() {
        MultiDimensionalScalingCore multiDimensionalScalingCoreImpl;
        long j = 0;
        String property = System.getProperty(REQUIRED_FLAGS_PROPERTY);
        if (property != null) {
            j = Long.parseLong(property.trim());
        }
        if (j >= 1) {
            System.err.println("Attempting to use a native MDS core with flag: " + j + "; may the force be with you ....");
            multiDimensionalScalingCoreImpl = new MassivelyParallelMDSImpl();
            this.flags = j;
        } else {
            System.err.println("Computer mode found: " + j + " vs. " + property);
            multiDimensionalScalingCoreImpl = new MultiDimensionalScalingCoreImpl();
        }
        return multiDimensionalScalingCoreImpl;
    }

    public int getMdsDimension() {
        return this.mdsDimension;
    }

    public int getLocationCount() {
        return this.locationCount;
    }

    protected int initialize(int i, Parameter parameter, boolean z, MatrixParameterInterface matrixParameterInterface, String[] strArr, double[] dArr, ObservationType[] observationTypeArr) {
        this.mdsCore = getCore();
        if (z) {
            this.flags |= 32;
        }
        System.err.println("Initializing with flags: " + this.flags);
        this.mdsCore.initialize(i, this.locationCount, this.flags);
        this.locationLabels = strArr;
        this.locationsParameter = matrixParameterInterface;
        int internalDimension = this.mdsCore.getInternalDimension();
        setupLocationsParameter(this.locationsParameter);
        addVariable(matrixParameterInterface);
        this.mdsPrecisionParameter = parameter;
        addVariable(parameter);
        this.mdsCore.setParameters(this.mdsPrecisionParameter.getParameterValues());
        this.mdsCore.setPairwiseData(dArr);
        updateAllLocations(matrixParameterInterface);
        makeDirty();
        return internalDimension;
    }

    private void updateAllLocations(MatrixParameterInterface matrixParameterInterface) {
        this.mdsCore.updateLocation(-1, matrixParameterInterface.getParameterValues());
    }

    private void setupLocationsParameter(MatrixParameterInterface matrixParameterInterface) {
        if (!(matrixParameterInterface.getColumnDimension() > 0)) {
            throw new IllegalArgumentException("Dimensions on matrix must be set");
        }
        if (matrixParameterInterface.getColumnDimension() != this.locationCount) {
            throw new RuntimeException("locationsParameter column dimension (" + matrixParameterInterface.getColumnDimension() + ") is not equal to the locationCount (" + this.locationCount + ")");
        }
        if (matrixParameterInterface.getRowDimension() != this.mdsDimension) {
            throw new RuntimeException("locationsParameter row dimension (" + matrixParameterInterface.getRowDimension() + ") is not equal to the mdsDimension (" + this.mdsDimension + ")");
        }
        for (int i = 0; i < this.locationLabels.length; i++) {
            if (matrixParameterInterface.getParameter(i).getParameterName().compareTo(this.locationLabels[i]) != 0) {
                throw new RuntimeException("Mismatched trait parameter name (" + matrixParameterInterface.getParameter(i).getParameterName() + ") and data dimension name (" + this.locationLabels[i] + ")");
            }
        }
        for (int i2 = 0; i2 < matrixParameterInterface.getColumnDimension(); i2++) {
            Parameter parameter = matrixParameterInterface.getParameter(i2);
            try {
                parameter.getBounds();
            } catch (NullPointerException e) {
                parameter.addBounds(new Parameter.DefaultBounds(Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, parameter.getDimension()));
            }
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable == this.locationsParameter) {
            if (i == -1) {
                updateAllLocations(this.locationsParameter);
            } else {
                int i2 = i / this.mdsDimension;
                this.mdsCore.updateLocation(i2, this.locationsParameter.getColumnValues(i2));
            }
        } else if (variable == this.mdsPrecisionParameter) {
            this.mdsCore.setParameters(this.mdsPrecisionParameter.getParameterValues());
        }
        this.likelihoodKnown = false;
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedLogLikelihood = this.logLikelihood;
        this.mdsCore.storeState();
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = true;
        this.mdsCore.restoreState();
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
        this.mdsCore.acceptState();
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.mdsCore.makeDirty();
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = this.mdsCore.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    public double getMDSPrecision() {
        return this.mdsPrecisionParameter.getParameterValue(0);
    }
}
