use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class LBFGS method orient.
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
// if (getClass().desiredAssertionStatus()) {
// double verify = subject.measure(monitor).getMean();
// double input = measurement.getMean();
// boolean isDifferent = Math.abs(verify - input) > 1e-2;
// if (isDifferent) throw new AssertionError(String.format("Invalid input point: %s != %s", verify, input));
// monitor.log(String.format("Verified input point: %s == %s", verify, input));
// }
addToHistory(measurement, monitor);
@Nonnull final List<PointSample> history = Arrays.asList(this.history.toArray(new PointSample[] {}));
@Nullable final DeltaSet<Layer> result = lbfgs(measurement, monitor, history);
SimpleLineSearchCursor returnValue;
if (null == result) {
@Nonnull DeltaSet<Layer> scale = measurement.delta.scale(-1);
returnValue = cursor(subject, measurement, "GD", scale);
scale.freeRef();
} else {
returnValue = cursor(subject, measurement, "LBFGS", result);
result.freeRef();
}
while (this.history.size() > (null == result ? minHistory : maxHistory)) {
@Nullable final PointSample remove = this.history.pollFirst();
if (verbose) {
monitor.log(String.format("Removed measurement %s to history. Total: %s", Long.toHexString(System.identityHashCode(remove)), history.size()));
}
remove.freeRef();
}
return returnValue;
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class MomentumStrategy method orient.
@Nonnull
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
final LineSearchCursor orient = inner.orient(subject, measurement, monitor);
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) orient).direction;
@Nonnull final DeltaSet<Layer> newDelta = new DeltaSet<Layer>();
direction.getMap().forEach((layer, delta) -> {
final DoubleBuffer<Layer> prevBuffer = prevDelta.get(layer, delta.target);
newDelta.get(layer, delta.target).addInPlace(ArrayUtil.add(ArrayUtil.multiply(prevBuffer.getDelta(), carryOver), delta.getDelta()));
});
prevDelta = newDelta;
return new SimpleLineSearchCursor(subject, measurement, newDelta);
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class RecursiveSubspace method orient.
@Nonnull
@Override
public SimpleLineSearchCursor orient(@Nonnull Trainable subject, @Nonnull PointSample measurement, @Nonnull TrainingMonitor monitor) {
@Nonnull PointSample origin = measurement.copyFull().backup();
@Nullable Layer macroLayer = buildSubspace(subject, measurement, monitor);
train(monitor, macroLayer);
Result eval = macroLayer.eval((Result) null);
macroLayer.freeRef();
eval.getData().freeRef();
eval.freeRef();
@Nonnull StateSet<Layer> backupCopy = origin.weights.backupCopy();
@Nonnull DeltaSet<Layer> delta = backupCopy.subtract(origin.weights);
backupCopy.freeRef();
origin.restore();
@Nonnull SimpleLineSearchCursor simpleLineSearchCursor = new SimpleLineSearchCursor(subject, origin, delta);
delta.freeRef();
origin.freeRef();
return simpleLineSearchCursor.setDirectionType(CURSOR_LABEL);
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class LayerReweightingStrategy method orient.
@Override
public SimpleLineSearchCursor orient(final Trainable subject, final PointSample measurement, final TrainingMonitor monitor) {
final SimpleLineSearchCursor orient = inner.orient(subject, measurement, monitor);
final DeltaSet<Layer> direction = orient.direction;
direction.getMap().forEach((layer, buffer) -> {
if (null == buffer.getDelta())
return;
final Double weight = getRegionPolicy(layer);
if (null != weight && 0 < weight) {
final DoubleBuffer<Layer> deltaBuffer = direction.get(layer, buffer.target);
@Nonnull final double[] adjusted = ArrayUtil.multiply(deltaBuffer.getDelta(), weight);
for (int i = 0; i < adjusted.length; i++) {
deltaBuffer.getDelta()[i] = adjusted[i];
}
}
});
return orient;
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class QuantifyOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
final double denominator = weightDelta.deltaStatistics().rms();
final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
return numerator / (0 == denominator ? 1 : denominator);
}).collect(Collectors.toList());
if (1 == doubleList.size())
return Double.toString(doubleList.get(0));
return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
}));
monitor.log(String.format("Line search stats: %s", dataMap));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
Aggregations