package dr.oldevomodel.sitemodel;

import dr.evolution.alignment.Alignment;
import dr.evolution.alignment.ExtractPairs;
import dr.evolution.alignment.GapUtils;
import dr.evolution.alignment.SitePatterns;
import dr.evolution.datatype.Nucleotides;
import dr.evolution.io.Importer;
import dr.evolution.io.NexusImporter;
import dr.evomodel.continuous.TopographicalMap;
import dr.inference.model.Parameter;
import dr.math.DifferentialEvolution;
import dr.math.MultivariateFunction;
import dr.math.UnivariateFunction;
import dr.math.UnivariateMinimum;
import dr.oldevomodel.substmodel.FrequencyModel;
import dr.oldevomodel.substmodel.HKY;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:dr/oldevomodel/sitemodel/AlignmentScore.class */
public class AlignmentScore implements UnivariateFunction, MultivariateFunction {
    SiteModel siteModel;
    ScoreMatrix scoreMatrix;
    SitePatterns sitePatterns;

    public AlignmentScore(ScoreMatrix scoreMatrix, SitePatterns sitePatterns) {
        this.scoreMatrix = scoreMatrix;
        this.siteModel = scoreMatrix.siteModel;
        this.sitePatterns = sitePatterns;
    }

    @Override // dr.math.MultivariateFunction
    public double evaluate(double[] dArr) {
        double d = dArr[0];
        double d2 = dArr[1];
        ((HKY) ((GammaSiteModel) this.siteModel).getSubstitutionModel()).setKappa(d);
        this.scoreMatrix.setTime(d2);
        return -this.scoreMatrix.getScore(this.sitePatterns);
    }

    @Override // dr.math.UnivariateFunction
    public double evaluate(double d) {
        this.scoreMatrix.setTime(d);
        return -this.scoreMatrix.getScore(this.sitePatterns);
    }

    @Override // dr.math.MultivariateFunction
    public int getNumArguments() {
        return 2;
    }

    @Override // dr.math.UnivariateFunction
    public double getLowerBound() {
        return 0.0d;
    }

    @Override // dr.math.UnivariateFunction
    public double getUpperBound() {
        return 10.0d;
    }

    @Override // dr.math.MultivariateFunction
    public double getLowerBound(int i) {
        return 0.0d;
    }

    @Override // dr.math.MultivariateFunction
    public double getUpperBound(int i) {
        return (i == 0 || i == 1) ? 100.0d : 10.0d;
    }

    public static double[] getAlignmentScore(ScoreMatrix scoreMatrix, SitePatterns sitePatterns) {
        AlignmentScore alignmentScore = new AlignmentScore(scoreMatrix, sitePatterns);
        double[] dArr = new double[alignmentScore.getNumArguments()];
        DifferentialEvolution differentialEvolution = new DifferentialEvolution(dArr.length, dArr.length * 10);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 0.5d;
        }
        differentialEvolution.optimize(alignmentScore, dArr, 1.0E-6d, 1.0E-6d);
        double evaluate = alignmentScore.evaluate(dArr);
        double[] dArr2 = new double[dArr.length + 1];
        System.arraycopy(dArr, 0, dArr2, 0, dArr.length);
        dArr2[dArr.length] = evaluate;
        return dArr2;
    }

    public static double getGeneticDistance(ScoreMatrix scoreMatrix, SitePatterns sitePatterns) {
        AlignmentScore alignmentScore = new AlignmentScore(scoreMatrix, sitePatterns);
        double[] dArr = new double[alignmentScore.getNumArguments()];
        DifferentialEvolution differentialEvolution = new DifferentialEvolution(dArr.length, dArr.length * 10);
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 0.5d;
        }
        differentialEvolution.optimize(alignmentScore, dArr, 1.0E-6d, 1.0E-6d);
        return dArr[dArr.length - 1];
    }

    public static double[] getFastAlignmentScore(ScoreMatrix scoreMatrix, SitePatterns sitePatterns) {
        AlignmentScore alignmentScore = new AlignmentScore(scoreMatrix, sitePatterns);
        double optimize = new UnivariateMinimum().optimize(alignmentScore, 1.0E-6d);
        return new double[]{optimize, alignmentScore.evaluate(optimize)};
    }

    private static void printFrequencyTable(List<Integer> list) {
        int i = 0;
        int i2 = 0;
        for (Integer num : list) {
            if (num.intValue() > i) {
                i = num.intValue();
            }
            i2 += num.intValue();
        }
        int[] iArr = new int[i + 1];
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            iArr[intValue] = iArr[intValue] + 1;
        }
        for (int i3 = 0; i3 < iArr.length; i3++) {
            System.out.println(i3 + "\t" + iArr[i3]);
        }
        System.out.println("Total = " + i2);
    }

    public static void main(String[] strArr) throws IOException, Importer.ImportException {
        Alignment importAlignment = new NexusImporter(new FileReader(strArr[0])).importAlignment();
        ExtractPairs extractPairs = new ExtractPairs(importAlignment);
        Parameter.Default r0 = new Parameter.Default(1.0d);
        Parameter.Default r02 = new Parameter.Default(1.0d);
        r02.addBounds(new Parameter.DefaultBounds(100.0d, 0.0d, 1));
        r0.addBounds(new Parameter.DefaultBounds(1.0d, 1.0d, 1));
        ScoreMatrix scoreMatrix = new ScoreMatrix(new GammaSiteModel(new HKY(r02, new FrequencyModel(Nucleotides.INSTANCE, new Parameter.Default(importAlignment.getStateFrequencies()))), r0, null, 1, null), 0.1d);
        ArrayList<PairDistance> arrayList = new ArrayList();
        HashSet hashSet = new HashSet();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < importAlignment.getSequenceCount(); i++) {
            for (int i2 = i + 1; i2 < importAlignment.getSequenceCount(); i2++) {
                Alignment pairAlignment = extractPairs.getPairAlignment(i, i2);
                if (pairAlignment != null) {
                    double geneticDistance = getGeneticDistance(scoreMatrix, new SitePatterns(pairAlignment));
                    if (geneticDistance < 0.1d) {
                        ArrayList arrayList3 = new ArrayList();
                        GapUtils.getGapSizes(pairAlignment, arrayList3);
                        arrayList.add(new PairDistance(i, i2, geneticDistance, arrayList3, pairAlignment.getSiteCount()));
                        System.out.print(".");
                    } else {
                        System.out.print(TopographicalMap.defaultInvalidString);
                    }
                } else {
                    System.out.print("x");
                }
            }
            System.out.println();
        }
        Collections.sort(arrayList);
        int i3 = 0;
        for (PairDistance pairDistance : arrayList) {
            Integer valueOf = Integer.valueOf(pairDistance.x);
            Integer valueOf2 = Integer.valueOf(pairDistance.y);
            if (!hashSet.contains(valueOf) && !hashSet.contains(valueOf2)) {
                arrayList2.addAll(pairDistance.gaps);
                hashSet.add(valueOf);
                hashSet.add(valueOf2);
                System.out.println("Added pair (" + valueOf + "," + valueOf2 + ") d=" + pairDistance.distance + " L=" + pairDistance.alignmentLength);
                i3 += pairDistance.alignmentLength;
            }
        }
        printFrequencyTable(arrayList2);
        System.out.println("total length=" + i3);
    }
}
