Search in sources :

Example 11 with SimpleLineSearchCursor

use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.

the class QQN method orient.

@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, @Nonnull final PointSample origin, @Nonnull final TrainingMonitor monitor) {
    inner.addToHistory(origin, monitor);
    final SimpleLineSearchCursor lbfgsCursor = inner.orient(subject, origin, monitor);
    final DeltaSet<Layer> lbfgs = lbfgsCursor.direction;
    @Nonnull final DeltaSet<Layer> gd = origin.delta.scale(-1.0);
    final double lbfgsMag = lbfgs.getMagnitude();
    final double gdMag = gd.getMagnitude();
    if (Math.abs(lbfgsMag - gdMag) / (lbfgsMag + gdMag) > 1e-2) {
        @Nonnull final DeltaSet<Layer> scaledGradient = gd.scale(lbfgsMag / gdMag);
        monitor.log(String.format("Returning Quadratic Cursor %s GD, %s QN", gdMag, lbfgsMag));
        gd.freeRef();
        return new LineSearchCursorBase() {

            @Nonnull
            @Override
            public CharSequence getDirectionType() {
                return CURSOR_NAME;
            }

            @Override
            public DeltaSet<Layer> position(final double t) {
                if (!Double.isFinite(t))
                    throw new IllegalArgumentException();
                return scaledGradient.scale(t - t * t).add(lbfgs.scale(t * t));
            }

            @Override
            public void reset() {
                lbfgsCursor.reset();
            }

            @Nonnull
            @Override
            public LineSearchPoint step(final double t, @Nonnull final TrainingMonitor monitor) {
                if (!Double.isFinite(t))
                    throw new IllegalArgumentException();
                reset();
                position(t).accumulate(1);
                @Nonnull final PointSample sample = subject.measure(monitor).setRate(t);
                // monitor.log(String.format("delta buffers %d %d %d %d %d", sample.delta.apply.size(), origin.delta.apply.size(), lbfgs.apply.size(), gd.apply.size(), scaledGradient.apply.size()));
                inner.addToHistory(sample, monitor);
                @Nonnull final DeltaSet<Layer> tangent = scaledGradient.scale(1 - 2 * t).add(lbfgs.scale(2 * t));
                return new LineSearchPoint(sample, tangent.dot(sample.delta));
            }

            @Override
            public void _free() {
                scaledGradient.freeRef();
                lbfgsCursor.freeRef();
            }
        };
    } else {
        gd.freeRef();
        return lbfgsCursor;
    }
}
Also used : TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchPoint(com.simiacryptus.mindseye.opt.line.LineSearchPoint) PointSample(com.simiacryptus.mindseye.lang.PointSample) Layer(com.simiacryptus.mindseye.lang.Layer) LineSearchCursorBase(com.simiacryptus.mindseye.opt.line.LineSearchCursorBase)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)11 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)11 Nonnull (javax.annotation.Nonnull)11 PointSample (com.simiacryptus.mindseye.lang.PointSample)7 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)5 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)5 Nullable (javax.annotation.Nullable)5 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)4 Trainable (com.simiacryptus.mindseye.eval.Trainable)3 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)3 LineSearchPoint (com.simiacryptus.mindseye.opt.line.LineSearchPoint)3 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)2 LineSearchCursorBase (com.simiacryptus.mindseye.opt.line.LineSearchCursorBase)2 List (java.util.List)2 Map (java.util.Map)2 Collectors (java.util.stream.Collectors)2 Result (com.simiacryptus.mindseye.lang.Result)1 StateSet (com.simiacryptus.mindseye.lang.StateSet)1 FullyConnectedLayer (com.simiacryptus.mindseye.layers.java.FullyConnectedLayer)1 TrustRegion (com.simiacryptus.mindseye.opt.region.TrustRegion)1