use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class IterativeTrainer method step.
/**
* Step point sample.
*
* @param direction the direction
* @param directionType the direction type
* @param previous the previous
* @return the point sample
*/
public PointSample step(@Nonnull final LineSearchCursor direction, final CharSequence directionType, @Nonnull final PointSample previous) {
PointSample currentPoint;
LineSearchStrategy lineSearchStrategy;
if (lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = lineSearchStrategyMap.get(directionType);
} else {
log.info(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = lineSearchFactory.apply(direction.getDirectionType());
lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final FailsafeLineSearchCursor wrapped = new FailsafeLineSearchCursor(direction, previous, monitor);
lineSearchStrategy.step(wrapped, monitor).freeRef();
currentPoint = wrapped.getBest(monitor);
wrapped.freeRef();
return currentPoint;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class ArmijoWolfeSearch method step.
@Override
public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
// Keep memory of alpha from one iteration to next, but have a bias for growing the value
alpha = Math.min(maxAlpha, alpha * alphaGrowth);
double mu = 0;
double nu = Double.POSITIVE_INFINITY;
final LineSearchPoint startPoint = cursor.step(0, monitor);
@Nullable LineSearchPoint lastStep = null;
try {
// theta'(0)
final double startLineDeriv = startPoint.derivative;
// theta(0)
final double startValue = startPoint.point.getMean();
if (0 <= startPoint.derivative) {
monitor.log(String.format("th(0)=%s;dx=%s (ERROR: Starting derivative negative)", startValue, startLineDeriv));
LineSearchPoint step = cursor.step(0, monitor);
PointSample point = step.point;
point.addRef();
step.freeRef();
return point;
}
monitor.log(String.format("th(0)=%s;dx=%s", startValue, startLineDeriv));
int stepBias = 0;
double bestAlpha = 0;
double bestValue = startPoint.point.getMean();
while (true) {
if (!isAlphaValid()) {
PointSample point = stepPoint(cursor, monitor, bestAlpha);
monitor.log(String.format("INVALID ALPHA (%s): th(%s)=%s", alpha, bestAlpha, point.getMean()));
return point;
}
if (mu >= nu - absoluteTolerance) {
loosenMetaparameters();
PointSample point = stepPoint(cursor, monitor, bestAlpha);
monitor.log(String.format("mu >= nu (%s): th(%s)=%s", mu, bestAlpha, point.getMean()));
return point;
}
if (nu - mu < nu * relativeTolerance) {
loosenMetaparameters();
PointSample point = stepPoint(cursor, monitor, bestAlpha);
monitor.log(String.format("mu ~= nu (%s): th(%s)=%s", mu, bestAlpha, point.getMean()));
return point;
}
if (Math.abs(alpha) < minAlpha) {
PointSample point = stepPoint(cursor, monitor, bestAlpha);
monitor.log(String.format("MIN ALPHA (%s): th(%s)=%s", alpha, bestAlpha, point.getMean()));
alpha = minAlpha;
return point;
}
if (Math.abs(alpha) > maxAlpha) {
PointSample point = stepPoint(cursor, monitor, bestAlpha);
monitor.log(String.format("MAX ALPHA (%s): th(%s)=%s", alpha, bestAlpha, point.getMean()));
alpha = maxAlpha;
return point;
}
LineSearchPoint newValue = cursor.step(alpha, monitor);
synchronized (this) {
if (null != lastStep)
lastStep.freeRef();
lastStep = newValue;
}
double lastValue = lastStep.point.getMean();
if (bestValue > lastValue) {
bestAlpha = alpha;
bestValue = lastValue;
}
if (!Double.isFinite(lastValue)) {
lastValue = Double.POSITIVE_INFINITY;
}
if (lastValue > startValue + alpha * c1 * startLineDeriv) {
// Value did not decrease (enough) - It is gauranteed to decrease given an infitefimal rate; the rate must be less than this; this is a new ceiling
monitor.log(String.format("Armijo: th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
nu = alpha;
stepBias = Math.min(-1, stepBias - 1);
} else if (isStrongWolfe() && lastStep.derivative > 0) {
// If the slope is increasing, then we can go lower by choosing a lower rate; this is a new ceiling
monitor.log(String.format("WOLF (strong): th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
nu = alpha;
stepBias = Math.min(-1, stepBias - 1);
} else if (lastStep.derivative < c2 * startLineDeriv) {
// Current slope decreases at no more than X - If it is still decreasing that fast, we know we want a rate of least this value; this is a new floor
monitor.log(String.format("WOLFE (weak): th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
mu = alpha;
stepBias = Math.max(1, stepBias + 1);
} else {
monitor.log(String.format("END: th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
PointSample point = lastStep.point;
point.addRef();
return point;
}
if (!Double.isFinite(nu)) {
alpha = (1 + Math.abs(stepBias)) * alpha;
} else if (0.0 == mu) {
alpha = nu / (1 + Math.abs(stepBias));
} else {
alpha = (mu + nu) / 2;
}
}
} finally {
synchronized (this) {
if (null != lastStep)
lastStep.freeRef();
lastStep = null;
}
startPoint.freeRef();
}
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class ArmijoWolfeSearch method stepPoint.
private PointSample stepPoint(@Nonnull LineSearchCursor cursor, TrainingMonitor monitor, double bestAlpha) {
LineSearchPoint step = cursor.step(bestAlpha, monitor);
PointSample point = step.point;
point.addRef();
step.freeRef();
return point;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class SimpleLineSearchCursor method step.
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
if (!Double.isFinite(alpha))
throw new IllegalArgumentException();
reset();
if (0.0 != alpha) {
direction.accumulate(alpha);
}
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
final double dot = direction.dot(sample.delta);
@Nonnull LineSearchPoint lineSearchPoint = new LineSearchPoint(sample, dot);
sample.freeRef();
return lineSearchPoint;
}
use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.
the class LBFGS method orient.
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
// if (getClass().desiredAssertionStatus()) {
// double verify = subject.measure(monitor).getMean();
// double input = measurement.getMean();
// boolean isDifferent = Math.abs(verify - input) > 1e-2;
// if (isDifferent) throw new AssertionError(String.format("Invalid input point: %s != %s", verify, input));
// monitor.log(String.format("Verified input point: %s == %s", verify, input));
// }
addToHistory(measurement, monitor);
@Nonnull final List<PointSample> history = Arrays.asList(this.history.toArray(new PointSample[] {}));
@Nullable final DeltaSet<Layer> result = lbfgs(measurement, monitor, history);
SimpleLineSearchCursor returnValue;
if (null == result) {
@Nonnull DeltaSet<Layer> scale = measurement.delta.scale(-1);
returnValue = cursor(subject, measurement, "GD", scale);
scale.freeRef();
} else {
returnValue = cursor(subject, measurement, "LBFGS", result);
result.freeRef();
}
while (this.history.size() > (null == result ? minHistory : maxHistory)) {
@Nullable final PointSample remove = this.history.pollFirst();
if (verbose) {
monitor.log(String.format("Removed measurement %s to history. Total: %s", Long.toHexString(System.identityHashCode(remove)), history.size()));
}
remove.freeRef();
}
return returnValue;
}
Aggregations