use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class MomentumStrategy method orient.
@Nonnull
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
final LineSearchCursor orient = inner.orient(subject, measurement, monitor);
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) orient).direction;
@Nonnull final DeltaSet<Layer> newDelta = new DeltaSet<Layer>();
direction.getMap().forEach((layer, delta) -> {
final DoubleBuffer<Layer> prevBuffer = prevDelta.get(layer, delta.target);
newDelta.get(layer, delta.target).addInPlace(ArrayUtil.add(ArrayUtil.multiply(prevBuffer.getDelta(), carryOver), delta.getDelta()));
});
prevDelta = newDelta;
return new SimpleLineSearchCursor(subject, measurement, newDelta);
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method runStep.
/**
* Step runStep result.
*
* @param previousPoint the previous point
* @param phase the phase
* @return the runStep result
*/
@Nonnull
protected StepResult runStep(@Nonnull final PointSample previousPoint, @Nonnull final TrainingPhase phase) {
currentIteration.incrementAndGet();
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> phase.orientation.orient(phase.trainingSubject, previousPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
LineSearchStrategy lineSearchStrategy;
if (phase.lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = phase.lineSearchStrategyMap.get(directionType);
} else {
monitor.log(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = phase.lineSearchFactory.apply(direction.getDirectionType());
phase.lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> {
@Nonnull final FailsafeLineSearchCursor cursor = new FailsafeLineSearchCursor(direction, previousPoint, monitor);
lineSearchStrategy.step(cursor, monitor);
@Nonnull final PointSample restore = cursor.getBest(monitor).restore();
// cursor.step(restore.rate, monitor);
return restore;
});
final PointSample bestPoint = timedLineSearch.result;
if (bestPoint.getMean() > previousPoint.getMean()) {
throw new IllegalStateException(bestPoint.getMean() + " > " + previousPoint.getMean());
}
monitor.log(compare(previousPoint, bestPoint));
return new StepResult(previousPoint, bestPoint, new double[] { timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9 });
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class QuantifyOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
final double denominator = weightDelta.deltaStatistics().rms();
final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
return numerator / (0 == denominator ? 1 : denominator);
}).collect(Collectors.toList());
if (1 == doubleList.size())
return Double.toString(doubleList.get(0));
return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
}));
monitor.log(String.format("Line search stats: %s", dataMap));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class TrustRegionStrategy method orient.
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin);
while (history.size() > maxHistory) {
history.remove(history.size() - 1);
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet<Layer> position(final double alpha) {
reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
@Nonnull
public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((layer, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta)
return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(layer);
if (null != region) {
final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
}
@Override
public void _free() {
cursor.freeRef();
}
};
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class DescribeOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final CharSequence asString = DescribeOrientationWrapper.render(weights, direction);
monitor.log(String.format("Orientation Details: %s", asString));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
Aggregations