use of com.simiacryptus.mindseye.opt.line.LineSearchStrategy in project MindsEye by SimiaCryptus.
the class IterativeTrainer method step.
/**
* Step point sample.
*
* @param direction the direction
* @param directionType the direction type
* @param previous the previous
* @return the point sample
*/
public PointSample step(@Nonnull final LineSearchCursor direction, final CharSequence directionType, @Nonnull final PointSample previous) {
PointSample currentPoint;
LineSearchStrategy lineSearchStrategy;
if (lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = lineSearchStrategyMap.get(directionType);
} else {
log.info(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = lineSearchFactory.apply(direction.getDirectionType());
lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final FailsafeLineSearchCursor wrapped = new FailsafeLineSearchCursor(direction, previous, monitor);
lineSearchStrategy.step(wrapped, monitor).freeRef();
currentPoint = wrapped.getBest(monitor);
wrapped.freeRef();
return currentPoint;
}
use of com.simiacryptus.mindseye.opt.line.LineSearchStrategy in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method runStep.
/**
* Step runStep result.
*
* @param previousPoint the previous point
* @param phase the phase
* @return the runStep result
*/
@Nonnull
protected StepResult runStep(@Nonnull final PointSample previousPoint, @Nonnull final TrainingPhase phase) {
currentIteration.incrementAndGet();
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> phase.orientation.orient(phase.trainingSubject, previousPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
LineSearchStrategy lineSearchStrategy;
if (phase.lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = phase.lineSearchStrategyMap.get(directionType);
} else {
monitor.log(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = phase.lineSearchFactory.apply(direction.getDirectionType());
phase.lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> {
@Nonnull final FailsafeLineSearchCursor cursor = new FailsafeLineSearchCursor(direction, previousPoint, monitor);
lineSearchStrategy.step(cursor, monitor);
@Nonnull final PointSample restore = cursor.getBest(monitor).restore();
// cursor.step(restore.rate, monitor);
return restore;
});
final PointSample bestPoint = timedLineSearch.result;
if (bestPoint.getMean() > previousPoint.getMean()) {
throw new IllegalStateException(bestPoint.getMean() + " > " + previousPoint.getMean());
}
monitor.log(compare(previousPoint, bestPoint));
return new StepResult(previousPoint, bestPoint, new double[] { timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9 });
}
use of com.simiacryptus.mindseye.opt.line.LineSearchStrategy in project MindsEye by SimiaCryptus.
the class RoundRobinTrainer method run.
/**
* Run double.
*
* @return the double
*/
public double run() {
final long timeoutMs = System.currentTimeMillis() + timeout.toMillis();
PointSample currentPoint = measure();
mainLoop: while (timeoutMs > System.currentTimeMillis() && currentPoint.sum > terminateThreshold) {
if (currentIteration.get() > maxIterations) {
break;
}
currentPoint = measure();
subiterationLoop: for (int subiteration = 0; subiteration < iterationsPerSample; subiteration++) {
final PointSample previousOrientations = currentPoint;
for (@Nonnull final OrientationStrategy<?> orientation : orientations) {
if (currentIteration.incrementAndGet() > maxIterations) {
break;
}
final LineSearchCursor direction = orientation.orient(subject, currentPoint, monitor);
@Nonnull final CharSequence directionType = direction.getDirectionType() + "+" + Long.toHexString(System.identityHashCode(orientation));
LineSearchStrategy lineSearchStrategy;
if (lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = lineSearchStrategyMap.get(directionType);
} else {
log.info(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = lineSearchFactory.apply(directionType);
lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
final PointSample previous = currentPoint;
currentPoint = lineSearchStrategy.step(direction, monitor);
monitor.onStepComplete(new Step(currentPoint, currentIteration.get()));
if (previous.sum == currentPoint.sum) {
monitor.log(String.format("Iteration %s failed, ignoring. Error: %s", currentIteration.get(), currentPoint.sum));
} else {
monitor.log(String.format("Iteration %s complete. Error: %s", currentIteration.get(), currentPoint.sum));
}
}
if (previousOrientations.sum <= currentPoint.sum) {
if (subject.reseed(System.nanoTime())) {
monitor.log(String.format("MacroIteration %s failed, retrying. Error: %s", currentIteration.get(), currentPoint.sum));
break subiterationLoop;
} else {
monitor.log(String.format("MacroIteration %s failed, aborting. Error: %s", currentIteration.get(), currentPoint.sum));
break mainLoop;
}
}
}
}
return null == currentPoint ? Double.NaN : currentPoint.sum;
}
Aggregations