Search in sources :

Example 1 with SimpleLineSearchCursor

use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.

the class LBFGS method orient.

@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    // if (getClass().desiredAssertionStatus()) {
    // double verify = subject.measure(monitor).getMean();
    // double input = measurement.getMean();
    // boolean isDifferent = Math.abs(verify - input) > 1e-2;
    // if (isDifferent) throw new AssertionError(String.format("Invalid input point: %s != %s", verify, input));
    // monitor.log(String.format("Verified input point: %s == %s", verify, input));
    // }
    addToHistory(measurement, monitor);
    @Nonnull final List<PointSample> history = Arrays.asList(this.history.toArray(new PointSample[] {}));
    @Nullable final DeltaSet<Layer> result = lbfgs(measurement, monitor, history);
    SimpleLineSearchCursor returnValue;
    if (null == result) {
        @Nonnull DeltaSet<Layer> scale = measurement.delta.scale(-1);
        returnValue = cursor(subject, measurement, "GD", scale);
        scale.freeRef();
    } else {
        returnValue = cursor(subject, measurement, "LBFGS", result);
        result.freeRef();
    }
    while (this.history.size() > (null == result ? minHistory : maxHistory)) {
        @Nullable final PointSample remove = this.history.pollFirst();
        if (verbose) {
            monitor.log(String.format("Removed measurement %s to history. Total: %s", Long.toHexString(System.identityHashCode(remove)), history.size()));
        }
        remove.freeRef();
    }
    return returnValue;
}
Also used : Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable)

Example 2 with SimpleLineSearchCursor

use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.

the class MomentumStrategy method orient.

@Nonnull
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
    final LineSearchCursor orient = inner.orient(subject, measurement, monitor);
    final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) orient).direction;
    @Nonnull final DeltaSet<Layer> newDelta = new DeltaSet<Layer>();
    direction.getMap().forEach((layer, delta) -> {
        final DoubleBuffer<Layer> prevBuffer = prevDelta.get(layer, delta.target);
        newDelta.get(layer, delta.target).addInPlace(ArrayUtil.add(ArrayUtil.multiply(prevBuffer.getDelta(), carryOver), delta.getDelta()));
    });
    prevDelta = newDelta;
    return new SimpleLineSearchCursor(subject, measurement, newDelta);
}
Also used : SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Layer(com.simiacryptus.mindseye.lang.Layer) Nonnull(javax.annotation.Nonnull)

Example 3 with SimpleLineSearchCursor

use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.

the class RecursiveSubspace method orient.

@Nonnull
@Override
public SimpleLineSearchCursor orient(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
    @Nonnull PointSample origin = measurement.copyFull().backup();
    @Nullable Layer macroLayer = buildSubspace(subject, measurement, monitor);
    train(monitor, macroLayer);
    Result eval = macroLayer.eval((Result) null);
    macroLayer.freeRef();
    eval.getData().freeRef();
    eval.freeRef();
    @Nonnull StateSet<Layer> backupCopy = origin.weights.backupCopy();
    @Nonnull DeltaSet<Layer> delta = backupCopy.subtract(origin.weights);
    backupCopy.freeRef();
    origin.restore();
    @Nonnull SimpleLineSearchCursor simpleLineSearchCursor = new SimpleLineSearchCursor(subject, origin, delta);
    delta.freeRef();
    origin.freeRef();
    return simpleLineSearchCursor.setDirectionType(CURSOR_LABEL);
}
Also used : Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) PointSample(com.simiacryptus.mindseye.lang.PointSample) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 4 with SimpleLineSearchCursor

use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.

the class LayerReweightingStrategy method orient.

@Override
public SimpleLineSearchCursor orient(final Trainable subject, final PointSample measurement, final TrainingMonitor monitor) {
    final SimpleLineSearchCursor orient = inner.orient(subject, measurement, monitor);
    final DeltaSet<Layer> direction = orient.direction;
    direction.getMap().forEach((layer, buffer) -> {
        if (null == buffer.getDelta())
            return;
        final Double weight = getRegionPolicy(layer);
        if (null != weight && 0 < weight) {
            final DoubleBuffer<Layer> deltaBuffer = direction.get(layer, buffer.target);
            @Nonnull final double[] adjusted = ArrayUtil.multiply(deltaBuffer.getDelta(), weight);
            for (int i = 0; i < adjusted.length; i++) {
                deltaBuffer.getDelta()[i] = adjusted[i];
            }
        }
    });
    return orient;
}
Also used : Nonnull(javax.annotation.Nonnull) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) Layer(com.simiacryptus.mindseye.lang.Layer)

Example 5 with SimpleLineSearchCursor

use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)11 SimpleLineSearchCursor (com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor)11 Nonnull (javax.annotation.Nonnull)11 PointSample (com.simiacryptus.mindseye.lang.PointSample)7 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)5 LineSearchCursor (com.simiacryptus.mindseye.opt.line.LineSearchCursor)5 Nullable (javax.annotation.Nullable)5 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)4 Trainable (com.simiacryptus.mindseye.eval.Trainable)3 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)3 LineSearchPoint (com.simiacryptus.mindseye.opt.line.LineSearchPoint)3 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)2 LineSearchCursorBase (com.simiacryptus.mindseye.opt.line.LineSearchCursorBase)2 List (java.util.List)2 Map (java.util.Map)2 Collectors (java.util.stream.Collectors)2 Result (com.simiacryptus.mindseye.lang.Result)1 StateSet (com.simiacryptus.mindseye.lang.StateSet)1 FullyConnectedLayer (com.simiacryptus.mindseye.layers.java.FullyConnectedLayer)1 TrustRegion (com.simiacryptus.mindseye.opt.region.TrustRegion)1