use of com.simiacryptus.mindseye.opt.line.LineSearchCursorBase in project MindsEye by SimiaCryptus.
the class TrustRegionStrategy method orient.
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin);
while (history.size() > maxHistory) {
history.remove(history.size() - 1);
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet<Layer> position(final double alpha) {
reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
@Nonnull
public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((layer, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta)
return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(layer);
if (null != region) {
final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
}
@Override
public void _free() {
cursor.freeRef();
}
};
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursorBase in project MindsEye by SimiaCryptus.
the class QQN method orient.
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, @Nonnull final PointSample origin, @Nonnull final TrainingMonitor monitor) {
inner.addToHistory(origin, monitor);
final SimpleLineSearchCursor lbfgsCursor = inner.orient(subject, origin, monitor);
final DeltaSet<Layer> lbfgs = lbfgsCursor.direction;
@Nonnull final DeltaSet<Layer> gd = origin.delta.scale(-1.0);
final double lbfgsMag = lbfgs.getMagnitude();
final double gdMag = gd.getMagnitude();
if (Math.abs(lbfgsMag - gdMag) / (lbfgsMag + gdMag) > 1e-2) {
@Nonnull final DeltaSet<Layer> scaledGradient = gd.scale(lbfgsMag / gdMag);
monitor.log(String.format("Returning Quadratic Cursor %s GD, %s QN", gdMag, lbfgsMag));
gd.freeRef();
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return CURSOR_NAME;
}
@Override
public DeltaSet<Layer> position(final double t) {
if (!Double.isFinite(t))
throw new IllegalArgumentException();
return scaledGradient.scale(t - t * t).add(lbfgs.scale(t * t));
}
@Override
public void reset() {
lbfgsCursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double t, @Nonnull final TrainingMonitor monitor) {
if (!Double.isFinite(t))
throw new IllegalArgumentException();
reset();
position(t).accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(t);
// monitor.log(String.format("delta buffers %d %d %d %d %d", sample.delta.apply.size(), origin.delta.apply.size(), lbfgs.apply.size(), gd.apply.size(), scaledGradient.apply.size()));
inner.addToHistory(sample, monitor);
@Nonnull final DeltaSet<Layer> tangent = scaledGradient.scale(1 - 2 * t).add(lbfgs.scale(2 * t));
return new LineSearchPoint(sample, tangent.dot(sample.delta));
}
@Override
public void _free() {
scaledGradient.freeRef();
lbfgsCursor.freeRef();
}
};
} else {
gd.freeRef();
return lbfgsCursor;
}
}
Aggregations