use of dr.math.MultivariateFunction in project beast-mcmc by beast-dev.
the class BranchSubstitutionParameterGradient method numericWrap.
private MultivariateFunction numericWrap(final Parameter parameter) {
return new MultivariateFunction() {
@Override
public double evaluate(double[] argument) {
if (!(branchRateModel instanceof ArbitraryBranchRates)) {
throw new RuntimeException("Not yet tested with ProxyParameter.");
}
ArbitraryBranchRates branchRates = (ArbitraryBranchRates) branchRateModel;
Tree tree = treeDataLikelihood.getTree();
for (int i = 0; i < argument.length; ++i) {
NodeRef node = tree.getNode(i);
if (!tree.isRoot(node)) {
branchRates.setBranchRate(tree, tree.getNode(i), argument[i]);
}
}
// treeDataLikelihood.makeDirty();
return treeDataLikelihood.getLogLikelihood();
}
@Override
public int getNumArguments() {
return parameter.getDimension();
}
@Override
public double getLowerBound(int n) {
return 0;
}
@Override
public double getUpperBound(int n) {
return Double.POSITIVE_INFINITY;
}
};
}
use of dr.math.MultivariateFunction in project beast-mcmc by beast-dev.
the class HamiltonianMonteCarloOperator method checkGradient.
void checkGradient(final Likelihood joint) {
if (parameter.getDimension() != gradientProvider.getDimension()) {
throw new RuntimeException("Unequal dimensions");
}
MultivariateFunction numeric = new MultivariateFunction() {
@Override
public double evaluate(double[] argument) {
if (transform == null) {
ReadableVector.Utils.setParameter(argument, parameter);
return joint.getLogLikelihood();
} else {
double[] untransformedValue = transform.inverse(argument, 0, argument.length);
ReadableVector.Utils.setParameter(untransformedValue, parameter);
return joint.getLogLikelihood() - transform.getLogJacobian(untransformedValue, 0, untransformedValue.length);
}
}
@Override
public int getNumArguments() {
return parameter.getDimension();
}
@Override
public double getLowerBound(int n) {
return parameter.getBounds().getLowerLimit(n);
}
@Override
public double getUpperBound(int n) {
return parameter.getBounds().getUpperLimit(n);
}
};
double[] analyticalGradientOriginal = gradientProvider.getGradientLogDensity();
double[] restoredParameterValue = parameter.getParameterValues();
if (transform == null) {
double[] numericGradientOriginal = NumericalDerivative.gradient(numeric, parameter.getParameterValues());
if (!MathUtils.isClose(analyticalGradientOriginal, numericGradientOriginal, runtimeOptions.gradientCheckTolerance)) {
String sb = "Gradients do not match:\n" + "\tAnalytic: " + new WrappedVector.Raw(analyticalGradientOriginal) + "\n" + "\tNumeric : " + new WrappedVector.Raw(numericGradientOriginal) + "\n";
throw new RuntimeException(sb);
}
} else {
double[] transformedParameter = transform.transform(parameter.getParameterValues(), 0, parameter.getParameterValues().length);
double[] numericGradientTransformed = NumericalDerivative.gradient(numeric, transformedParameter);
double[] analyticalGradientTransformed = transform.updateGradientLogDensity(analyticalGradientOriginal, parameter.getParameterValues(), 0, parameter.getParameterValues().length);
if (!MathUtils.isClose(analyticalGradientTransformed, numericGradientTransformed, runtimeOptions.gradientCheckTolerance)) {
String sb = "Transformed Gradients do not match:\n" + "\tAnalytic: " + new WrappedVector.Raw(analyticalGradientTransformed) + "\n" + "\tNumeric : " + new WrappedVector.Raw(numericGradientTransformed) + "\n" + "\tParameter : " + new WrappedVector.Raw(parameter.getParameterValues()) + "\n" + "\tTransformed Parameter : " + new WrappedVector.Raw(transformedParameter) + "\n";
throw new RuntimeException(sb);
}
}
ReadableVector.Utils.setParameter(restoredParameterValue, parameter);
}
Aggregations