Search in sources :

Example 1 with TrainableDataMask

use of com.simiacryptus.mindseye.eval.TrainableDataMask in project MindsEye by SimiaCryptus.

the class ImageDecompositionLab method train.

/**
 * Train.
 *
 * @param log            the log
 * @param monitor        the monitor
 * @param network        the network
 * @param data           the data
 * @param timeoutMinutes the timeout minutes
 * @param mask           the mask
 */
protected void train(@Nonnull final NotebookOutput log, final TrainingMonitor monitor, final Layer network, @Nonnull final Tensor[][] data, final int timeoutMinutes, final boolean... mask) {
    log.out("Training for %s minutes, mask=%s", timeoutMinutes, Arrays.toString(mask));
    log.code(() -> {
        @Nonnull SampledTrainable trainingSubject = new SampledArrayTrainable(data, network, data.length);
        trainingSubject = (SampledTrainable) ((TrainableDataMask) trainingSubject).setMask(mask);
        @Nonnull final ValidatingTrainer validatingTrainer = new ValidatingTrainer(trainingSubject, new ArrayTrainable(data, network)).setMaxTrainingSize(data.length).setMinTrainingSize(5).setMonitor(monitor).setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(1000);
        validatingTrainer.getRegimen().get(0).setOrientation(new GradientDescent()).setLineSearchFactory(name -> name.equals(QQN.CURSOR_NAME) ? new QuadraticSearch().setCurrentRate(1.0) : new QuadraticSearch().setCurrentRate(1.0));
        validatingTrainer.run();
    });
}
Also used : TrainableDataMask(com.simiacryptus.mindseye.eval.TrainableDataMask) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Aggregations

ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)1 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)1 SampledTrainable (com.simiacryptus.mindseye.eval.SampledTrainable)1 TrainableDataMask (com.simiacryptus.mindseye.eval.TrainableDataMask)1 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)1 QuadraticSearch (com.simiacryptus.mindseye.opt.line.QuadraticSearch)1 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)1 Nonnull (javax.annotation.Nonnull)1