package dr.evomodel.substmodel;

import dr.evolution.datatype.DataType;
import dr.evolution.datatype.GeneralDataType;
import dr.evomodel.siteratemodel.SiteRateModel;
import dr.inference.model.Model;
import dr.math.KroneckerOperation;
import dr.util.Citable;
import dr.util.Citation;
import dr.util.CommonCitations;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:dr/evomodel/substmodel/ProductChainSubstitutionModel.class */
public class ProductChainSubstitutionModel extends BaseSubstitutionModel implements Citable {
    protected final int numBaseModel;
    protected final List<SubstitutionModel> baseModels;
    protected final List<SiteRateModel> rateModels;
    protected final int[] stateSizes;
    protected final ProductChainFrequencyModel pcFreqModel;
    protected double[] rateMatrix;
    private final boolean forceAverageModel;
    private SubstitutionProcess averageModel;

    public ProductChainSubstitutionModel(String str, List<SubstitutionModel> list) {
        this(str, list, null);
    }

    public ProductChainSubstitutionModel(String str, List<SubstitutionModel> list, List<SiteRateModel> list2) {
        this(str, list, list2, false);
    }

    public ProductChainSubstitutionModel(String str, List<SubstitutionModel> list, List<SiteRateModel> list2, boolean z) {
        super(str);
        this.rateMatrix = null;
        this.averageModel = null;
        this.baseModels = list;
        this.rateModels = list2;
        this.forceAverageModel = z;
        this.numBaseModel = list.size();
        if (this.numBaseModel == 0) {
            throw new RuntimeException("May not construct ProductChainSubstitutionModel with 0 base models");
        }
        if (list2 != null) {
            Iterator<SiteRateModel> it = list2.iterator();
            while (it.hasNext()) {
                if (it.next().getCategoryCount() > 1) {
                    throw new RuntimeException("ProductChainSubstitutionModels with multiple categories not yet implemented");
                }
            }
        }
        ArrayList arrayList = new ArrayList();
        this.stateSizes = new int[this.numBaseModel];
        this.stateCount = 1;
        for (int i = 0; i < this.numBaseModel; i++) {
            arrayList.add(list.get(i).getFrequencyModel());
            DataType dataType = list.get(i).getDataType();
            this.stateSizes[i] = dataType.getStateCount();
            this.stateCount *= dataType.getStateCount();
            addModel(list.get(i));
            addModel(list2.get(i));
        }
        this.pcFreqModel = new ProductChainFrequencyModel("pc", arrayList);
        addModel(this.pcFreqModel);
        this.dataType = new GeneralDataType(getCharacterStrings());
        this.updateMatrix = true;
    }

    @Override // dr.util.Citable
    public Citation.Category getCategory() {
        return Citation.Category.SUBSTITUTION_MODELS;
    }

    @Override // dr.util.Citable
    public String getDescription() {
        return "Product chain substitution model";
    }

    @Override // dr.util.Citable
    public List<Citation> getCitations() {
        return Collections.singletonList(CommonCitations.OBRIEN_2009_LEARNING);
    }

    @Override // dr.evomodel.substmodel.BaseSubstitutionModel, dr.evomodel.substmodel.SubstitutionProcess
    public EigenDecomposition getEigenDecomposition() {
        synchronized (this) {
            if (this.updateMatrix) {
                computeKroneckerSumsAndProducts();
            }
        }
        return this.eigenDecomposition;
    }

