use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.
the class TrainingTester method trainLBFGS.
/**
* Train lbfgs list.
*
* @param log the log
* @param trainable the trainable
* @return the list
*/
@Nonnull
public List<StepRecord> trainLBFGS(@Nonnull final NotebookOutput log, final Trainable trainable) {
log.p("Next, we apply the same optimization using L-BFGS, which is nearly ideal for purely second-order or quadratic functions.");
@Nonnull final List<StepRecord> history = new ArrayList<>();
@Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
try {
log.code(() -> {
return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new LBFGS()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
});
} catch (Throwable e) {
if (isThrowExceptions())
throw new RuntimeException(e);
}
return history;
}
use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.
the class TrustSphereTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
@Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {
@Override
public TrustRegion getRegionPolicy(final Layer layer) {
return new AdaptiveTrustSphere();
}
};
return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.
the class QuadraticLineSearchTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new QuadraticSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.
the class RecursiveSubspace method train.
/**
* Train.
*
* @param monitor the monitor
* @param macroLayer the macro layer
*/
public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
@Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
// @javax.annotation.Nonnull Tensor tensor = new Tensor();
@Nonnull ArrayTrainable trainable = new ArrayTrainable(inner, new Tensor[][] { {} });
inner.freeRef();
// tensor.freeRef();
new IterativeTrainer(trainable).setOrientation(new LBFGS()).setLineSearchFactory(n -> new ArmijoWolfeSearch()).setMonitor(new TrainingMonitor() {
@Override
public void log(String msg) {
monitor.log("\t" + msg);
}
}).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).setTerminateThreshold(terminateThreshold).runAndFree();
trainable.freeRef();
}
use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.
the class StaticRateTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new StaticLearningRate(0.001)).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
});
}
Aggregations