Search in sources :

Example 1 with MultivariateFunction

use of dr.math.MultivariateFunction in project beast-mcmc by beast-dev.

the class BranchSubstitutionParameterGradient method numericWrap.

private MultivariateFunction numericWrap(final Parameter parameter) {
    return new MultivariateFunction() {

        @Override
        public double evaluate(double[] argument) {
            if (!(branchRateModel instanceof ArbitraryBranchRates)) {
                throw new RuntimeException("Not yet tested with ProxyParameter.");
            }
            ArbitraryBranchRates branchRates = (ArbitraryBranchRates) branchRateModel;
            Tree tree = treeDataLikelihood.getTree();
            for (int i = 0; i < argument.length; ++i) {
                NodeRef node = tree.getNode(i);
                if (!tree.isRoot(node)) {
                    branchRates.setBranchRate(tree, tree.getNode(i), argument[i]);
                }
            }
            // treeDataLikelihood.makeDirty();
            return treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return parameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
}
Also used : MultivariateFunction(dr.math.MultivariateFunction) ArbitraryBranchRates(dr.evomodel.branchratemodel.ArbitraryBranchRates)

Example 2 with MultivariateFunction

use of dr.math.MultivariateFunction in project beast-mcmc by beast-dev.

the class HamiltonianMonteCarloOperator method checkGradient.

void checkGradient(final Likelihood joint) {
    if (parameter.getDimension() != gradientProvider.getDimension()) {
        throw new RuntimeException("Unequal dimensions");
    }
    MultivariateFunction numeric = new MultivariateFunction() {

        @Override
        public double evaluate(double[] argument) {
            if (transform == null) {
                ReadableVector.Utils.setParameter(argument, parameter);
                return joint.getLogLikelihood();
            } else {
                double[] untransformedValue = transform.inverse(argument, 0, argument.length);
                ReadableVector.Utils.setParameter(untransformedValue, parameter);
                return joint.getLogLikelihood() - transform.getLogJacobian(untransformedValue, 0, untransformedValue.length);
            }
        }

        @Override
        public int getNumArguments() {
            return parameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return parameter.getBounds().getLowerLimit(n);
        }

        @Override
        public double getUpperBound(int n) {
            return parameter.getBounds().getUpperLimit(n);
        }
    };
    double[] analyticalGradientOriginal = gradientProvider.getGradientLogDensity();
    double[] restoredParameterValue = parameter.getParameterValues();
    if (transform == null) {
        double[] numericGradientOriginal = NumericalDerivative.gradient(numeric, parameter.getParameterValues());
        if (!MathUtils.isClose(analyticalGradientOriginal, numericGradientOriginal, runtimeOptions.gradientCheckTolerance)) {
            String sb = "Gradients do not match:\n" + "\tAnalytic: " + new WrappedVector.Raw(analyticalGradientOriginal) + "\n" + "\tNumeric : " + new WrappedVector.Raw(numericGradientOriginal) + "\n";
            throw new RuntimeException(sb);
        }
    } else {
        double[] transformedParameter = transform.transform(parameter.getParameterValues(), 0, parameter.getParameterValues().length);
        double[] numericGradientTransformed = NumericalDerivative.gradient(numeric, transformedParameter);
        double[] analyticalGradientTransformed = transform.updateGradientLogDensity(analyticalGradientOriginal, parameter.getParameterValues(), 0, parameter.getParameterValues().length);
        if (!MathUtils.isClose(analyticalGradientTransformed, numericGradientTransformed, runtimeOptions.gradientCheckTolerance)) {
            String sb = "Transformed Gradients do not match:\n" + "\tAnalytic: " + new WrappedVector.Raw(analyticalGradientTransformed) + "\n" + "\tNumeric : " + new WrappedVector.Raw(numericGradientTransformed) + "\n" + "\tParameter : " + new WrappedVector.Raw(parameter.getParameterValues()) + "\n" + "\tTransformed Parameter : " + new WrappedVector.Raw(transformedParameter) + "\n";
            throw new RuntimeException(sb);
        }
    }
    ReadableVector.Utils.setParameter(restoredParameterValue, parameter);
}
Also used : MultivariateFunction(dr.math.MultivariateFunction) WrappedVector(dr.math.matrixAlgebra.WrappedVector)

Aggregations

MultivariateFunction (dr.math.MultivariateFunction)2 ArbitraryBranchRates (dr.evomodel.branchratemodel.ArbitraryBranchRates)1 WrappedVector (dr.math.matrixAlgebra.WrappedVector)1