    private String[] getCharacterStrings() {
        String[] strArr = null;
        for (int i = this.numBaseModel - 1; i >= 0; i--) {
            strArr = recursivelyAppendCharacterStates(this.baseModels.get(i).getDataType(), strArr);
        }
        return strArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // dr.evomodel.substmodel.BaseSubstitutionModel, dr.inference.model.AbstractModel
    public void handleModelChangedEvent(Model model, Object obj, int i) {
        super.handleModelChangedEvent(model, obj, i);
        fireModelChanged(model);
        this.averageModel = null;
    }

    private String[] recursivelyAppendCharacterStates(DataType dataType, String[] strArr) {
        String[] strArr2 = strArr;
        if (strArr2 == null) {
            strArr2 = new String[]{""};
        }
        int length = strArr2.length;
        int stateCount = dataType.getStateCount();
        String[] strArr3 = new String[length * stateCount];
        for (int i = 0; i < stateCount; i++) {
            String code = dataType.getCode(i);
            for (int i2 = 0; i2 < length; i2++) {
                strArr3[(i * length) + i2] = code + strArr2[i2];
            }
        }
        return strArr3;
    }

    @Override // dr.evomodel.substmodel.BaseSubstitutionModel, dr.evomodel.substmodel.SubstitutionProcess
    public void getInfinitesimalMatrix(double[] dArr) {
        getEigenDecomposition();
        System.arraycopy(this.rateMatrix, 0, dArr, 0, this.stateCount * this.stateCount);
    }

    double getRateForModel(int i) {
        if (!this.forceAverageModel) {
            return this.rateModels.get(i).getRateForCategory(0);
        }
        double d = 0.0d;
        for (int i2 = 0; i2 < this.rateModels.size(); i2++) {
            d += this.rateModels.get(i2).getRateForCategory(0);
        }
        return d / this.rateModels.size();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] scaleForProductChain(double[] dArr, int i) {
        if (this.rateModels == null) {
            return dArr;
        }
        double rateForModel = getRateForModel(i);
        if (rateForModel == 1.0d) {
            return dArr;
        }
        int length = dArr.length;
        double[] dArr2 = new double[length];
        for (int i2 = 0; i2 < length; i2++) {
            dArr2[i2] = rateForModel * dArr[i2];
        }
        return dArr2;
    }

    private SubstitutionProcess computeAverageModel() {
        return new SubstitutionProcess() { // from class: dr.evomodel.substmodel.ProductChainSubstitutionModel.1
            private double[] averageMatrix = null;
            private EigenDecomposition eigenDecomposition = null;

            @Override // dr.evomodel.substmodel.SubstitutionProcess
            public void getTransitionProbabilities(double d, double[] dArr) {
                throw new RuntimeException("Should not be called");
            }

            @Override // dr.evomodel.substmodel.SubstitutionProcess
            public EigenDecomposition getEigenDecomposition() {
                if (this.eigenDecomposition == null) {
                    double[][] dArr = new double[ProductChainSubstitutionModel.this.stateSizes[0]][ProductChainSubstitutionModel.this.stateSizes[0]];
                    double[] dArr2 = new double[ProductChainSubstitutionModel.this.stateSizes[0] * ProductChainSubstitutionModel.this.stateSizes[0]];
                    getInfinitesimalMatrix(dArr2);
                    for (int i = 0; i < ProductChainSubstitutionModel.this.stateSizes[0]; i++) {
                        System.arraycopy(dArr2, i * ProductChainSubstitutionModel.this.stateSizes[0], dArr[i], 0, ProductChainSubstitutionModel.this.stateSizes[0]);
                    }
                    this.eigenDecomposition = ProductChainSubstitutionModel.this.getDefaultEigenSystem(ProductChainSubstitutionModel.this.stateSizes[0]).decomposeMatrix(dArr);
                }
                return this.eigenDecomposition;
            }

            @Override // dr.evomodel.substmodel.SubstitutionProcess
            public FrequencyModel getFrequencyModel() {
                throw new RuntimeException("Should not be called");
            }

            @Override // dr.evomodel.substmodel.SubstitutionProcess
            public void getInfinitesimalMatrix(double[] dArr) {
                if (this.averageMatrix == null) {
                    int length = dArr.length;
                    this.averageMatrix = new double[length];
                    double[][] dArr2 = new double[ProductChainSubstitutionModel.this.baseModels.size()][length];
                    for (int i = 0; i < ProductChainSubstitutionModel.this.baseModels.size(); i++) {
                        ProductChainSubstitutionModel.this.baseModels.get(i).getInfinitesimalMatrix(dArr2[i]);
                    }
                    for (int i2 = 0; i2 < length; i2++) {
                        double d = 0.0d;
                        for (int i3 = 0; i3 < ProductChainSubstitutionModel.this.baseModels.size(); i3++) {
                            d += dArr2[i3][i2];
                        }
                        this.averageMatrix[i2] = d / ProductChainSubstitutionModel.this.baseModels.size();
                    }
                }
                System.arraycopy(this.averageMatrix, 0, dArr, 0, this.averageMatrix.length);
            }

            @Override // dr.evomodel.substmodel.SubstitutionProcess
            public DataType getDataType() {
                throw new RuntimeException("Should not be called");
            }

            @Override // dr.evomodel.substmodel.SubstitutionProcess
            public boolean canReturnComplexDiagonalization() {
                throw new RuntimeException("Should not be called");
            }
        };
    }

    private SubstitutionProcess getBaseModel(int i) {
        if (!this.forceAverageModel) {
            return this.baseModels.get(i);
        }
        if (this.averageModel == null) {
            this.averageModel = computeAverageModel();
        }
        return this.averageModel;
    }

    private void computeKroneckerSumsAndProducts() {
        int i = this.stateSizes[0];
        double[] dArr = new double[i * i];
        getBaseModel(0).getInfinitesimalMatrix(dArr);
        double[] scaleForProductChain = scaleForProductChain(dArr, 0);
        EigenDecomposition eigenDecomposition = getBaseModel(0).getEigenDecomposition();
        double[] scaleForProductChain2 = scaleForProductChain(eigenDecomposition.getEigenValues(), 0);
        double[] eigenVectors = eigenDecomposition.getEigenVectors();
        double[] transpose = transpose(eigenDecomposition.getInverseEigenVectors(), i);
        for (int i2 = 1; i2 < this.numBaseModel; i2++) {
            SubstitutionProcess baseModel = getBaseModel(i2);
            int i3 = this.stateSizes[i2];
            double[] dArr2 = new double[i3 * i3];
            baseModel.getInfinitesimalMatrix(dArr2);
            scaleForProductChain = KroneckerOperation.sum(scaleForProductChain, i, scaleForProductChain(dArr2, i2), i3);
            EigenDecomposition eigenDecomposition2 = baseModel.getEigenDecomposition();
            double[] scaleForProductChain3 = scaleForProductChain(eigenDecomposition2.getEigenValues(), i2);
            double[] eigenVectors2 = eigenDecomposition2.getEigenVectors();
            double[] transpose2 = transpose(eigenDecomposition2.getInverseEigenVectors(), i3);
            scaleForProductChain2 = KroneckerOperation.sum(scaleForProductChain2, scaleForProductChain3);
            eigenVectors = KroneckerOperation.product(eigenVectors, i, i, eigenVectors2, i3, i3);
            transpose = KroneckerOperation.product(transpose, i, i, transpose2, i3, i3);
            i *= i3;
        }
        this.rateMatrix = scaleForProductChain;
        this.eigenDecomposition = new EigenDecomposition(eigenVectors, transpose(transpose, i), scaleForProductChain2);
        this.updateMatrix = false;
    }

    private static double[] transpose(double[] dArr, int i) {
        double[] dArr2 = new double[i * i];
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                dArr2[(i3 * i) + i2] = dArr[(i2 * i) + i3];
            }
        }
        return dArr2;
    }

    @Override // dr.evomodel.substmodel.BaseSubstitutionModel, dr.evomodel.substmodel.SubstitutionProcess
    public FrequencyModel getFrequencyModel() {
        return this.pcFreqModel;
    }

    @Override // dr.evomodel.substmodel.BaseSubstitutionModel
    protected void frequenciesChanged() {
    }

    @Override // dr.evomodel.substmodel.BaseSubstitutionModel
    protected void ratesChanged() {
    }

    @Override // dr.evomodel.substmodel.BaseSubstitutionModel
    protected void setupRelativeRates(double[] dArr) {
    }
}
