package dr.app.checkpoint;

import dr.app.tools.AncestralSequenceAnnotator;
import dr.evolution.tree.NodeRef;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.inference.markovchain.MarkovChain;
import dr.inference.markovchain.MarkovChainListener;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.ParameterParser;
import dr.inference.operators.AdaptableMCMCOperator;
import dr.inference.operators.MCMCOperator;
import dr.inference.operators.OperatorSchedule;
import dr.inference.state.Factory;
import dr.inference.state.StateLoader;
import dr.inference.state.StateLoaderSaver;
import dr.inference.state.StateSaverChainListener;
import dr.math.MathUtils;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintStream;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;

/* loaded from: input_file:dr/app/checkpoint/BeastCheckpointer.class */
public class BeastCheckpointer implements StateLoaderSaver {
    private static final boolean DEBUG = false;
    private static final boolean CHECK_LOAD_STATE = true;
    public static final String LOAD_STATE_FILE = "load.state.file";
    public static final String SAVE_STATE_FILE = "save.state.file";
    public static final String SAVE_STATE_AT = "save.state.at";
    public static final String SAVE_STATE_EVERY = "save.state.every";
    public static final String SAVE_STEM = "save.state.stem";
    public static final String FORCE_RESUME = "force.resume";
    public static final String CHECKPOINT_SEED = "checkpoint.seed";
    private boolean forceResume = false;
    private final String loadStateFileName = System.getProperty("load.state.file", null);
    private final String saveStateFileName = System.getProperty("save.state.file", null);
    private final String stemFileName = System.getProperty(SAVE_STEM, null);

