package dr.inference.operators;

import dr.inference.distribution.LinearRegression;
import dr.inference.distribution.MultivariateDistributionLikelihood;
import dr.inference.model.Parameter;
import dr.math.MathUtils;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

/* loaded from: input_file:dr/inference/operators/RegressionMetropolizedIndicatorOperator.class */
public class RegressionMetropolizedIndicatorOperator extends SimpleMCMCOperator {
    public static final String MH_OPERATOR = "regressionMetropolizedIndicatorOperator";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.inference.operators.RegressionMetropolizedIndicatorOperator.1
        private XMLSyntaxRule[] rules = {AttributeRule.newDoubleRule("weight"), new ElementRule(Parameter.class), new ElementRule(MultivariateDistributionLikelihood.class), new ElementRule(LinearRegression.class), new ElementRule("indicator", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}), new ElementRule("mask", new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true)};

        @Override // dr.xml.XMLObjectParser
        public String getParserName() {
            return RegressionMetropolizedIndicatorOperator.MH_OPERATOR;
        }

        @Override // dr.xml.AbstractXMLObjectParser
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            double doubleAttribute = xMLObject.getDoubleAttribute("weight");
            LinearRegression linearRegression = (LinearRegression) xMLObject.getChild(LinearRegression.class);
            Parameter parameter = (Parameter) xMLObject.getChild(Parameter.class);
            MultivariateDistributionLikelihood multivariateDistributionLikelihood = (MultivariateDistributionLikelihood) xMLObject.getChild(MultivariateDistributionLikelihood.class);
            if (multivariateDistributionLikelihood.getDistribution().getType().compareTo(MultivariateNormalDistribution.TYPE) != 0) {
                throw new XMLParseException("Only a multivariate normal prior is conjugate");
            }
            Parameter parameter2 = (Parameter) xMLObject.getChild("indicator").getChild(Parameter.class);
            XMLObject child = xMLObject.getChild("mask");
            Parameter parameter3 = null;
            if (child != null) {
                parameter3 = (Parameter) child.getChild(Parameter.class);
                if (parameter3.getDimension() != parameter2.getDimension()) {
                    throw new XMLParseException("Indicator and mask parameter must have the same dimension");
                }
            }
            RegressionMetropolizedIndicatorOperator regressionMetropolizedIndicatorOperator = new RegressionMetropolizedIndicatorOperator(linearRegression, parameter, parameter2, multivariateDistributionLikelihood, parameter3);
            regressionMetropolizedIndicatorOperator.setWeight(doubleAttribute);
            return regressionMetropolizedIndicatorOperator;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public String getParserDescription() {
            return "This element returns a multivariate Gibbs operator on an internal node trait.";
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public Class getReturnType() {
            return MCMCOperator.class;
        }

        @Override // dr.xml.AbstractXMLObjectParser, dr.xml.XMLObjectParser
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }
    };
    private Parameter mask;
    private Parameter indicators;
    private Parameter effect;
    private RegressionGibbsEffectOperator effectOperator;
    private double[] mean = null;
    private double[][] variance = null;
    private double[][] precision = null;

    public RegressionMetropolizedIndicatorOperator(LinearRegression linearRegression, Parameter parameter, Parameter parameter2, MultivariateDistributionLikelihood multivariateDistributionLikelihood, Parameter parameter3) {
        this.effectOperator = new RegressionGibbsEffectOperator(linearRegression, parameter, parameter2, multivariateDistributionLikelihood);
        this.effect = parameter;
        this.indicators = parameter2;
        this.mask = parameter3;
    }

    public String getPerformanceSuggestion() {
        return null;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator, dr.inference.operators.MCMCOperator
    public String getOperatorName() {
        return MH_OPERATOR;
    }

    @Override // dr.inference.operators.SimpleMCMCOperator
    public double doOperation() {
        int nextInt;
        if (this.mask != null) {
            int i = 0;
            for (int i2 = 0; i2 < this.mask.getDimension(); i2++) {
                i = (int) (i + this.mask.getParameterValue(i2));
            }
            if (i == 0) {
                throw new RuntimeException("Mask parameter has all zeros");
            }
        }
        if (this.mean == null) {
            int dimension = this.effect.getDimension();
            this.mean = new double[dimension];
            this.variance = new double[dimension][dimension];
            this.precision = new double[dimension][dimension];
        }
        this.effectOperator.computeForwardDensity(this.mean, this.variance, this.precision);
        double logPdf = 0.0d + MultivariateNormalDistribution.logPdf(this.effect.getParameterValues(), this.mean, this.precision, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(this.precision)), 1.0d);
        do {
            nextInt = MathUtils.nextInt(this.indicators.getDimension());
            if (this.mask == null) {
                break;
            }
        } while (this.mask.getParameterValue(nextInt) == 0.0d);
        this.indicators.setParameterValue(nextInt, 1.0d - this.indicators.getParameterValue(nextInt));
        this.effectOperator.doOperation();
        this.mean = this.effectOperator.getLastMean();
        this.variance = this.effectOperator.getLastVariance();
        this.precision = this.effectOperator.getLastPrecision();
        return logPdf - MultivariateNormalDistribution.logPdf(this.effect.getParameterValues(), this.mean, this.precision, Math.log(MultivariateNormalDistribution.calculatePrecisionMatrixDeterminate(this.precision)), 1.0d);
    }
}
