package dr.inference.multidimensionalscaling;

import dr.math.distributions.NormalDistribution;

/* loaded from: input_file:dr/inference/multidimensionalscaling/MultiDimensionalScalingCoreImpl.class */
public class MultiDimensionalScalingCoreImpl implements MultiDimensionalScalingCore {
    private int embeddingDimension;
    private int locationCount;
    private int observationCount;
    private double precision;
    private double storedPrecision;
    private double[][] observations;
    private double[][] locations;
    private double[][] storedLocations;
    private double[][] increments;
    private double[] storedIncrements;
    private double sumOfIncrements;
    private double storedSumOfIncrements;
    private boolean isLeftTruncated = false;
    private int updatedLocation = -1;
    private boolean incrementsKnown = false;
    private boolean sumOfIncrementsKnown = false;

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void initialize(int i, int i2, long j) {
        this.embeddingDimension = i;
        this.locationCount = i2;
        this.observationCount = (i2 * (i2 - 1)) / 2;
        this.observations = new double[i2][i2];
        this.increments = new double[i2][i2];
        this.storedIncrements = null;
        this.incrementsKnown = false;
        this.sumOfIncrementsKnown = false;
        this.isLeftTruncated = (j & 32) != 0;
        this.updatedLocation = -1;
        this.locations = new double[i2][i];
        this.storedLocations = new double[i2][i];
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void setPairwiseData(double[] dArr) {
        if (dArr.length != this.locationCount * this.locationCount) {
            throw new RuntimeException("Observation data is not the correct dimension");
        }
        int i = 0;
        for (int i2 = 0; i2 < this.locationCount; i2++) {
            System.arraycopy(dArr, i, this.observations[i2], 0, this.locationCount);
            i += this.locationCount;
        }
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public double[] getPairwiseData() {
        double[] dArr = new double[this.locationCount * this.locationCount];
        int i = 0;
        for (int i2 = 0; i2 < this.locationCount; i2++) {
            System.arraycopy(this.observations[i2], 0, dArr, i, this.locationCount);
            i += this.locationCount;
        }
        return dArr;
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public int getInternalDimension() {
        return this.embeddingDimension;
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void setParameters(double[] dArr) {
        this.precision = dArr[0];
        if (this.isLeftTruncated) {
            this.incrementsKnown = false;
            this.sumOfIncrementsKnown = false;
        }
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void updateLocation(int i, double[] dArr) {
        if (this.updatedLocation != -1 || i == -1) {
            this.incrementsKnown = false;
            this.storedIncrements = null;
        }
        if (i != -1) {
            this.updatedLocation = i;
            if (dArr.length != this.embeddingDimension) {
                throw new RuntimeException("Location is not the correct dimension");
            }
            System.arraycopy(dArr, 0, this.locations[i], 0, this.embeddingDimension);
        } else {
            if (dArr.length != this.embeddingDimension * this.locationCount) {
                throw new RuntimeException("Location is the not correct dimension");
            }
            int i2 = 0;
            for (int i3 = 0; i3 < this.locationCount; i3++) {
                System.arraycopy(dArr, i2, this.locations[i3], 0, this.embeddingDimension);
                i2 += this.embeddingDimension;
            }
        }
        this.sumOfIncrementsKnown = false;
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public double calculateLogLikelihood() {
        if (!this.sumOfIncrementsKnown) {
            if (this.incrementsKnown) {
                updateSumOfSquaredResiduals();
            } else {
                computeSumOfSquaredResiduals();
            }
            this.sumOfIncrementsKnown = true;
        }
        double log = 0.5d * (Math.log(this.precision) - Math.log(6.283185307179586d)) * this.observationCount;
        return this.isLeftTruncated ? log - this.sumOfIncrements : log - ((0.5d * this.precision) * this.sumOfIncrements);
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void storeState() {
        this.storedSumOfIncrements = this.sumOfIncrements;
        this.storedIncrements = null;
        for (int i = 0; i < this.locationCount; i++) {
            System.arraycopy(this.locations[i], 0, this.storedLocations[i], 0, this.embeddingDimension);
        }
        this.updatedLocation = -1;
        this.storedPrecision = this.precision;
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void restoreState() {
        this.sumOfIncrements = this.storedSumOfIncrements;
        this.sumOfIncrementsKnown = true;
        if (this.storedIncrements != null) {
            System.arraycopy(this.storedIncrements, 0, this.increments[this.updatedLocation], 0, this.locationCount);
            this.incrementsKnown = true;
        } else {
            this.incrementsKnown = false;
        }
        double[][] dArr = this.storedLocations;
        this.storedLocations = this.locations;
        this.locations = dArr;
        this.precision = this.storedPrecision;
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void acceptState() {
        if (this.storedIncrements != null) {
            for (int i = 0; i < this.locationCount; i++) {
                this.increments[i][this.updatedLocation] = this.increments[this.updatedLocation][i];
            }
        }
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void getGradient(double[] dArr) {
        throw new RuntimeException("Not yet implemented.");
    }

    @Override // dr.inference.multidimensionalscaling.MultiDimensionalScalingCore
    public void makeDirty() {
        this.sumOfIncrementsKnown = false;
        this.incrementsKnown = false;
    }

    private void computeSumOfSquaredResiduals() {
        double sqrt = Math.sqrt(this.precision);
        double d = 0.5d * this.precision;
        this.sumOfIncrements = 0.0d;
        for (int i = 0; i < this.locationCount; i++) {
            for (int i2 = 0; i2 < this.locationCount; i2++) {
                double calculateDistance = calculateDistance(this.locations[i], this.locations[i2]);
                double d2 = calculateDistance - this.observations[i][i2];
                double d3 = d2 * d2;
                if (this.isLeftTruncated) {
                    d3 = d * d3;
                    if (i != i2) {
                        d3 += computeTruncation(calculateDistance, sqrt);
                    }
                }
                this.increments[i][i2] = d3;
                this.sumOfIncrements += d3;
            }
        }
        this.sumOfIncrements /= 2.0d;
        this.incrementsKnown = true;
        this.sumOfIncrementsKnown = true;
    }

    private void updateSumOfSquaredResiduals() {
        double sqrt = Math.sqrt(this.precision);
        double d = 0.5d * this.precision;
        double d2 = 0.0d;
        int i = this.updatedLocation;
        this.storedIncrements = new double[this.locationCount];
        System.arraycopy(this.increments[i], 0, this.storedIncrements, 0, this.locationCount);
        for (int i2 = 0; i2 < this.locationCount; i2++) {
            double calculateDistance = calculateDistance(this.locations[i], this.locations[i2]);
            double d3 = calculateDistance - this.observations[i][i2];
            double d4 = d3 * d3;
            if (this.isLeftTruncated) {
                d4 = d * d4;
                if (i != i2) {
                    d4 += computeTruncation(calculateDistance, sqrt);
                }
            }
            d2 += d4 - this.increments[i][i2];
            this.increments[i][i2] = d4;
        }
        this.sumOfIncrements += d2;
    }

    private double calculateDistance(double[] dArr, double[] dArr2) {
        double d = 0.0d;
        for (int i = 0; i < this.embeddingDimension; i++) {
            double d2 = dArr[i] - dArr2[i];
            d += d2 * d2;
        }
        return Math.sqrt(d);
    }

    private double computeTruncation(double d, double d2) {
        return NormalDistribution.standardCDF(d * d2, true);
    }
}
