Search in sources :

Example 16 with WrappedVector

use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.

the class IrreversibleZigZagOperator method doBounce.

@SuppressWarnings("Duplicates")
@Override
BounceState doBounce(BounceState initialBounceState, MinimumTravelInformation firstBounce, WrappedVector position, WrappedVector velocity, WrappedVector action, WrappedVector gradient, WrappedVector momentum) {
    if (TIMING) {
        timer.startTimer("doBounce");
    }
    double remainingTime = initialBounceState.remainingTime;
    double eventTime = firstBounce.time;
    final BounceState finalBounceState;
    if (remainingTime < eventTime) {
        // No event during remaining time
        updatePosition(position, velocity, remainingTime);
        finalBounceState = new BounceState(Type.NONE, -1, 0.0);
    } else {
        final Type eventType = firstBounce.type;
        final int eventIndex = firstBounce.index;
        WrappedVector column = getPrecisionColumn(eventIndex);
        updateDynamics(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), column.getBuffer(), eventTime, eventIndex);
        reflectVelocity(velocity, eventIndex);
        finalBounceState = new BounceState(eventType, eventIndex, remainingTime - eventTime);
    }
    if (TIMING) {
        timer.stopTimer("doBounce");
    }
    return finalBounceState;
}
Also used : WrappedVector(dr.math.matrixAlgebra.WrappedVector)

Example 17 with WrappedVector

use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.

the class ReversibleZigZagOperator method doBounce.

@Override
final BounceState doBounce(BounceState initialBounceState, MinimumTravelInformation firstBounce, WrappedVector position, WrappedVector velocity, WrappedVector action, WrappedVector gradient, WrappedVector momentum) {
    if (TIMING) {
        timer.startTimer("doBounce");
    }
    double remainingTime = initialBounceState.remainingTime;
    double eventTime = firstBounce.time;
    final BounceState finalBounceState;
    if (remainingTime < eventTime) {
        // No event during remaining time
        updatePosition(position, velocity, remainingTime);
        finalBounceState = new BounceState(Type.NONE, -1, 0.0);
    } else {
        if (TIMING) {
            timer.startTimer("notUpdateAction");
        }
        final Type eventType = firstBounce.type;
        final int eventIndex = firstBounce.index;
        if (TEST_FUSED_DYNAMICS) {
            WrappedVector column = getPrecisionColumn(eventIndex);
            if (!TEST_NATIVE_INNER_BOUNCE) {
                updateDynamics(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), momentum.getBuffer(), column.getBuffer(), eventTime, eventIndex);
            } else {
                nativeZigZag.updateDynamics(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), momentum.getBuffer(), column.getBuffer(), eventTime, eventIndex, eventType.ordinal());
            }
            if (firstBounce.type == Type.BOUNDARY) {
                // Reflect against boundary
                reflectMomentum(momentum, position, eventIndex);
            } else {
                // Bounce caused by the gradient
                setZeroMomentum(momentum, eventIndex);
            }
            reflectVelocity(velocity, eventIndex);
        } else {
            if (!TEST_NATIVE_INNER_BOUNCE) {
                updatePosition(position, velocity, eventTime);
                updateMomentum(momentum, gradient, action, eventTime);
                if (firstBounce.type == Type.BOUNDARY) {
                    // Reflect against boundary
                    reflectMomentum(momentum, position, eventIndex);
                } else {
                    // Bounce caused by the gradient
                    setZeroMomentum(momentum, eventIndex);
                }
                reflectVelocity(velocity, eventIndex);
                updateGradient(gradient, eventTime, action);
            } else {
                if (TEST_CRITICAL_REGION) {
                    nativeZigZag.innerBounceCriticalRegion(eventIndex, eventIndex, eventType.ordinal());
                } else {
                    nativeZigZag.innerBounce(position.getBuffer(), velocity.getBuffer(), action.getBuffer(), gradient.getBuffer(), momentum.getBuffer(), eventTime, eventIndex, eventType.ordinal());
                }
            }
            if (TIMING) {
                timer.stopTimer("notUpdateAction");
            }
            updateAction(action, velocity, eventIndex);
        }
        finalBounceState = new BounceState(eventType, eventIndex, remainingTime - eventTime);
    }
    if (TIMING) {
        timer.stopTimer("doBounce");
    }
    return finalBounceState;
}
Also used : WrappedVector(dr.math.matrixAlgebra.WrappedVector)

