use of com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor in project MindsEye by SimiaCryptus.
the class IterativeTrainer method step.
/**
* Step point sample.
*
* @param direction the direction
* @param directionType the direction type
* @param previous the previous
* @return the point sample
*/
public PointSample step(@Nonnull final LineSearchCursor direction, final CharSequence directionType, @Nonnull final PointSample previous) {
PointSample currentPoint;
LineSearchStrategy lineSearchStrategy;
if (lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = lineSearchStrategyMap.get(directionType);
} else {
log.info(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = lineSearchFactory.apply(direction.getDirectionType());
lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final FailsafeLineSearchCursor wrapped = new FailsafeLineSearchCursor(direction, previous, monitor);
lineSearchStrategy.step(wrapped, monitor).freeRef();
currentPoint = wrapped.getBest(monitor);
wrapped.freeRef();
return currentPoint;
}
use of com.simiacryptus.mindseye.opt.line.FailsafeLineSearchCursor in project MindsEye by SimiaCryptus.
the class ValidatingTrainer method runStep.
/**
* Step runStep result.
*
* @param previousPoint the previous point
* @param phase the phase
* @return the runStep result
*/
@Nonnull
protected StepResult runStep(@Nonnull final PointSample previousPoint, @Nonnull final TrainingPhase phase) {
currentIteration.incrementAndGet();
@Nonnull final TimedResult<LineSearchCursor> timedOrientation = TimedResult.time(() -> phase.orientation.orient(phase.trainingSubject, previousPoint, monitor));
final LineSearchCursor direction = timedOrientation.result;
final CharSequence directionType = direction.getDirectionType();
LineSearchStrategy lineSearchStrategy;
if (phase.lineSearchStrategyMap.containsKey(directionType)) {
lineSearchStrategy = phase.lineSearchStrategyMap.get(directionType);
} else {
monitor.log(String.format("Constructing line search parameters: %s", directionType));
lineSearchStrategy = phase.lineSearchFactory.apply(direction.getDirectionType());
phase.lineSearchStrategyMap.put(directionType, lineSearchStrategy);
}
@Nonnull final TimedResult<PointSample> timedLineSearch = TimedResult.time(() -> {
@Nonnull final FailsafeLineSearchCursor cursor = new FailsafeLineSearchCursor(direction, previousPoint, monitor);
lineSearchStrategy.step(cursor, monitor);
@Nonnull final PointSample restore = cursor.getBest(monitor).restore();
// cursor.step(restore.rate, monitor);
return restore;
});
final PointSample bestPoint = timedLineSearch.result;
if (bestPoint.getMean() > previousPoint.getMean()) {
throw new IllegalStateException(bestPoint.getMean() + " > " + previousPoint.getMean());
}
monitor.log(compare(previousPoint, bestPoint));
return new StepResult(previousPoint, bestPoint, new double[] { timedOrientation.timeNanos / 1e9, timedLineSearch.timeNanos / 1e9 });
}
Aggregations