Search in sources :

Example 21 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class L1NormalizationTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new L12Normalizer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000)) {

            @Override
            public Layer getLayer() {
                return inner.getLayer();
            }

            @Override
            protected double getL1(final Layer layer) {
                return 1.0;
            }

            @Override
            protected double getL2(final Layer layer) {
                return 0;
            }
        };
        return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) L12Normalizer(com.simiacryptus.mindseye.eval.L12Normalizer) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer)

Example 22 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class SimpleGradientDescentTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final ArrayList<Tensor[]> trainingList = new ArrayList<>(Arrays.stream(trainingData).collect(Collectors.toList()));
        Collections.shuffle(trainingList);
        @Nonnull final Tensor[][] randomSelection = trainingList.subList(0, 10000).toArray(new Tensor[][] {});
        @Nonnull final Trainable trainable = new ArrayTrainable(randomSelection, supervisedNetwork);
        return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Aggregations

IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)22 Nonnull (javax.annotation.Nonnull)22 Trainable (com.simiacryptus.mindseye.eval.Trainable)20 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)13 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)13 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)12 Layer (com.simiacryptus.mindseye.lang.Layer)10 ArrayList (java.util.ArrayList)9 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)8 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)8 StepRecord (com.simiacryptus.mindseye.test.StepRecord)8 Tensor (com.simiacryptus.mindseye.lang.Tensor)7 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)6 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)6 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)5 Arrays (java.util.Arrays)5 List (java.util.List)5 Map (java.util.Map)5 QQN (com.simiacryptus.mindseye.opt.orient.QQN)4 TrustRegionStrategy (com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)4