Example 18 with WrappedVector

use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.

the class NoUTurnOperator method findReasonableStepSize.

private StepSize findReasonableStepSize(double[] initialPosition, double forcedInitialStepSize) {
    if (forcedInitialStepSize != 0) {
        return new StepSize(forcedInitialStepSize);
    } else {
        double stepSize = 0.1;
        // final double[] mass = massProvider.getMass();
        WrappedVector momentum = preconditioning.drawInitialMomentum();
        int count = 1;
        double[] position = Arrays.copyOf(initialPosition, dim);
        double probBefore = getJointProbability(gradientProvider, momentum);
        try {
            doLeap(position, momentum, stepSize);
        } catch (NumericInstabilityException e) {
            handleInstability();
        }
        double probAfter = getJointProbability(gradientProvider, momentum);
        double a = ((probAfter - probBefore) > Math.log(0.5) ? 1 : -1);
        double probRatio = Math.exp(probAfter - probBefore);
        while (Math.pow(probRatio, a) > Math.pow(2, -a)) {
            probBefore = probAfter;
            // "one frog jump!"
            try {
                doLeap(position, momentum, stepSize);
            } catch (NumericInstabilityException e) {
                handleInstability();
            }
            probAfter = getJointProbability(gradientProvider, momentum);
            probRatio = Math.exp(probAfter - probBefore);
            stepSize = Math.pow(2, a) * stepSize;
            count++;
            if (count > options.findMax) {
                throw new RuntimeException("Cannot find a reasonable step-size in " + options.findMax + " " + "iterations");
            }
        }
        leapFrogEngine.setParameter(initialPosition);
        return new StepSize(stepSize);
    }
}
Also used : WrappedVector(dr.math.matrixAlgebra.WrappedVector)

Example 19 with WrappedVector

use of dr.math.matrixAlgebra.WrappedVector in project beast-mcmc by beast-dev.

the class NoUTurnOperator method buildBaseCase.

private TreeState buildBaseCase(double[] inPosition, double[] inMomentum, int direction, double logSliceU, double stepSize, double initialJointDensity) {
    // Make deep copy of position and momentum
    double[] position = Arrays.copyOf(inPosition, inPosition.length);
    WrappedVector momentum = new WrappedVector.Raw(Arrays.copyOf(inMomentum, inMomentum.length));
    leapFrogEngine.setParameter(position);
    // "one frog jump!"
    try {
        doLeap(position, momentum, direction * stepSize);
    } catch (NumericInstabilityException e) {
        handleInstability();
    }
    double logJointProbAfter = getJointProbability(gradientProvider, momentum);
    final int numNodes = (logSliceU <= logJointProbAfter ? 1 : 0);
    final boolean flagContinue = (logSliceU < options.logProbErrorTol + logJointProbAfter);
    // Values for dual-averaging
    final double acceptProb = Math.min(1.0, Math.exp(logJointProbAfter - initialJointDensity));
    final int numAcceptProbStates = 1;
    leapFrogEngine.setParameter(inPosition);
    return new TreeState(position, momentum.getBuffer(), numNodes, flagContinue, acceptProb, numAcceptProbStates);
}
Also used : WrappedVector(dr.math.matrixAlgebra.WrappedVector)

Aggregations

WrappedVector (dr.math.matrixAlgebra.WrappedVector)19 DenseMatrix64F (org.ejml.data.DenseMatrix64F)4 ReadableVector (dr.math.matrixAlgebra.ReadableVector)2 WrappedMatrix (dr.math.matrixAlgebra.WrappedMatrix)2 ContinuousTraitDataModel (dr.evomodel.treedatalikelihood.continuous.ContinuousTraitDataModel)1 ElementaryVectorDataModel (dr.evomodel.treedatalikelihood.continuous.ElementaryVectorDataModel)1 Parameter (dr.inference.model.Parameter)1 Utils.setParameter (dr.math.matrixAlgebra.ReadableVector.Utils.setParameter)1