Search in sources :

Example 21 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class LBFGS method lbfgs.

private boolean lbfgs(@Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor, @Nonnull List<PointSample> history, @Nonnull DeltaSet<Layer> direction) {
    try {
        @Nonnull DeltaSet<Layer> p = measurement.delta.copy();
        if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
            throw new IllegalStateException("Non-finite value");
        }
        @Nonnull final double[] alphas = new double[history.size()];
        for (int i = history.size() - 2; i >= 0; i--) {
            @Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
            @Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
            final double denominator = sd.dot(yd);
            if (0 == denominator) {
                throw new IllegalStateException("Orientation vanished.");
            }
            alphas[i] = p.dot(sd) / denominator;
            p = p.subtract(yd.scale(alphas[i]));
            if ((!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d))))) {
                throw new IllegalStateException("Non-finite value");
            }
        }
        @Nonnull final DeltaSet<Layer> sk = history.get(history.size() - 1).weights.subtract(history.get(history.size() - 2).weights);
        @Nonnull final DeltaSet<Layer> yk = history.get(history.size() - 1).delta.subtract(history.get(history.size() - 2).delta);
        p = p.scale(sk.dot(yk) / yk.dot(yk));
        if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
            throw new IllegalStateException("Non-finite value");
        }
        for (int i = 0; i < history.size() - 1; i++) {
            @Nonnull final DeltaSet<Layer> sd = history.get(i + 1).weights.subtract(history.get(i).weights);
            @Nonnull final DeltaSet<Layer> yd = history.get(i + 1).delta.subtract(history.get(i).delta);
            final double beta = p.dot(yd) / sd.dot(yd);
            p = p.add(sd.scale(alphas[i] - beta));
            if (!p.stream().parallel().allMatch(y -> Arrays.stream(y.getDelta()).allMatch(d -> Double.isFinite(d)))) {
                throw new IllegalStateException("Non-finite value");
            }
        }
        boolean accept = measurement.delta.dot(p) < 0;
        if (accept) {
            monitor.log("Accepted: " + new Stats(direction, p));
            copy(p, direction);
        } else {
            monitor.log("Rejected: " + new Stats(direction, p));
        }
        return accept;
    } catch (Throwable e) {
        monitor.log(String.format("LBFGS Orientation Error: %s", e.getMessage()));
        return false;
    }
}
Also used : Arrays(java.util.Arrays) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) Collectors(java.util.stream.Collectors) ArrayUtil(com.simiacryptus.util.ArrayUtil) TreeSet(java.util.TreeSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) DoubleBufferSet(com.simiacryptus.mindseye.lang.DoubleBufferSet) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint)

Example 22 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class IterativeTrainer method measure.

/**
 * Measure point sample.
 *
 * @param reset the reset
 * @return the point sample
 */
@Nullable
public PointSample measure(boolean reset) {
    @Nullable PointSample currentPoint = null;
    int retries = 0;
    do {
        if (reset) {
            orientation.reset();
            if (subject.getLayer() instanceof DAGNetwork) {
                ((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
                    if (layer instanceof StochasticComponent)
                        ((StochasticComponent) layer).shuffle(StochasticComponent.random.get().nextLong());
                });
            }
            if (!subject.reseed(System.nanoTime())) {
                if (retries > 0)
                    throw new IterativeStopException("Failed to reset training subject");
            } else {
                monitor.log(String.format("Reset training subject"));
            }
        }
        if (null != currentPoint) {
            currentPoint.freeRef();
        }
        currentPoint = subject.measure(monitor);
    } while (!Double.isFinite(currentPoint.getMean()) && 10 < retries++);
    if (!Double.isFinite(currentPoint.getMean())) {
        currentPoint.freeRef();
        throw new IterativeStopException();
    }
    return currentPoint;
}
Also used : StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) PointSample(com.simiacryptus.mindseye.lang.PointSample) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Nullable(javax.annotation.Nullable) Nullable(javax.annotation.Nullable)

Example 23 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class IterativeTrainer method run.

/**
 * Run double.
 *
 * @return the double
 */
