use of com.simiacryptus.mindseye.opt.line.StaticLearningRate in project MindsEye by SimiaCryptus.
the class RecursiveSubspaceTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 1000), new ArrayTrainable(trainingData, supervisedNetwork, 1000).cached()).setMonitor(monitor);
trainer.getRegimen().get(0).setOrientation(getOrientation()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new StaticLearningRate(1.0) : new QuadraticSearch());
return trainer.setTimeout(15, TimeUnit.MINUTES).setMaxIterations(500).run();
});
}
use of com.simiacryptus.mindseye.opt.line.StaticLearningRate in project MindsEye by SimiaCryptus.
the class TrainingTester method trainMagic.
/**
* Train lbfgs list.
*
* @param log the log
* @param trainable the trainable
* @return the list
*/
@Nonnull
public List<StepRecord> trainMagic(@Nonnull final NotebookOutput log, final Trainable trainable) {
log.p("Now we train using an experimental optimizer:");
@Nonnull final List<StepRecord> history = new ArrayList<>();
@Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
try {
log.code(() -> {
return new IterativeTrainer(trainable).setLineSearchFactory(label -> new StaticLearningRate(1.0)).setOrientation(new RecursiveSubspace() {
@Override
public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
@Nonnull Tensor[][] nullData = { { new Tensor() } };
@Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
@Nonnull ArrayTrainable trainable1 = new ArrayTrainable(inner, nullData);
inner.freeRef();
new IterativeTrainer(trainable1).setOrientation(new QQN()).setLineSearchFactory(n -> new QuadraticSearch().setCurrentRate(n.equals(QQN.CURSOR_NAME) ? 1.0 : 1e-4)).setMonitor(new TrainingMonitor() {
@Override
public void log(String msg) {
monitor.log("\t" + msg);
}
}).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).runAndFree();
trainable1.freeRef();
for (@Nonnull Tensor[] tensors : nullData) {
for (@Nonnull Tensor tensor : tensors) {
tensor.freeRef();
}
}
}
}).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
});
} catch (Throwable e) {
if (isThrowExceptions())
throw new RuntimeException(e);
}
return history;
}
Aggregations