use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class TrustRegionStrategy method orient.
@Nonnull
@Override
public LineSearchCursor orient(@Nonnull final Trainable subject, final PointSample origin, final TrainingMonitor monitor) {
history.add(0, origin);
while (history.size() > maxHistory) {
history.remove(history.size() - 1);
}
final SimpleLineSearchCursor cursor = inner.orient(subject, origin, monitor);
return new LineSearchCursorBase() {
@Nonnull
@Override
public CharSequence getDirectionType() {
return cursor.getDirectionType() + "+Trust";
}
@Nonnull
@Override
public DeltaSet<Layer> position(final double alpha) {
reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
project(adjustedPosVector, new TrainingMonitor());
return adjustedPosVector;
}
@Nonnull
public DeltaSet<Layer> project(@Nonnull final DeltaSet<Layer> deltaIn, final TrainingMonitor monitor) {
final DeltaSet<Layer> originalAlphaDerivative = cursor.direction;
@Nonnull final DeltaSet<Layer> newAlphaDerivative = originalAlphaDerivative.copy();
deltaIn.getMap().forEach((layer, buffer) -> {
@Nullable final double[] delta = buffer.getDelta();
if (null == delta)
return;
final double[] currentPosition = buffer.target;
@Nullable final double[] originalAlphaD = originalAlphaDerivative.get(layer, currentPosition).getDelta();
@Nullable final double[] newAlphaD = newAlphaDerivative.get(layer, currentPosition).getDelta();
@Nonnull final double[] proposedPosition = ArrayUtil.add(currentPosition, delta);
final TrustRegion region = getRegionPolicy(layer);
if (null != region) {
final Stream<double[]> zz = history.stream().map((@Nonnull final PointSample x) -> {
final DoubleBuffer<Layer> d = x.weights.getMap().get(layer);
@Nullable final double[] z = null == d ? null : d.getDelta();
return z;
});
final double[] projectedPosition = region.project(zz.filter(x -> null != x).toArray(i -> new double[i][]), proposedPosition);
if (projectedPosition != proposedPosition) {
for (int i = 0; i < projectedPosition.length; i++) {
delta[i] = projectedPosition[i] - currentPosition[i];
}
@Nonnull final double[] normal = ArrayUtil.subtract(projectedPosition, proposedPosition);
final double normalMagSq = ArrayUtil.dot(normal, normal);
// normalMagSq));
if (0 < normalMagSq) {
final double a = ArrayUtil.dot(originalAlphaD, normal);
if (a != -1) {
@Nonnull final double[] tangent = ArrayUtil.add(originalAlphaD, ArrayUtil.multiply(normal, -a / normalMagSq));
for (int i = 0; i < tangent.length; i++) {
newAlphaD[i] = tangent[i];
}
// double newAlphaDerivSq = ArrayUtil.dot(tangent, tangent);
// double originalAlphaDerivSq = ArrayUtil.dot(originalAlphaD, originalAlphaD);
// assert(newAlphaDerivSq <= originalAlphaDerivSq);
// assert(Math.abs(ArrayUtil.dot(tangent, normal)) <= 1e-4);
// monitor.log(String.format("%s: normalMagSq = %s, newAlphaDerivSq = %s, originalAlphaDerivSq = %s", layer, normalMagSq, newAlphaDerivSq, originalAlphaDerivSq));
}
}
}
}
});
return newAlphaDerivative;
}
@Override
public void reset() {
cursor.reset();
}
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
cursor.reset();
@Nonnull final DeltaSet<Layer> adjustedPosVector = cursor.position(alpha);
@Nonnull final DeltaSet<Layer> adjustedGradient = project(adjustedPosVector, monitor);
adjustedPosVector.accumulate(1);
@Nonnull final PointSample sample = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(sample, adjustedGradient.dot(sample.delta));
}
@Override
public void _free() {
cursor.freeRef();
}
};
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class DescribeOrientationWrapper method orient.
@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
if (cursor instanceof SimpleLineSearchCursor) {
final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
@Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
final CharSequence asString = DescribeOrientationWrapper.render(weights, direction);
monitor.log(String.format("Orientation Details: %s", asString));
} else {
monitor.log(String.format("Non-simple cursor: %s", cursor));
}
return cursor;
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class LayerRateDiagnosticTrainer method run.
/**
* Run map.
*
* @return the map
*/
@Nonnull
public Map<Layer, LayerStats> run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
PointSample measure = measure();
@Nonnull final ArrayList<Layer> layers = new ArrayList<>(measure.weights.getMap().keySet());
while (timeoutMs > System.currentTimeMillis() && measure.sum > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
final PointSample initialPhasePoint = measure();
measure = initialPhasePoint;
for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
if (currentIteration.incrementAndGet() > maxIterations) {
break;
}
{
@Nonnull final SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
final double stepSize = 1e-12 * orient.origin.sum;
@Nonnull final DeltaSet<Layer> pointB = orient.step(stepSize, monitor).point.delta.copy();
@Nonnull final DeltaSet<Layer> pointA = orient.step(0.0, monitor).point.delta.copy();
@Nonnull final DeltaSet<Layer> d1 = pointA;
@Nonnull final DeltaSet<Layer> d2 = d1.add(pointB.scale(-1)).scale(1.0 / stepSize);
@Nonnull final Map<Layer, Double> steps = new HashMap<>();
final double overallStepEstimate = d1.getMagnitude() / d2.getMagnitude();
for (final Layer layer : layers) {
final DoubleBuffer<Layer> a = d2.get(layer, (double[]) null);
final DoubleBuffer<Layer> b = d1.get(layer, (double[]) null);
final double bmag = Math.sqrt(b.deltaStatistics().sumSq());
final double amag = Math.sqrt(a.deltaStatistics().sumSq());
final double dot = a.dot(b) / (amag * bmag);
final double idealSize = bmag / (amag * dot);
steps.put(layer, idealSize);
monitor.log(String.format("Layers stats: %s (%s, %s, %s) => %s", layer, amag, bmag, dot, idealSize));
}
monitor.log(String.format("Estimated ideal rates for layers: %s (%s overall; probed at %s)", steps, overallStepEstimate, stepSize));
}
@Nullable SimpleLineSearchCursor bestOrient = null;
@Nullable PointSample bestPoint = null;
layerLoop: for (@Nonnull final Layer layer : layers) {
@Nonnull SimpleLineSearchCursor orient = (SimpleLineSearchCursor) getOrientation().orient(subject, measure, monitor);
@Nonnull final DeltaSet<Layer> direction = filterDirection(orient.direction, layer);
if (direction.getMagnitude() == 0) {
monitor.log(String.format("Zero derivative for layer %s; skipping", layer));
continue layerLoop;
}
orient = new SimpleLineSearchCursor(orient.subject, orient.origin, direction);
final PointSample previous = measure;
measure = getLineSearchStrategy().step(orient, monitor);
if (isStrict()) {
monitor.log(String.format("Iteration %s reverting. Error: %s", currentIteration.get(), measure.sum));
monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
if (null == bestPoint || bestPoint.sum < measure.sum) {
bestOrient = orient;
bestPoint = measure;
}
getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
orient.step(0, monitor);
measure = previous;
} else if (previous.sum == measure.sum) {
monitor.log(String.format("Iteration %s failed. Error: %s", currentIteration.get(), measure.sum));
} else {
monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), measure.sum));
monitor.log(String.format("Optimal rate for layer %s: %s", layer.getName(), measure.getRate()));
getLayerRates().put(layer, new LayerStats(measure.getRate(), initialPhasePoint.sum - measure.sum));
}
}
monitor.log(String.format("Ideal rates: %s", getLayerRates()));
if (null != bestPoint) {
bestOrient.step(bestPoint.rate, monitor);
}
monitor.onStepComplete(new Step(measure, currentIteration.get()));
}
}
return getLayerRates();
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class GradientDescent method orient.
@Nonnull
@Override
public SimpleLineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
@Nonnull final DeltaSet<Layer> direction = measurement.delta.scale(-1);
final double magnitude = direction.getMagnitude();
if (Math.abs(magnitude) < 1e-10) {
monitor.log(String.format("Zero gradient: %s", magnitude));
} else if (Math.abs(magnitude) < 1e-5) {
monitor.log(String.format("Low gradient: %s", magnitude));
}
@Nonnull SimpleLineSearchCursor gd = new SimpleLineSearchCursor(subject, measurement, direction).setDirectionType("GD");
direction.freeRef();
return gd;
}
use of com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor in project MindsEye by SimiaCryptus.
the class OwlQn method orient.
@Nonnull
@Override
public LineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
@Nonnull final SimpleLineSearchCursor gradient = (SimpleLineSearchCursor) inner.orient(subject, measurement, monitor);
@Nonnull final DeltaSet<Layer> searchDirection = gradient.direction.copy();
@Nonnull final DeltaSet<Layer> orthant = new DeltaSet<Layer>();
for (@Nonnull final Layer layer : getLayers(gradient.direction.getMap().keySet())) {
final double[] weights = gradient.direction.getMap().get(layer).target;
@Nullable final double[] delta = gradient.direction.getMap().get(layer).getDelta();
@Nullable final double[] searchDir = searchDirection.get(layer, weights).getDelta();
@Nullable final double[] suborthant = orthant.get(layer, weights).getDelta();
for (int i = 0; i < searchDir.length; i++) {
final int positionSign = sign(weights[i]);
final int directionSign = sign(delta[i]);
suborthant[i] = 0 == positionSign ? directionSign : positionSign;
searchDir[i] += factor_L1 * (weights[i] < 0 ? -1.0 : 1.0);
if (sign(searchDir[i]) != directionSign) {
searchDir[i] = delta[i];
}
}
assert null != searchDir;
}
return new SimpleLineSearchCursor(subject, measurement, searchDirection) {
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
origin.weights.stream().forEach(d -> d.restore());
@Nonnull final DeltaSet<Layer> currentDirection = direction.copy();
direction.getMap().forEach((layer, buffer) -> {
if (null == buffer.getDelta())
return;
@Nullable final double[] currentDelta = currentDirection.get(layer, buffer.target).getDelta();
for (int i = 0; i < buffer.getDelta().length; i++) {
final double prevValue = buffer.target[i];
final double newValue = prevValue + buffer.getDelta()[i] * alpha;
if (sign(prevValue) != 0 && sign(prevValue) != sign(newValue)) {
currentDelta[i] = 0;
buffer.target[i] = 0;
} else {
buffer.target[i] = newValue;
}
}
});
@Nonnull final PointSample measure = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(measure, currentDirection.dot(measure.delta));
}
}.setDirectionType("OWL/QN");
}
Aggregations