public double run() {
    final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
    long lastIterationTime = System.nanoTime();
    @Nullable PointSample currentPoint = measure(true);
    mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.getMean() > terminateThreshold) {
        if (currentIteration.get() > maxIterations) {
            break;
        }
        currentPoint.freeRef();
        currentPoint = measure(true);
        assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
        subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample || iterationsPerSample <= 0; subiteration++) {
            if (timeoutMs < System.currentTimeMillis()) {
                break mainLoop;
            }
            if (currentIteration.incrementAndGet() > maxIterations) {
                break mainLoop;
            }
            currentPoint.freeRef();
            currentPoint = measure(true);
            @Nullable final PointSample _currentPoint = currentPoint;
            @Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> orientation.orient(subject, _currentPoint, monitor));
            final LineSearchCursor direction = timedOrientation.result;
            final CharSequence directionType = direction.getDirectionType();
            @Nullable final PointSample previous = currentPoint;
            previous.addRef();
            try {
                @Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> step(direction, directionType, previous));
                currentPoint.freeRef();
                currentPoint = timedLineSearch.result;
                final long now = System.nanoTime();
                final CharSequence perfString = String.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f", (now - lastIterationTime) / 1e9, timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9);
                lastIterationTime = now;
                monitor.log(String.format("Fitness changed from %s to %s", previous.getMean(), currentPoint.getMean()));
                if (previous.getMean() <= currentPoint.getMean()) {
                    if (previous.getMean() < currentPoint.getMean()) {
                        monitor.log(String.format("Resetting Iteration %s", perfString));
                        currentPoint.freeRef();
                        currentPoint = direction.step(0, monitor).point;
                    } else {
                        monitor.log(String.format("Static Iteration %s", perfString));
                    }
                    if (subject.reseed(System.nanoTime())) {
                        monitor.log(String.format("Iteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.getMean()));
                        monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
                        break subiterationLoop;
                    } else {
                        monitor.log(String.format("Iteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.getMean()));
                        monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
                        break mainLoop;
                    }
                } else {
                    monitor.log(String.format("Iteration %s complete. Error: %s " + perfString, currentIteration.get(), currentPoint.getMean()));
                }
                monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
            } finally {
                previous.freeRef();
                direction.freeRef();
            }
        }
    }
    if (subject.getLayer() instanceof DAGNetwork) {
        ((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
            if (layer instanceof StochasticComponent)
                ((StochasticComponent) layer).clearNoise();
        });
    }
    double mean = null == currentPoint ? Double.NaN : currentPoint.getMean();
    currentPoint.freeRef();
    return mean;
}
Also used : Nonnull(javax.annotation.Nonnull) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) FailsafeLineSearchCursor(com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable)

Example 24 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class RoundRobinTrainer method run.

/**
 * Run double.
 *
 * @return the double
 */
public double run() {
    final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
    PointSample currentPoint = measure();
    mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.sum > terminateThreshold) {
        if (currentIteration.get() > maxIterations) {
            break;
        }
        currentPoint = measure();
        subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
            final PointSample previousOrientations = currentPoint;
            for (@Nonnull final OrientationStrategy<?> orientation : orientations) {
                if (currentIteration.incrementAndGet() > maxIterations) {
                    break;
                }
                final LineSearchCursor direction = orientation.orient(subject, currentPoint, monitor);
                @Nonnull final CharSequence directionType = direction.getDirectionType() + "+" + Long.toHexString(System.identityHashCode(orientation));
                LineSearchStrategy lineSearchStrategy;
                if (lineSearchStrategyMap.containsKey(directionType)) {
                    lineSearchStrategy = lineSearchStrategyMap.get(directionType);
                } else {
                    log.info(String.format("Constructing line search parameters: %s", directionType));
                    lineSearchStrategy = lineSearchFactory.apply(directionType);
                    lineSearchStrategyMap.put(directionType, lineSearchStrategy);
                }
                final PointSample previous = currentPoint;
                currentPoint = lineSearchStrategy.step(direction, monitor);
                monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
                if (previous.sum == currentPoint.sum) {
                    monitor.log(String.format("Iteration %s failed, ignoring. Error: %s", currentIteration.get(), currentPoint.sum));
                } else {
                    monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), currentPoint.sum));
                }
            }
            if (previousOrientations.sum <= currentPoint.sum) {
                if (subject.reseed(System.nanoTime())) {
                    monitor.log(String.format("MacroIteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.sum));
                    break subiterationLoop;
                } else {
                    monitor.log(String.format("MacroIteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.sum));
                    break mainLoop;
                }
            }
        }
    }
    return null == currentPoint ? Double.NaN : currentPoint.sum;
}
Also used : LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) LineSearchStrategy(com.simiacryptus.mindseye.opt.line.LineSearchStrategy) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Example 25 with PointSample

use of com.simiacryptus.mindseye.lang.PointSample in project MindsEye by SimiaCryptus.

the class RoundRobinTrainer method measure.

/**
 * Measure point sample.
 *
 * @return the point sample
 */
public PointSample measure() {
    PointSample currentPoint;
    int retries = 0;
    do {
        if (!subject.reseed(System.nanoTime()) && retries > 0)
            throw new IterativeStopException();
        if (10 < retries++)
            throw new IterativeStopException();
        currentPoint = subject.measure(monitor);
    } while (!Double.isFinite(currentPoint.sum));
    assert Double.isFinite(currentPoint.sum);
    return currentPoint;
}
Also used : IterativeStopException(com.simiacryptus.mindseye.lang.IterativeStopException) PointSample(com.simiacryptus.mindseye.lang.PointSample)

Aggregations

PointSample (com.simiacryptus.mindseye.lang.PointSample)33 Nonnull (javax.annotation.Nonnull)24 Layer (com.simiacryptus.mindseye.lang.Layer)16 Nullable (javax.annotation.Nullable)14 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)10 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)9 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)9 StateSet (com.simiacryptus.mindseye.lang.StateSet)8 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)8 List (java.util.List)8 Trainable (com.simiacryptus.mindseye.eval.Trainable)7 Arrays (java.util.Arrays)7 Collectors (java.util.stream.Collectors)7 IterativeStopException (com.simiacryptus.mindseye.lang.IterativeStopException)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)5 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)5 FailsafeLineSearchCursor (com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor)5 LineSearchStrategy (com.simiacryptus.mindseye.opt.line.LineSearchStrategy)5