package dr.math.distributions;

import dr.inference.loggers.LogColumn;
import dr.inference.loggers.NumberColumn;
import dr.inference.model.AbstractModel;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameter;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.math.GammaFunction;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import dr.xml.XORRule;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

/* loaded from: input_file:dr/math/distributions/MultivariatePolyaDistributionLikelihood.class */
public class MultivariatePolyaDistributionLikelihood extends AbstractModel implements Likelihood {
    protected Parameter frequencies;
    protected Parameter dispersion;
    protected Parameter alphas;
    protected boolean usingAlphas;
    protected boolean isAlphasKnown;
    protected MatrixParameter data;
    protected double fixedNorm;
    protected double variableNorm;
    protected double storedFixedNorm;
    protected double storedVariableNorm;
    protected double logLikelihood;
    protected double storedLogLikelihood;
    protected boolean isLogLikelihoodKnown;
    protected boolean isFixedNormKnown;
    protected boolean isVariableNormKnown;
    protected double[] rowSums;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser() { // from class: dr.math.distributions.MultivariatePolyaDistributionLikelihood.2
        private final XMLSyntaxRule[] rules = {new ElementRule("data", new XMLSyntaxRule[]{new ElementRule(MatrixParameter.class)}, false), new XORRule(new ElementRule(MultivariatePolyaDistributionLikelihood.RATES, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, false), new ElementRule(MultivariatePolyaDistributionLikelihood.FREQ, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, false)), new ElementRule(MultivariatePolyaDistributionLikelihood.DISPERSION, new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true)};

        public String getParserName() {
            return MultivariatePolyaDistributionLikelihood.MVPLIKE;
        }

        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            if (!xMLObject.hasChildNamed("data")) {
                throw new XMLParseException("Missing data element!");
            }
            MatrixParameter matrixParameter = (MatrixParameter) xMLObject.getChild("data").getChild(MatrixParameter.class);
            if (xMLObject.hasChildNamed(MultivariatePolyaDistributionLikelihood.RATES)) {
                Parameter parameter = (Parameter) xMLObject.getChild(MultivariatePolyaDistributionLikelihood.RATES).getChild(Parameter.class);
                if (parameter.getDimension() != matrixParameter.getColumnDimension()) {
                    throw new XMLParseException("The number of data columns must match the dimension of alpha parameter (" + matrixParameter.getColumnDimension() + " != " + parameter.getDimension() + "!");
                }
                return new MultivariatePolyaDistributionLikelihood(MultivariatePolyaDistributionLikelihood.MVPLIKE, matrixParameter, parameter);
            }
            if (!xMLObject.hasChildNamed(MultivariatePolyaDistributionLikelihood.FREQ)) {
                throw new XMLParseException("Either frequencies or alphaelement has to be specified!");
            }
            Parameter parameter2 = (Parameter) xMLObject.getChild(MultivariatePolyaDistributionLikelihood.FREQ).getChild(Parameter.class);
            if (!xMLObject.hasChildNamed(MultivariatePolyaDistributionLikelihood.DISPERSION)) {
                throw new XMLParseException("dispersion element has to be specified when using frequencies parametrization");
            }
            Parameter parameter3 = (Parameter) xMLObject.getChild(MultivariatePolyaDistributionLikelihood.DISPERSION).getChild(Parameter.class);
            if (parameter3.getDimension() != 1) {
                throw new XMLParseException("Dispersion parameter must be of dimmension exactly 1!");
            }
            if (parameter2.getDimension() != matrixParameter.getColumnDimension()) {
                throw new XMLParseException("The number of data columns must match the dimension of frequencies parameter (" + matrixParameter.getColumnDimension() + " != " + parameter2.getDimension() + "!");
            }
            return new MultivariatePolyaDistributionLikelihood(MultivariatePolyaDistributionLikelihood.MVPLIKE, matrixParameter, parameter2, parameter3);
        }

