package dr.evomodelxml.continuous.hmc;

import dr.evolution.tree.Tree;
import dr.evomodel.treedatalikelihood.DataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.BranchSpecificGradient;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch;
import dr.evomodel.treedatalikelihood.hmc.AbstractDiffusionGradient;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.ArrayList;
import java.util.Arrays;

/* loaded from: input_file:dr/evomodelxml/continuous/hmc/AttenuationGradientParser.class */
public class AttenuationGradientParser extends AbstractXMLObjectParser {
    private static final String PRECISION_GRADIENT = "attenuationGradient";
    private static final String PARAMETER = "parameter";
    private static final String ATTENUATION_CORRELATION = "correlation";
    private static final String ATTENUATION_DIAGONAL = "diagonal";
    private static final String ATTENUATION_BOTH = "both";
    private static final String TRAIT_NAME = "traitName";
    private final XMLSyntaxRule[] rules = {new ElementRule(Likelihood.class), new ElementRule(MatrixParameterInterface.class)};

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:dr/evomodelxml/continuous/hmc/AttenuationGradientParser$ParameterMode.class */
    public enum ParameterMode {
        WRT_BOTH { // from class: dr.evomodelxml.continuous.hmc.AttenuationGradientParser.ParameterMode.1
            @Override // dr.evomodelxml.continuous.hmc.AttenuationGradientParser.ParameterMode
            public Object factory(BranchSpecificGradient branchSpecificGradient, TreeDataLikelihood treeDataLikelihood, MatrixParameterInterface matrixParameterInterface) {
                throw new RuntimeException("Gradient wrt full attenuation not yet implemented.");
            }
        },
        WRT_CORRELATION { // from class: dr.evomodelxml.continuous.hmc.AttenuationGradientParser.ParameterMode.2
            @Override // dr.evomodelxml.continuous.hmc.AttenuationGradientParser.ParameterMode
            public Object factory(BranchSpecificGradient branchSpecificGradient, TreeDataLikelihood treeDataLikelihood, MatrixParameterInterface matrixParameterInterface) {
                throw new RuntimeException("Gradient wrt correlation of attenuation not yet implemented.");
            }
        },
        WRT_DIAGONAL { // from class: dr.evomodelxml.continuous.hmc.AttenuationGradientParser.ParameterMode.3
            @Override // dr.evomodelxml.continuous.hmc.AttenuationGradientParser.ParameterMode
            public Object factory(BranchSpecificGradient branchSpecificGradient, TreeDataLikelihood treeDataLikelihood, MatrixParameterInterface matrixParameterInterface) {
                return AbstractDiffusionGradient.ParameterDiffusionGradient.createDiagonalAttenuationGradient(branchSpecificGradient, treeDataLikelihood, matrixParameterInterface);
            }
        };

        abstract Object factory(BranchSpecificGradient branchSpecificGradient, TreeDataLikelihood treeDataLikelihood, MatrixParameterInterface matrixParameterInterface);
    }

    @Override // dr.xml.XMLObjectParser
    public String getParserName() {
        return PRECISION_GRADIENT;
    }

    private ParameterMode parseParameterMode(XMLObject xMLObject) throws XMLParseException {
        ParameterMode parameterMode = ParameterMode.WRT_BOTH;
        String lowerCase = ((String) xMLObject.getAttribute("parameter", ATTENUATION_BOTH)).toLowerCase();
        if (lowerCase.compareTo("correlation") == 0) {
            parameterMode = ParameterMode.WRT_CORRELATION;
        } else if (lowerCase.compareTo("diagonal") == 0) {
            parameterMode = ParameterMode.WRT_DIAGONAL;
        }
        return parameterMode;
    }

    @Override // dr.xml.AbstractXMLObjectParser
    public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
        String str = (String) xMLObject.getAttribute("traitName", "trait");
        MatrixParameterInterface matrixParameterInterface = (MatrixParameterInterface) xMLObject.getChild(MatrixParameterInterface.class);
        TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood) xMLObject.getChild(TreeDataLikelihood.class);
        DataLikelihoodDelegate dataLikelihoodDelegate = treeDataLikelihood.getDataLikelihoodDelegate();
        int traitDim = treeDataLikelihood.getDataLikelihoodDelegate().getTraitDim();
        Tree tree = treeDataLikelihood.getTree();
        ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate = (ContinuousDataLikelihoodDelegate) dataLikelihoodDelegate;
        return parseParameterMode(xMLObject).factory(new BranchSpecificGradient(str, treeDataLikelihood, continuousDataLikelihoodDelegate, new ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient(traitDim, tree, continuousDataLikelihoodDelegate, new ArrayList(Arrays.asList(ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter.WRT_DIAGONAL_SELECTION_STRENGTH))), matrixParameterInterface), treeDataLikelihood, matrixParameterInterface);
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public XMLSyntaxRule[] getSyntaxRules() {
        return this.rules;
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public String getParserDescription() {
        return null;
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public Class getReturnType() {
        return AbstractDiffusionGradient.ParameterDiffusionGradient.class;
    }
}
