Search in sources :

Example 1 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class IterativeTrainer method step.

/**
 * Step point sample.
 *
 * @param direction     the direction
 * @param directionType the direction type
 * @param previous      the previous
 * @return the point sample
 */
public PointSample step(@Nonnull final LineSearchCursor direction, final CharSequence directionType, @Nonnull final PointSample previous) {
    PointSample currentPoint;
    LineSearchStrategy lineSearchStrategy;
    if (lineSearchStrategyMap.containsKey(directionType)) {
        lineSearchStrategy = lineSearchStrategyMap.get(directionType);
    } else {
        log.info(String.format("Constructing line search parameters: %s", directionType));
        lineSearchStrategy = lineSearchFactory.apply(direction.getDirectionType());
        lineSearchStrategyMap.put(directionType, lineSearchStrategy);
    }
    @Nonnull final FailsafeLineSearchCursor wrapped = new FailsafeLineSearchCursor(direction, previous, monitor);
    lineSearchStrategy.step(wrapped, monitor).freeRef();
    currentPoint = wrapped.getBest(monitor);
    wrapped.freeRef();
    return currentPoint;
}
Also used : Nonnull(javax.annotation.Nonnull) LineSearchStrategy(com.simiacryptus.mindseye.opt.line.LineSearchStrategy) PointSample(com.simiacryptus.mindseye.lang.PointSample) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)

Example 2 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class ArmijoWolfeSearch method step.

