package dr.inferencexml.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.inference.operators.hmc.NoUTurnOperator;
import dr.util.Transform;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

/* loaded from: input_file:dr/inferencexml/operators/hmc/HamiltonianMonteCarloOperatorParser.class */
public class HamiltonianMonteCarloOperatorParser extends AbstractXMLObjectParser {
    private static final String HMC_OPERATOR = "hamiltonianMonteCarloOperator";
    private static final String N_STEPS = "nSteps";
    private static final String STEP_SIZE = "stepSize";
    private static final String MODE = "mode";
    private static final String NUTS = "nuts";
    private static final String VANILLA = "vanilla";
    private static final String RANDOM_STEP_FRACTION = "randomStepCountFraction";
    private static final String PRECONDITIONING = "preconditioning";
    private static final String PRECONDITIONING_UPDATE_FREQUENCY = "preconditioningUpdateFrequency";
    private static final String PRECONDITIONING_DELAY = "preconditioningDelay";
    private static final String PRECONDITIONING_MEMORY = "preconditioningMemory";
    private static final String GRADIENT_CHECK_COUNT = "gradientCheckCount";
    private static final String GRADIENT_CHECK_TOLERANCE = "gradientCheckTolerance";
    private static final String MAX_ITERATIONS = "checkStepSizeMaxIterations";
    private static final String REDUCTION_FACTOR = "checkStepSizeReductionFactor";
    private static final String TARGET_ACCEPTANCE_PROBABILITY = "targetAcceptanceProbability";
    private static final String MASK = "mask";
    protected final XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), AttributeRule.newIntegerRule("nSteps", true), AttributeRule.newDoubleRule("stepSize"), AttributeRule.newBooleanRule("autoOptimize", true), AttributeRule.newStringRule(PRECONDITIONING, true), AttributeRule.newStringRule("mode", true), AttributeRule.newDoubleRule(RANDOM_STEP_FRACTION, true), AttributeRule.newDoubleRule(TARGET_ACCEPTANCE_PROBABILITY, true), new ElementRule(Parameter.class, true), new ElementRule(Transform.MultivariableTransformWithParameter.class, true), new ElementRule(GradientWrtParameterProvider.class), new ElementRule("mask", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true)};

    @Override // dr.xml.XMLObjectParser
    public String getParserName() {
        return HMC_OPERATOR;
    }

    private int parseRunMode(XMLObject xMLObject) throws XMLParseException {
        int i = 0;
        if (((String) xMLObject.getAttribute("mode", VANILLA)).toLowerCase().compareTo(NUTS) == 0) {
            i = 1;
        }
        return i;
    }

    private MassPreconditioner.Type parsePreconditioning(XMLObject xMLObject) throws XMLParseException {
        return MassPreconditioner.Type.parseFromString((String) xMLObject.getAttribute(PRECONDITIONING, MassPreconditioner.Type.NONE.getName()));
    }

    @Override // dr.xml.AbstractXMLObjectParser
    public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
        double doubleAttribute = xMLObject.getDoubleAttribute("weight");
        int intValue = ((Integer) xMLObject.getAttribute("nSteps", 10)).intValue();
        double doubleAttribute2 = xMLObject.getDoubleAttribute("stepSize");
        int parseRunMode = parseRunMode(xMLObject);
        MassPreconditioner.Type parsePreconditioning = parsePreconditioning(xMLObject);
        double abs = Math.abs(((Double) xMLObject.getAttribute(RANDOM_STEP_FRACTION, Double.valueOf(0.0d))).doubleValue());
        if (abs > 1.0d) {
            throw new XMLParseException("Random step count fraction must be < 1.0");
        }
        int intValue2 = ((Integer) xMLObject.getAttribute(PRECONDITIONING_UPDATE_FREQUENCY, 0)).intValue();
        int intValue3 = ((Integer) xMLObject.getAttribute(PRECONDITIONING_DELAY, 0)).intValue();
        int intValue4 = ((Integer) xMLObject.getAttribute(PRECONDITIONING_MEMORY, 0)).intValue();
        AdaptationMode parseMode = AdaptationMode.parseMode(xMLObject);
        GradientWrtParameterProvider gradientWrtParameterProvider = (GradientWrtParameterProvider) xMLObject.getChild(GradientWrtParameterProvider.class);
        if (parsePreconditioning != MassPreconditioner.Type.NONE && !(gradientWrtParameterProvider instanceof HessianWrtParameterProvider)) {
            throw new XMLParseException("Unable precondition without a Hessian provider");
        }
        Parameter parameter = (Parameter) xMLObject.getChild(Parameter.class);
        if (parameter == null) {
            parameter = gradientWrtParameterProvider.getParameter();
        }
        Transform parseTransform = Transform.Util.parseTransform(xMLObject);
        boolean z = gradientWrtParameterProvider.getDimension() != parameter.getDimension();
        if (parseTransform != null && (parseTransform instanceof Transform.MultivariableTransform)) {
            z = ((Transform.MultivariableTransform) parseTransform).getDimension() != parameter.getDimension();
        }
        if (z) {
            throw new XMLParseException("Gradient (" + gradientWrtParameterProvider.getDimension() + ") must be the same dimensions as the parameter (" + parameter.getDimension() + ")");
        }
        Parameter parameter2 = null;
        if (xMLObject.hasChildNamed("mask")) {
            parameter2 = (Parameter) xMLObject.getElementFirstChild("mask");
            if (parameter2.getDimension() != gradientWrtParameterProvider.getDimension()) {
                throw new XMLParseException("Mask (" + parameter2.getDimension() + ") must be the same dimension as the gradient (" + gradientWrtParameterProvider.getDimension() + ")");
            }
        }
        return factory(parseMode, doubleAttribute, gradientWrtParameterProvider, parameter, parseTransform, parameter2, new HamiltonianMonteCarloOperator.Options(doubleAttribute2, intValue, abs, intValue2, intValue3, intValue4, ((Integer) xMLObject.getAttribute(GRADIENT_CHECK_COUNT, 0)).intValue(), ((Double) xMLObject.getAttribute(GRADIENT_CHECK_TOLERANCE, Double.valueOf(0.001d))).doubleValue(), ((Integer) xMLObject.getAttribute(MAX_ITERATIONS, 10)).intValue(), ((Double) xMLObject.getAttribute(REDUCTION_FACTOR, Double.valueOf(0.1d))).doubleValue(), ((Double) xMLObject.getAttribute(TARGET_ACCEPTANCE_PROBABILITY, Double.valueOf(0.8d))).doubleValue()), parsePreconditioning, parseRunMode);
    }

    protected HamiltonianMonteCarloOperator factory(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, HamiltonianMonteCarloOperator.Options options, MassPreconditioner.Type type, int i) {
        return i == 0 ? new HamiltonianMonteCarloOperator(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, parameter2, options, type) : new NoUTurnOperator(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, parameter2, options, type);
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public XMLSyntaxRule[] getSyntaxRules() {
        return this.rules;
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public String getParserDescription() {
        return "Returns a Hamiltonian Monte Carlo transition kernel";
    }

    @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
    public Class getReturnType() {
        return HamiltonianMonteCarloOperator.class;
    }
}
