use of com.simiacryptus.mindseye.eval.TrainableDataMask in project MindsEye by SimiaCryptus.
the class ImageDecompositionLab method train.
/**
* Train.
*
* @param log the log
* @param monitor the monitor
* @param network the network
* @param data the data
* @param timeoutMinutes the timeout minutes
* @param mask the mask
*/
protected void train(@Nonnull final NotebookOutput log, final TrainingMonitor monitor, final Layer network, @Nonnull final Tensor[][] data, final int timeoutMinutes, final boolean... mask) {
log.out("Training for %s minutes, mask=%s", timeoutMinutes, Arrays.toString(mask));
log.code(() -> {
@Nonnull SampledTrainable trainingSubject = new SampledArrayTrainable(data, network, data.length);
trainingSubject = (SampledTrainable) ((TrainableDataMask) trainingSubject).setMask(mask);
@Nonnull final ValidatingTrainer validatingTrainer = new ValidatingTrainer(trainingSubject, new ArrayTrainable(data, network)).setMaxTrainingSize(data.length).setMinTrainingSize(5).setMonitor(monitor).setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(1000);
validatingTrainer.getRegimen().get(0).setOrientation(new GradientDescent()).setLineSearchFactory(name -> name.equals(QQN.CURSOR_NAME) ? new QuadraticSearch().setCurrentRate(1.0) : new QuadraticSearch().setCurrentRate(1.0));
validatingTrainer.run();
});
}
Aggregations