@Override
public PointSample step(@Nonnull final LineSearchCursor cursor, @Nonnull final TrainingMonitor monitor) {
    // Keep memory of alpha from one iteration to next, but have a bias for growing the value
    alpha = Math.min(maxAlpha, alpha * alphaGrowth);
    double mu = 0;
    double nu = Double.POSITIVE_INFINITY;
    final LineSearchPoint startPoint = cursor.step(0, monitor);
    @Nullable LineSearchPoint lastStep = null;
    try {
        // theta'(0)
        final double startLineDeriv = startPoint.derivative;
        // theta(0)
        final double startValue = startPoint.point.getMean();
        if (0 <= startPoint.derivative) {
            monitor.log(String.format("th(0)=%s;dx=%s (ERROR: Starting derivative negative)", startValue, startLineDeriv));
            LineSearchPoint step = cursor.step(0, monitor);
            PointSample point = step.point;
            point.addRef();
            step.freeRef();
            return point;
        }
        monitor.log(String.format("th(0)=%s;dx=%s", startValue, startLineDeriv));
        int stepBias = 0;
        double bestAlpha = 0;
        double bestValue = startPoint.point.getMean();
        while (true) {
            if (!isAlphaValid()) {
                PointSample point = stepPoint(cursor, monitor, bestAlpha);
                monitor.log(String.format("INVALID ALPHA (%s): th(%s)=%s", alpha, bestAlpha, point.getMean()));
                return point;
            }
            if (mu >= nu - absoluteTolerance) {
                loosenMetaparameters();
                PointSample point = stepPoint(cursor, monitor, bestAlpha);
                monitor.log(String.format("mu >= nu (%s): th(%s)=%s", mu, bestAlpha, point.getMean()));
                return point;
            }
            if (nu - mu < nu * relativeTolerance) {
                loosenMetaparameters();
                PointSample point = stepPoint(cursor, monitor, bestAlpha);
                monitor.log(String.format("mu ~= nu (%s): th(%s)=%s", mu, bestAlpha, point.getMean()));
                return point;
            }
            if (Math.abs(alpha) < minAlpha) {
                PointSample point = stepPoint(cursor, monitor, bestAlpha);
                monitor.log(String.format("MIN ALPHA (%s): th(%s)=%s", alpha, bestAlpha, point.getMean()));
                alpha = minAlpha;
                return point;
            }
            if (Math.abs(alpha) > maxAlpha) {
                PointSample point = stepPoint(cursor, monitor, bestAlpha);
                monitor.log(String.format("MAX ALPHA (%s): th(%s)=%s", alpha, bestAlpha, point.getMean()));
                alpha = maxAlpha;
                return point;
            }
            LineSearchPoint newValue = cursor.step(alpha, monitor);
            synchronized (this) {
                if (null != lastStep)
                    lastStep.freeRef();
                lastStep = newValue;
            }
            double lastValue = lastStep.point.getMean();
            if (bestValue > lastValue) {
                bestAlpha = alpha;
                bestValue = lastValue;
            }
            if (!Double.isFinite(lastValue)) {
                lastValue = Double.POSITIVE_INFINITY;
            }
            if (lastValue > startValue + alpha * c1 * startLineDeriv) {
                // Value did not decrease (enough) - It is gauranteed to decrease given an infitefimal rate; the rate must be less than this; this is a new ceiling
                monitor.log(String.format("Armijo: th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
                nu = alpha;
                stepBias = Math.min(-1, stepBias - 1);
            } else if (isStrongWolfe() && lastStep.derivative > 0) {
                // If the slope is increasing, then we can go lower by choosing a lower rate; this is a new ceiling
                monitor.log(String.format("WOLF (strong): th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
                nu = alpha;
                stepBias = Math.min(-1, stepBias - 1);
            } else if (lastStep.derivative < c2 * startLineDeriv) {
                // Current slope decreases at no more than X - If it is still decreasing that fast, we know we want a rate of least this value; this is a new floor
                monitor.log(String.format("WOLFE (weak): th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
                mu = alpha;
                stepBias = Math.max(1, stepBias + 1);
            } else {
                monitor.log(String.format("END: th(%s)=%s; dx=%s delta=%s", alpha, lastValue, lastStep.derivative, startValue - lastValue));
                PointSample point = lastStep.point;
                point.addRef();
                return point;
            }
            if (!Double.isFinite(nu)) {
                alpha = (1 + Math.abs(stepBias)) * alpha;
            } else if (0.0 == mu) {
                alpha = nu / (1 + Math.abs(stepBias));
            } else {
                alpha = (mu + nu) / 2;
            }
        }
    } finally {
        synchronized (this) {
            if (null != lastStep)
                lastStep.freeRef();
            lastStep = null;
        }
        startPoint.freeRef();
    }
}
Also used : PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable)

Example 3 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class ArmijoWolfeSearch method stepPoint.

private PointSample stepPoint(@Nonnull LineSearchCursor cursor, TrainingMonitor monitor, double bestAlpha) {
    LineSearchPoint step = cursor.step(bestAlpha, monitor);
    PointSample point = step.point;
    point.addRef();
    step.freeRef();
    return point;
}
Also used : PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 4 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class SimpleLineSearchCursor method step.

@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
    if (!Double.isFinite(alpha))
        throw new IllegalArgumentException();
    reset();
    if (0.0 != alpha) {
        direction.accumulate(alpha);
    }
    @Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
    final double dot = direction.dot(sample.delta);
    @Nonnull LineSearchPoint lineSearchPoint = new LineSearchPoint(sample, dot);
    sample.freeRef();
    return lineSearchPoint;
}
Also used : Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 5 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class LBFGS method orient.

@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    // if (getClass().desiredAssertionStatus()) {
    // double verify = subject.measure(monitor).getMean();
    // double input = measurement.getMean();
    // boolean isDifferent = Math.abs(verify - input) > 1e-2;
    // if (isDifferent) throw new AssertionError(String.format("Invalid input point: %s != %s", verify, input));
    // monitor.log(String.format("Verified input point: %s == %s", verify, input));
    // }
    addToHistory(measurement, monitor);
    @Nonnull final List<PointSample> history = Arrays.asList(this.history.toArray(new PointSample[] {}));
    @Nullable final DeltaSet<Layer> result = lbfgs(measurement, monitor, history);
    SimpleLineSearchCursor returnValue;
    if (null == result) {
        @Nonnull DeltaSet<Layer> scale = measurement.delta.scale(-1);
        returnValue = cursor(subject, measurement, "GD", scale);
        scale.freeRef();
    } else {
        returnValue = cursor(subject, measurement, "LBFGS", result);
        result.freeRef();
    }
    while (this.history.size() > (null == result ? minHistory : maxHistory)) {
        @Nullable final PointSample remove = this.history.pollFirst();
        if (verbose) {
            monitor.log(String.format("Removed measurement %s to history. Total: %s", Long.toHexString(System.identityHashCode(remove)), history.size()));
        }
        remove.freeRef();
    }
    return returnValue;
}
Also used : Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable)

Aggregations

PointSample (com.simiacryptus.mindseye.lang.PointSample)33 Nonnull (javax.annotation.Nonnull)24 Layer (com.simiacryptus.mindseye.lang.Layer)16 Nullable (javax.annotation.Nullable)14 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)10 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)9 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)9 StateSet (com.simiacryptus.mindseye.lang.StateSet)8 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)8 List (java.util.List)8 Trainable (com.simiacryptus.mindseye.eval.Trainable)7 Arrays (java.util.Arrays)7 Collectors (java.util.stream.Collectors)7 IterativeStopException (com.simiacryptus.mindseye.lang.IterativeStopException)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)5 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)5 FailsafeLineSearchCursor (com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)5 LineSearchStrategy (com.simiacryptus.mindseye.opt.line.LineSearchStrategy)5