use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.
the class IrreversibleZigZagOperator method doBounce.
@SuppressWarnings("Duplicates")
@Override
BounceState doBounce(BounceState initialBounceState, MinimumTravelInformation firstBounce, WrappedVector position, WrappedVector velocity, WrappedVector action, WrappedVector gradient, WrappedVector momentum) {
if (TIMING) {
timer.startTimer("doBounce");
}
double remainingTime = initialBounceState.remainingTime;
double eventTime = firstBounce.time;
final BounceState finalBounceState;
if (remainingTime < eventTime) {
// No event during remaining time
updatePosition(position, velocity, remainingTime);
finalBounceState = new BounceState(Type.NONE, -1, 0.0);
} else {
final Type eventType = firstBounce.type;
final int eventIndex = firstBounce.index;
WrappedVector column = getPrecisionColumn(eventIndex);
updateDynamics(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), column.getBuffer(), eventTime, eventIndex);
reflectVelocity(velocity, eventIndex);
finalBounceState = new BounceState(eventType, eventIndex, remainingTime - eventTime);
}
if (TIMING) {
timer.stopTimer("doBounce");
}
return finalBounceState;
}
use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.
the class ReversibleZigZagOperator method doBounce.
@Override
final BounceState doBounce(BounceState initialBounceState, MinimumTravelInformation firstBounce, WrappedVector position, WrappedVector velocity, WrappedVector action, WrappedVector gradient, WrappedVector momentum) {
if (TIMING) {
timer.startTimer("doBounce");
}
double remainingTime = initialBounceState.remainingTime;
double eventTime = firstBounce.time;
final BounceState finalBounceState;
if (remainingTime < eventTime) {
// No event during remaining time
updatePosition(position, velocity, remainingTime);
finalBounceState = new BounceState(Type.NONE, -1, 0.0);
} else {
if (TIMING) {
timer.startTimer("notUpdateAction");
}
final Type eventType = firstBounce.type;
final int eventIndex = firstBounce.index;
if (TEST_FUSED_DYNAMICS) {
WrappedVector column = getPrecisionColumn(eventIndex);
if (!TEST_NATIVE_INNER_BOUNCE) {
updateDynamics(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), momentum.getBuffer(), column.getBuffer(), eventTime, eventIndex);
} else {
nativeZigZag.updateDynamics(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), momentum.getBuffer(), column.getBuffer(), eventTime, eventIndex, eventType.ordinal());
}
if (firstBounce.type == Type.BOUNDARY) {
// Reflect against boundary
reflectMomentum(momentum, position, eventIndex);
} else {
// Bounce caused by the gradient
setZeroMomentum(momentum, eventIndex);
}
reflectVelocity(velocity, eventIndex);
} else {
if (!TEST_NATIVE_INNER_BOUNCE) {
updatePosition(position, velocity, eventTime);
updateMomentum(momentum, gradient, action, eventTime);
if (firstBounce.type == Type.BOUNDARY) {
// Reflect against boundary
reflectMomentum(momentum, position, eventIndex);
} else {
// Bounce caused by the gradient
setZeroMomentum(momentum, eventIndex);
}
reflectVelocity(velocity, eventIndex);
updateGradient(gradient, eventTime, action);
} else {
if (TEST_CRITICAL_REGION) {
nativeZigZag.innerBounceCriticalRegion(eventIndex, eventIndex, eventType.ordinal());
} else {
nativeZigZag.innerBounce(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), momentum.getBuffer(), eventTime, eventIndex, eventType.ordinal());
}
}
if (TIMING) {
timer.stopTimer("notUpdateAction");
}
updateAction(action, velocity, eventIndex);
}
finalBounceState = new BounceState(eventType, eventIndex, remainingTime - eventTime);
}
if (TIMING) {
timer.stopTimer("doBounce");
}
return finalBounceState;
}
use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.
the class NoUTurnOperator method findReasonableStepSize.
private StepSize findReasonableStepSize(double[] initialPosition, double forcedInitialStepSize) {
if (forcedInitialStepSize != 0) {
return new StepSize(forcedInitialStepSize);
} else {
double stepSize = 0.1;
// final double[] mass = massProvider.getMass();
WrappedVector momentum = preconditioning.drawInitialMomentum();
int count = 1;
double[] position = Arrays.copyOf(initialPosition, dim);
double probBefore = getJointProbability(gradientProvider, momentum);
try {
doLeap(position, momentum, stepSize);
} catch (NumericInstabilityException e) {
handleInstability();
}
double probAfter = getJointProbability(gradientProvider, momentum);
double a = ((probAfter - probBefore) > Math.log(0.5) ? 1 : -1);
double probRatio = Math.exp(probAfter - probBefore);
while (Math.pow(probRatio, a) > Math.pow(2, -a)) {
probBefore = probAfter;
// "one frog jump!"
try {
doLeap(position, momentum, stepSize);
} catch (NumericInstabilityException e) {
handleInstability();
}
probAfter = getJointProbability(gradientProvider, momentum);
probRatio = Math.exp(probAfter - probBefore);
stepSize = Math.pow(2, a) * stepSize;
count++;
if (count > options.findMax) {
throw new RuntimeException("Cannot find a reasonable step-size in " + options.findMax + " " + "iterations");
}
}
leapFrogEngine.setParameter(initialPosition);
return new StepSize(stepSize);
}
}
use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.
the class NoUTurnOperator method buildBaseCase.
private TreeState buildBaseCase(double[] inPosition, double[] inMomentum, int direction, double logSliceU, double stepSize, double initialJointDensity) {
// Make deep copy of position and momentum
double[] position = Arrays.copyOf(inPosition, inPosition.length);
WrappedVector momentum = new WrappedVector.Raw(Arrays.copyOf(inMomentum, inMomentum.length));
leapFrogEngine.setParameter(position);
// "one frog jump!"
try {
doLeap(position, momentum, direction * stepSize);
} catch (NumericInstabilityException e) {
handleInstability();
}
double logJointProbAfter = getJointProbability(gradientProvider, momentum);
final int numNodes = (logSliceU <= logJointProbAfter ? 1 : 0);
final boolean flagContinue = (logSliceU < options.logProbErrorTol + logJointProbAfter);
// Values for dual-averaging
final double acceptProb = Math.min(1.0, Math.exp(logJointProbAfter - initialJointDensity));
final int numAcceptProbStates = 1;
leapFrogEngine.setParameter(inPosition);
return new TreeState(position, momentum.getBuffer(), numNodes, flagContinue, acceptProb, numAcceptProbStates);
}
Aggregations