Search in sources :

Example 6 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class QuadraticLineSearchTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new QuadraticSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 7 with GradientDescent

use of com.simiacryptus.mindseye.opt.orient.GradientDescent in project MindsEye by SimiaCryptus.

the class StaticRateTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new StaticLearningRate(0.001)).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Aggregations

GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)7 Nonnull (javax.annotation.Nonnull)7 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)6 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)5 Trainable (com.simiacryptus.mindseye.eval.Trainable)4 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)4 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)4 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)2 QuadraticSearch (com.simiacryptus.mindseye.opt.line.QuadraticSearch)2 StepRecord (com.simiacryptus.mindseye.test.StepRecord)2 ArrayList (java.util.ArrayList)2 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)1 SampledTrainable (com.simiacryptus.mindseye.eval.SampledTrainable)1 TrainableDataMask (com.simiacryptus.mindseye.eval.TrainableDataMask)1 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)1 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)1