        public String getParserDescription() {
            return "A matrix parameter constructed from its component parameters.";
        }

        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        public Class getReturnType() {
            return MatrixParameter.class;
        }
    };
    public static final String MVPLIKE = "mvPolyaLikelihood";
    public static final String DATA = "data";
    public static final String DISPERSION = "dispersion";
    public static final String FREQ = "frequencies";
    public static final String RATES = "alpha";

    public MultivariatePolyaDistributionLikelihood(String str, MatrixParameter matrixParameter, Parameter parameter, Parameter parameter2) {
        super(str);
        this.frequencies = parameter;
        this.dispersion = parameter2;
        this.alphas = new Parameter.Default(parameter.getDimension());
        computeAlphas();
        this.data = matrixParameter;
        this.isFixedNormKnown = false;
        this.isVariableNormKnown = false;
        addVariable(this.frequencies);
        addVariable(this.dispersion);
        addVariable(this.data);
        if (this.alphas.getDimension() != matrixParameter.getColumnDimension()) {
            System.err.println("Dimensions of the frequency vector and number of columns do not match!");
        }
    }

    public MultivariatePolyaDistributionLikelihood(String str, MatrixParameter matrixParameter, Parameter parameter) {
        super(str);
        this.alphas = parameter;
        this.isAlphasKnown = true;
        this.usingAlphas = true;
        this.frequencies = new Parameter.Default(parameter.getDimension());
        this.dispersion = new Parameter.Default(1);
        this.data = matrixParameter;
        this.isFixedNormKnown = false;
        this.isVariableNormKnown = false;
        addVariable(this.alphas);
        addVariable(this.data);
        if (this.alphas.getDimension() != matrixParameter.getColumnDimension()) {
            System.err.println("Dimensions of the frequency vector and number of columns do not match!");
        }
    }

    protected void computeAlphas() {
        double parameterValue = this.dispersion.getParameterValue(0);
        double[] parameterValues = this.frequencies.getParameterValues();
        for (int i = 0; i < this.alphas.getDimension(); i++) {
            this.alphas.setParameterValueQuietly(i, parameterValue * parameterValues[i]);
        }
        this.alphas.setParameterValueNotifyChangedAll(0, this.alphas.getParameterValue(0));
        this.isAlphasKnown = true;
    }

    public MultivariatePolyaDistributionLikelihood(String str) {
        super(str);
    }

    public double calculateLogLikelihood() {
        if (!this.isAlphasKnown) {
            computeAlphas();
        }
        if (!this.isFixedNormKnown) {
            computeFixedNorm();
        }
        if (!this.isVariableNormKnown) {
            computeVariableNorm();
        }
        double d = this.fixedNorm + this.variableNorm;
        double d2 = 0.0d;
        double[] parameterValues = this.alphas.getParameterValues();
        for (int i = 0; i < this.alphas.getDimension(); i++) {
            d2 += parameterValues[i];
        }
        for (int i2 = 0; i2 < this.data.getRowDimension(); i2++) {
            for (int i3 = 0; i3 < this.data.getColumnDimension(); i3++) {
                d += GammaFunction.lnGamma(this.data.getParameterValue(i2, i3) + parameterValues[i3]);
            }
            d -= GammaFunction.lnGamma(this.rowSums[i2] + d2);
        }
        return d;
    }

    protected void computeFixedNorm() {
        this.rowSums = new double[this.data.getRowDimension()];
        for (int i = 0; i < this.data.getRowDimension(); i++) {
            this.rowSums[i] = 0.0d;
            for (int i2 = 0; i2 < this.data.getColumnDimension(); i2++) {
                double[] dArr = this.rowSums;
                int i3 = i;
                dArr[i3] = dArr[i3] + this.data.getParameterValue(i, i2);
            }
        }
        this.fixedNorm = 0.0d;
        for (int i4 = 0; i4 < this.data.getRowDimension(); i4++) {
            for (int i5 = 0; i5 < this.data.getColumnDimension(); i5++) {
                this.fixedNorm -= GammaFunction.lnGamma(this.data.getParameterValue(i4, i5) + 1.0d);
            }
            this.fixedNorm += GammaFunction.lnGamma(this.rowSums[i4] + 1.0d);
        }
        this.isFixedNormKnown = true;
    }

    protected void computeVariableNorm() {
        double d = 0.0d;
        double[] parameterValues = this.alphas.getParameterValues();
        for (int i = 0; i < this.alphas.getDimension(); i++) {
            d += parameterValues[i];
        }
        this.variableNorm = GammaFunction.lnGamma(d);
        for (int i2 = 0; i2 < this.alphas.getDimension(); i2++) {
            this.variableNorm -= GammaFunction.lnGamma(parameterValues[i2]);
        }
        this.variableNorm *= this.data.getRowDimension();
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleModelChangedEvent(Model model, Object obj, int i) {
    }

    @Override // dr.inference.model.AbstractModel
    protected void handleVariableChangedEvent(Variable variable, int i, Variable.ChangeType changeType) {
        if (variable.getVariableName().equals(this.frequencies.getVariableName()) || variable.getVariableName().equals(this.dispersion.getVariableName())) {
            this.isAlphasKnown = false;
            this.isVariableNormKnown = false;
        } else if (variable.getVariableName().equals(this.data.getVariableName())) {
            this.isFixedNormKnown = false;
        } else if (variable.getVariableName().equals(this.alphas.getVariableName())) {
            this.isVariableNormKnown = false;
        }
    }

    @Override // dr.inference.model.AbstractModel
    protected void storeState() {
        this.storedVariableNorm = this.variableNorm;
        this.storedFixedNorm = this.fixedNorm;
        this.storedLogLikelihood = this.logLikelihood;
    }

    @Override // dr.inference.model.AbstractModel
    protected void restoreState() {
        this.variableNorm = this.storedVariableNorm;
        this.fixedNorm = this.storedFixedNorm;
        this.logLikelihood = this.storedLogLikelihood;
        if (this.usingAlphas) {
            return;
        }
        computeAlphas();
    }

    @Override // dr.inference.model.AbstractModel
    protected void acceptState() {
    }

    @Override // dr.inference.model.Likelihood
    public Model getModel() {
        return this;
    }

    @Override // dr.inference.model.Likelihood
    public double getLogLikelihood() {
        if (!this.isLogLikelihoodKnown) {
            this.logLikelihood = calculateLogLikelihood();
        }
        return this.logLikelihood;
    }

    @Override // dr.inference.model.Likelihood
    public void makeDirty() {
        this.isLogLikelihoodKnown = false;
        this.isVariableNormKnown = false;
        this.isFixedNormKnown = false;
    }

    @Override // dr.inference.model.Likelihood
    public String prettyName() {
        return "Multivariate Polya Distribution Likelihood";
    }

    @Override // dr.inference.model.Likelihood
    public boolean evaluateEarly() {
        return false;
    }

    @Override // dr.inference.model.Likelihood
    public Set<Likelihood> getLikelihoodSet() {
        return new HashSet(Arrays.asList(this));
    }

    @Override // dr.inference.model.Likelihood
    public void setUsed() {
    }

    public LogColumn[] getColumns() {
        return new LogColumn[]{new NumberColumn(getId()) { // from class: dr.math.distributions.MultivariatePolyaDistributionLikelihood.1
            public double getDoubleValue() {
                return MultivariatePolyaDistributionLikelihood.this.getLogLikelihood();
            }
        }};
    }
}
