use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class IterativeTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
long lastIterationTime = System.nanoTime();
@Nullable PointSample currentPoint = measure(true);
mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.getMean() > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
currentPoint.freeRef();
currentPoint = measure(true);
assert 0 < currentPoint.delta.getMap().size() : "Nothing to optimize";
subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample || iterationsPerSample <= 0; subiteration++) {
if (timeoutMs < System.currentTimeMillis()) {
break mainLoop;
}
if (currentIteration.incrementAndGet() > maxIterations) {
break mainLoop;
}
currentPoint.freeRef();
currentPoint = measure(true);
@Nullable final PointSample _currentPoint = currentPoint;
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> orientation.orient(subject, _currentPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
@Nullable final PointSample previous = currentPoint;
previous.addRef();
try {
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> step(direction, directionType, previous));
currentPoint.freeRef();
currentPoint = timedLineSearch.result;
final long now = System.nanoTime();
final CharSequence perfString = String.format("Total: %.4f; Orientation: %.4f; Line Search: %.4f", (now - lastIterationTime) / 1e9, timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9);
lastIterationTime = now;
monitor.log(String.format("Fitness changed from %s to %s", previous.getMean(), currentPoint.getMean()));
if (previous.getMean() <= currentPoint.getMean()) {
if (previous.getMean() < currentPoint.getMean()) {
monitor.log(String.format("Resetting Iteration %s", perfString));
currentPoint.freeRef();
currentPoint = direction.step(0, monitor).point;
} else {
monitor.log(String.format("Static Iteration %s", perfString));
}
if (subject.reseed(System.nanoTime())) {
monitor.log(String.format("Iteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.getMean()));
monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
break subiterationLoop;
} else {
monitor.log(String.format("Iteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.getMean()));
monitor.log(String.format("Previous Error: %s -> %s", previous.getRate(), previous.getMean()));
break mainLoop;
}
} else {
monitor.log(String.format("Iteration %s complete. Error: %s " + perfString, currentIteration.get(), currentPoint.getMean()));
}
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
} finally {
previous.freeRef();
direction.freeRef();
}
}
}
if (subject.getLayer() instanceof DAGNetwork) {
((DAGNetwork) subject.getLayer()).visitLayers(layer -> {
if (layer instanceof StochasticComponent)
((StochasticComponent) layer).clearNoise();
});
}
double mean = null == currentPoint ? Double.NaN : currentPoint.getMean();
currentPoint.freeRef();
return mean;
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class RoundRobinTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
PointSample currentPoint = measure();
mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.sum > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
currentPoint = measure();
subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
final PointSample previousOrientations = currentPoint;
for (@Nonnull final OrientationStrategy<?> orientation : orientations) {
if (currentIteration.incrementAndGet() > maxIterations) {
break;
}
final LineSearchCursor direction = orientation.orient(subject, currentPoint, monitor);
@Nonnull final CharSequence directionType = direction.getDirectionType() + "+" + Long.toHexString(System.identityHashCode(orientation));
LineSearchStrategy lineSearchStrategy;
if (lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = lineSearchStrategyMap.get(directionType);
} else {
log.info(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = lineSearchFactory.apply(directionType);
lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
final PointSample previous = currentPoint;
currentPoint = lineSearchStrategy.step(direction, monitor);
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
if (previous.sum == currentPoint.sum) {
monitor.log(String.format("Iteration %s failed, ignoring. Error: %s", currentIteration.get(), currentPoint.sum));
} else {
monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), currentPoint.sum));
}
}
if (previousOrientations.sum <= currentPoint.sum) {
if (subject.reseed(System.nanoTime())) {
monitor.log(String.format("MacroIteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.sum));
break subiterationLoop;
} else {
monitor.log(String.format("MacroIteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.sum));
break mainLoop;
}
}
}
}
return null == currentPoint ? Double.NaN : currentPoint.sum;
}
use of com.simiacryptus.mindseye.opt.line.LineSearchCursor in project MindsEye by SimiaCryptus.
the class OwlQn method orient.
@Nonnull
@Override
public LineSearchCursor orient(final Trainable subject, @Nonnull final PointSample measurement, final TrainingMonitor monitor) {
@Nonnull final SimpleLineSearchCursor gradient = (SimpleLineSearchCursor) inner.orient(subject, measurement, monitor);
@Nonnull final DeltaSet<Layer> searchDirection = gradient.direction.copy();
@Nonnull final DeltaSet<Layer> orthant = new DeltaSet<Layer>();
for (@Nonnull final Layer layer : getLayers(gradient.direction.getMap().keySet())) {
final double[] weights = gradient.direction.getMap().get(layer).target;
@Nullable final double[] delta = gradient.direction.getMap().get(layer).getDelta();
@Nullable final double[] searchDir = searchDirection.get(layer, weights).getDelta();
@Nullable final double[] suborthant = orthant.get(layer, weights).getDelta();
for (int i = 0; i < searchDir.length; i++) {
final int positionSign = sign(weights[i]);
final int directionSign = sign(delta[i]);
suborthant[i] = 0 == positionSign ? directionSign : positionSign;
searchDir[i] += factor_L1 * (weights[i] < 0 ? -1.0 : 1.0);
if (sign(searchDir[i]) != directionSign) {
searchDir[i] = delta[i];
}
}
assert null != searchDir;
}
return new SimpleLineSearchCursor(subject, measurement, searchDirection) {
@Nonnull
@Override
public LineSearchPoint step(final double alpha, final TrainingMonitor monitor) {
origin.weights.stream().forEach(d -> d.restore());
@Nonnull final DeltaSet<Layer> currentDirection = direction.copy();
direction.getMap().forEach((layer, buffer) -> {
if (null == buffer.getDelta())
return;
@Nullable final double[] currentDelta = currentDirection.get(layer, buffer.target).getDelta();
for (int i = 0; i < buffer.getDelta().length; i++) {
final double prevValue = buffer.target[i];
final double newValue = prevValue + buffer.getDelta()[i] * alpha;
if (sign(prevValue) != 0 && sign(prevValue) != sign(newValue)) {
currentDelta[i] = 0;
buffer.target[i] = 0;
} else {
buffer.target[i] = newValue;
}
}
});
@Nonnull final PointSample measure = subject.measure(monitor).setRate(alpha);
return new LineSearchPoint(measure, currentDirection.dot(measure.delta));
}
}.setDirectionType("OWL/QN");
}
Aggregations