use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class QQN method orient.
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, @Nonnull final PointSample origin, @Nonnull final TrainingMonitor monitor) {
inner.addToHistory(origin, monitor);
final SimpleLineSearchCursor lbfgsCursor = inner.orient(subject, origin, monitor);
final DeltaSet<Layer> lbfgs = lbfgsCursor.direction;
@Nonnull final DeltaSet<Layer> gd = origin.delta.scale(-1.0);
final double lbfgsMag = lbfgs.getMagnitude();
final double gdMag = gd.getMagnitude();
if (Math.abs(lbfgsMag - gdMag) / (lbfgsMag + gdMag) > 1e-2) {
@Nonnull final DeltaSet<Layer> scaledGradient = gd.scale(lbfgsMag / gdMag);
monitor.log(String.format("Returning Quadratic Cursor %s GD, %s QN", gdMag, lbfgsMag));
gd.freeRef();
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return CURSOR_NAME;
}
@Override
public DeltaSet<Layer> position(final double t) {
if (!Double.isFinite(t))
throw new IllegalArgumentException();
return scaledGradient.scale(t - t * t).add(lbfgs.scale(t * t));
}
@Override
public void reset() {
lbfgsCursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double t, @Nonnull final TrainingMonitor monitor) {
if (!Double.isFinite(t))
throw new IllegalArgumentException();
reset();
position(t).accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(t);
// monitor.log(String.format("delta buffers %d %d %d %d %d", sample.delta.apply.size(), origin.delta.apply.size(), lbfgs.apply.size(), gd.apply.size(), scaledGradient.apply.size()));
inner.addToHistory(sample, monitor);
@Nonnull final DeltaSet<Layer> tangent = scaledGradient.scale(1 - 2 * t).add(lbfgs.scale(2 * t));
return new LineSearchPoint(sample, tangent.dot(sample.delta));
}
@Override
public void _free() {
scaledGradient.freeRef();
lbfgsCursor.freeRef();
}
};
} else {
gd.freeRef();
return lbfgsCursor;
}
}
Aggregations