    public BeastCheckpointer() {
        final ArrayList arrayList = new ArrayList();
        if (System.getProperty(SAVE_STATE_AT) != null) {
            arrayList.add(new StateSaverChainListener(this, Long.parseLong(System.getProperty(SAVE_STATE_AT)), false));
        }
        if (System.getProperty(SAVE_STATE_EVERY) != null) {
            arrayList.add(new StateSaverChainListener(this, Long.parseLong(System.getProperty(SAVE_STATE_EVERY)), true));
        }
        Factory.INSTANCE = new Factory() { // from class: dr.app.checkpoint.BeastCheckpointer.1
            @Override // dr.inference.state.Factory
            public StateLoader getInitialStateLoader() {
                if (BeastCheckpointer.this.loadStateFileName == null) {
                    return null;
                }
                return BeastCheckpointer.this.getStateLoaderObject();
            }

            @Override // dr.inference.state.Factory
            public MarkovChainListener[] getStateSaverChainListeners() {
                return (MarkovChainListener[]) arrayList.toArray(new MarkovChainListener[0]);
            }

            @Override // dr.inference.state.Factory
            public StateLoaderSaver getStateLoaderSaver(final File file, final File file2) {
                return new StateLoaderSaver() { // from class: dr.app.checkpoint.BeastCheckpointer.1.1
                    @Override // dr.inference.state.StateSaver
                    public boolean saveState(MarkovChain markovChain, long j, double d) {
                        return BeastCheckpointer.this.writeStateToFile(file2, j, d, markovChain);
                    }

                    @Override // dr.inference.state.StateLoader
                    public long loadState(MarkovChain markovChain, double[] dArr) {
                        return BeastCheckpointer.this.readStateFromFile(file, markovChain, dArr);
                    }

                    @Override // dr.inference.state.StateLoader
                    public void checkLoadState(double d, double d2) {
                    }
                };
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public BeastCheckpointer getStateLoaderObject() {
        return this;
    }

    @Override // dr.inference.state.StateSaver
    public boolean saveState(MarkovChain markovChain, long j, double d) {
        String str;
        if (this.stemFileName != null) {
            str = this.stemFileName + "_" + j;
        } else {
            str = this.saveStateFileName != null ? this.saveStateFileName : "beast_state_" + new SimpleDateFormat("yyyy.MM.dd.HH.mm.ss").format(Calendar.getInstance().getTime());
        }
        return writeStateToFile(new File(str), j, d, markovChain);
    }

    @Override // dr.inference.state.StateLoader
    public long loadState(MarkovChain markovChain, double[] dArr) {
        return readStateFromFile(new File(this.loadStateFileName), markovChain, dArr);
    }

    @Override // dr.inference.state.StateLoader
    public void checkLoadState(double d, double d2) {
        if (System.getProperty(FORCE_RESUME) != null) {
            this.forceResume = Boolean.parseBoolean(System.getProperty(FORCE_RESUME));
        }
        if (this.forceResume) {
            System.out.println("Forcing analysis to resume regardless of recomputed likelihood values.");
            return;
        }
        if (d2 == d) {
            System.out.println("IDENTICAL LIKELIHOODS");
            System.out.println("lnL = " + d2);
            System.out.println("savedLnL[0] = " + d);
            return;
        }
        System.out.println("COMPARING LIKELIHOODS: " + d2 + " vs. " + d);
        String d3 = Double.toString(d);
        String d4 = Double.toString(d2);
        System.out.println(d2 + "    " + d3);
        System.out.println(d + "    " + d4);
        int i = 0;
        for (int i2 = 0; i2 < Math.min(d3.length(), d4.length()) && d3.charAt(i2) == d4.charAt(i2); i2++) {
            if (d3.charAt(i2) != '-' && d3.charAt(i2) != '.') {
                i++;
            }
        }
        if (i < 15) {
            double parseDouble = System.getProperty("mcmc.evaluation.threshold") != null ? Double.parseDouble(System.getProperty("mcmc.evaluation.threshold")) : 0.0d;
            if (Math.abs(d2 - d) > parseDouble) {
                throw new RuntimeException("Saved lnL does not match recomputed value for loaded state: stored lnL: " + d + ", recomputed lnL: " + d2 + " (difference " + (d - d2) + ").\nYour XML may require the construction of a randomly generated starting tree. Try resuming the analysis by using the same starting seed as for the original BEAST run.");
            }
            System.out.println("Saved lnL does not match recomputed value for loaded state: stored lnL: " + d + ", recomputed lnL: " + d2 + " (difference " + (d - d2) + ").\nThreshold of " + parseDouble + " for restarting analysis not exceeded; continuing ...");
        }
    }

    protected boolean writeStateToFile(File file, long j, double d, MarkovChain markovChain) {
        OperatorSchedule schedule = markovChain.getSchedule();
        try {
            FileOutputStream fileOutputStream = new FileOutputStream(file);
            PrintStream printStream = new PrintStream(fileOutputStream);
            ArrayList arrayList = new ArrayList();
            int[] randomState = MathUtils.getRandomState();
            printStream.print("rng");
            for (int i : randomState) {
                printStream.print("\t");
                printStream.print(i);
            }
            printStream.println();
            printStream.print("state\t");
            printStream.println(j);
            printStream.print("lnL\t");
            printStream.println(d);
            for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
                if (!parameter.isImmutable()) {
                    printStream.print(ParameterParser.PARAMETER);
                    printStream.print("\t");
                    printStream.print(parameter.getParameterName());
                    printStream.print("\t");
                    printStream.print(parameter.getDimension());
                    for (int i2 = 0; i2 < parameter.getDimension(); i2++) {
                        printStream.print("\t");
                        printStream.print(parameter.getParameterUntransformedValue(i2));
                    }
                }
                printStream.print("\n");
            }
            for (int i3 = 0; i3 < schedule.getOperatorCount(); i3++) {
                MCMCOperator operator = schedule.getOperator(i3);
                printStream.print("operator");
                printStream.print("\t");
                printStream.print(operator.getOperatorName());
                printStream.print("\t");
                printStream.print(operator.getAcceptCount());
                printStream.print("\t");
                printStream.print(operator.getRejectCount());
                if (operator instanceof AdaptableMCMCOperator) {
                    printStream.print("\t");
                    printStream.print(((AdaptableMCMCOperator) operator).getAdaptableParameter());
                    printStream.print("\t");
                    printStream.print(((AdaptableMCMCOperator) operator).getAdaptationCount());
                }
                printStream.println();
            }
            for (Model model : Model.CONNECTED_MODEL_SET) {
                if (model instanceof TreeParameterModel) {
                    arrayList.add((TreeParameterModel) model);
                }
            }
            for (Model model2 : Model.CONNECTED_MODEL_SET) {
                if (model2 instanceof TreeModel) {
                    printStream.print("tree");
                    printStream.print("\t");
                    printStream.println(model2.getModelName());
                    printStream.println("#node height taxon");
                    int nodeCount = ((TreeModel) model2).getNodeCount();
                    printStream.println(nodeCount);
                    for (int i4 = 0; i4 < nodeCount; i4++) {
                        printStream.print(((TreeModel) model2).getNode(i4).getNumber());
                        printStream.print("\t");
                        printStream.print(((TreeModel) model2).getNodeHeight(((TreeModel) model2).getNode(i4)));
                        if (((TreeModel) model2).isExternal(((TreeModel) model2).getNode(i4))) {
                            printStream.print("\t");
                            printStream.print(((TreeModel) model2).getNodeTaxon(((TreeModel) model2).getNode(i4)).getId());
                        }
                        printStream.println();
                    }
                    printStream.println("#edges");
                    printStream.println("#child-node parent-node L/R-child traits");
                    printStream.println(nodeCount);
                    for (int i5 = 0; i5 < nodeCount; i5++) {
                        NodeRef parent = ((TreeModel) model2).getParent(((TreeModel) model2).getNode(i5));
                        if (parent != null) {
                            printStream.print(((TreeModel) model2).getNode(i5).getNumber());
                            printStream.print("\t");
                            printStream.print(((TreeModel) model2).getParent(((TreeModel) model2).getNode(i5)).getNumber());
                            printStream.print("\t");
                            if (((TreeModel) model2).getChild(parent, 0) == ((TreeModel) model2).getNode(i5)) {
                                printStream.print(0);
                            } else {
                                if (((TreeModel) model2).getChild(parent, 1) != ((TreeModel) model2).getNode(i5)) {
                                    throw new RuntimeException("Operation currently only supported for nodes with 2 children.");
                                }
                                printStream.print(1);
                            }
                            Iterator it = arrayList.iterator();
                            while (it.hasNext()) {
                                TreeParameterModel treeParameterModel = (TreeParameterModel) it.next();
                                if (model2 == treeParameterModel.getTreeModel()) {
                                    printStream.print("\t");
                                    printStream.print(treeParameterModel.getNodeValue((TreeModel) model2, ((TreeModel) model2).getNode(i5)));
                                }
                            }
                            printStream.println();
                        }
                    }
                }
            }
            printStream.close();
            fileOutputStream.close();
            return true;
        } catch (IOException e) {
            System.err.println("Unable to write file: " + e.getMessage());
            return false;
        }
    }

    protected long readStateFromFile(File file, MarkovChain markovChain, double[] dArr) {
        OperatorSchedule schedule = markovChain.getSchedule();
        ArrayList<TreeParameterModel> arrayList = new ArrayList<>();
        try {
            FileReader fileReader = new FileReader(file);
            BufferedReader bufferedReader = new BufferedReader(fileReader);
            int[] iArr = null;
            String[] split = bufferedReader.readLine().split("\t");
            if (split[0].equals("rng")) {
                try {
                    iArr = new int[split.length - 1];
                    for (int i = 0; i < iArr.length; i++) {
                        iArr[i] = Integer.parseInt(split[i + 1]);
                    }
                    split = bufferedReader.readLine().split("\t");
                } catch (NumberFormatException e) {
                    throw new RuntimeException("Unable to read state number from state file");
                }
            }
            try {
                if (!split[0].equals("state")) {
                    throw new RuntimeException("Unable to read state number from state file");
                }
                long parseLong = Long.parseLong(split[1]);
                String[] split2 = bufferedReader.readLine().split("\t");
                try {
                    if (!split2[0].equals(AncestralSequenceAnnotator.LIKELIHOOD)) {
                        throw new RuntimeException("Unable to read lnL from state file");
                    }
                    if (dArr != null) {
                        dArr[0] = Double.parseDouble(split2[1]);
                    }
                    for (Parameter parameter : Parameter.CONNECTED_PARAMETER_SET) {
                        String[] split3 = bufferedReader.readLine().split("\t");
                        int parseInt = Integer.parseInt(split3[2]);
                        if (parseInt != parameter.getDimension()) {
                            System.err.println("Unable to match state parameter dimension: " + parseInt + ", expecting " + parameter.getDimension() + " for parameter: " + parameter.getParameterName());
                            System.err.print("Read from file: ");
                            for (String str : split3) {
                                System.err.print(str + "\t");
                            }
                            System.err.println();
                        }
                        if (split3[1].equals("branchRates.categories.rootNodeNumber")) {
                            parameter.setParameterValue(0, Double.parseDouble(split3[3]));
                        } else {
                            for (int i2 = 0; i2 < parameter.getDimension(); i2++) {
                                parameter.setParameterUntransformedValue(i2, Double.parseDouble(split3[i2 + 3]));
                            }
                        }
                    }
                    for (int i3 = 0; i3 < schedule.getOperatorCount(); i3++) {
                        MCMCOperator operator = schedule.getOperator(i3);
                        String[] split4 = bufferedReader.readLine().split("\t");
                        if (!split4[1].equals(operator.getOperatorName())) {
                            throw new RuntimeException("Unable to match operator: " + split4[1]);
                        }
                        if (split4.length < 4) {
                            throw new RuntimeException("Operator missing values: " + split4[1]);
                        }
                        operator.setAcceptCount(Integer.parseInt(split4[2]));
                        operator.setRejectCount(Integer.parseInt(split4[3]));
                        if (operator instanceof AdaptableMCMCOperator) {
                            if (split4.length != 6) {
                                throw new RuntimeException("Coercable operator missing parameter: " + split4[1]);
                            }
                            ((AdaptableMCMCOperator) operator).setAdaptableParameter(Double.parseDouble(split4[4]));
                            ((AdaptableMCMCOperator) operator).setAdaptationCount(Long.parseLong(split4[5]));
                        }
                    }
                    HashSet<String> hashSet = new HashSet();
                    ArrayList arrayList2 = new ArrayList();
                    for (Model model : Model.CONNECTED_MODEL_SET) {
                        if (model instanceof TreeModel) {
                            arrayList2.add((TreeModel) model);
                            hashSet.add(model.getModelName());
                        }
                        if (model instanceof TreeParameterModel) {
                            arrayList.add((TreeParameterModel) model);
                        }
                    }
                    HashMap hashMap = new HashMap();
                    for (String str2 : hashSet) {
                        ArrayList arrayList3 = new ArrayList();
                        Iterator<TreeParameterModel> it = arrayList.iterator();
                        while (it.hasNext()) {
                            TreeParameterModel next = it.next();
                            if (next.getTreeModel().getId().equals(str2)) {
                                arrayList3.add(next);
                            }
                        }
                        hashMap.put(str2, arrayList3);
                    }
                    String[] split5 = bufferedReader.readLine().split("\t");
                    while (split5[0].equals("tree")) {
                        for (Model model2 : Model.CONNECTED_MODEL_SET) {
                            if ((model2 instanceof TreeModel) && split5[1].equals(model2.getModelName())) {
                                bufferedReader.readLine();
                                int parseInt2 = Integer.parseInt(bufferedReader.readLine().split("\t")[0]);
                                double[] dArr2 = new double[parseInt2];
                                String[] strArr = new String[(parseInt2 + 1) / 2];
                                for (int i4 = 0; i4 < parseInt2; i4++) {
                                    String[] split6 = bufferedReader.readLine().split("\t");
                                    dArr2[i4] = Double.parseDouble(split6[1]);
                                    if (i4 < strArr.length) {
                                        strArr[i4] = split6[2];
                                    }
                                }
                                bufferedReader.readLine();
                                bufferedReader.readLine();
                                split5 = bufferedReader.readLine().split("\t");
                                int parseInt3 = Integer.parseInt(split5[0]);
                                double[][] dArr3 = new double[((ArrayList) hashMap.get(model2.getId())).size()][parseInt3];
                                int[] iArr2 = new int[parseInt3];
                                for (int i5 = 0; i5 < iArr2.length; i5++) {
                                    iArr2[i5] = -1;
                                }
                                int[] iArr3 = new int[parseInt3];
                                for (int i6 = 0; i6 < parseInt3; i6++) {
                                    iArr3[i6] = -1;
                                }
                                for (int i7 = 0; i7 < parseInt3 - 1; i7++) {
                                    String readLine = bufferedReader.readLine();
                                    if (readLine != null) {
                                        split5 = readLine.split("\t");
                                        iArr3[Integer.parseInt(split5[0])] = Integer.parseInt(split5[1]);
                                        iArr2[Integer.parseInt(split5[0])] = Integer.parseInt(split5[2]);
                                        for (int i8 = 0; i8 < ((ArrayList) hashMap.get(model2.getId())).size(); i8++) {
                                            dArr3[i8][Integer.parseInt(split5[0])] = Double.parseDouble(split5[3 + i8]);
                                        }
                                    }
                                }
                                ((TreeModel) model2).beginTreeEdit();
                                ((TreeModel) model2).adoptTreeStructure(iArr3, dArr2, iArr2, strArr);
                                if (arrayList.size() > 0) {
                                    System.out.println("adopting " + arrayList.size() + " trait models to treeModel " + ((TreeModel) model2).getId());
                                    ((TreeModel) model2).adoptTraitData(iArr3, arrayList, dArr3, strArr);
                                }
                                ((TreeModel) model2).endTreeEdit();
                                hashSet.remove(model2.getModelName());
                            }
                        }
                        String readLine2 = bufferedReader.readLine();
                        if (readLine2 != null) {
                            split5 = readLine2.split("\t");
                        }
                    }
                    if (hashSet.size() > 0) {
                        StringBuilder sb = new StringBuilder();
                        Iterator it2 = hashSet.iterator();
                        while (it2.hasNext()) {
                            sb.append("Expecting, but unable to match state parameter:" + ((String) it2.next()) + "\n");
                        }
                        throw new RuntimeException("\n" + sb.toString());
                    }
                    if (System.getProperty(CHECKPOINT_SEED) != null) {
                        MathUtils.setSeed(Long.parseLong(System.getProperty(CHECKPOINT_SEED)));
                    } else if (iArr != null) {
                        MathUtils.setRandomState(iArr);
                    }
                    bufferedReader.close();
                    fileReader.close();
                    return parseLong;
                } catch (NumberFormatException e2) {
                    throw new RuntimeException("Unable to read lnL from state file");
                }
            } catch (NumberFormatException e3) {
                throw new RuntimeException("Unable to read state number from state file");
            }
        } catch (IOException e4) {
            throw new RuntimeException("Unable to read file: " + e4.getMessage());
        }
    }
}
