Search in sources :

Example 1 with SampledArrayTrainable

use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.

the class GDTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 2 with SampledArrayTrainable

use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.

the class MomentumTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new ValidatingOrientationWrapper(new MomentumStrategy(new GradientDescent()).setCarryOver(0.8))).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 3 with SampledArrayTrainable

use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.

the class QQNTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        // return new IterativeTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 10000))
        @Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 10000), new ArrayTrainable(trainingData, supervisedNetwork)).setMonitor(monitor);
        trainer.getRegimen().get(0).setOrientation(new QQN());
        return trainer.setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).run();
    });
}
Also used : Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork)

Example 4 with SampledArrayTrainable

use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.

the class LinearSumConstraintTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        @Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new LinearSumConstraint();
            }
        };
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Example 5 with SampledArrayTrainable

use of com.simiacryptus.mindseye.eval.SampledArrayTrainable in project MindsEye by SimiaCryptus.

the class L2NormalizationTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.p("Training a model involves a few different components. First, our model is combined mapCoords a loss function. " + "Then we take that model and combine it mapCoords our training data to define a trainable object. " + "Finally, we use a simple iterative scheme to refine the weights of our model. " + "The final output is the last output value of the loss function when evaluating the last batch.");
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new L12Normalizer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000)) {

            @Override
            public Layer getLayer() {
                return inner.getLayer();
            }

            @Override
            protected double getL1(final Layer layer) {
                return 0.0;
            }

            @Override
            protected double getL2(final Layer layer) {
                return 1e4;
            }
        };
        return new IterativeTrainer(trainable).setMonitor(monitor).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) L12Normalizer(com.simiacryptus.mindseye.eval.L12Normalizer) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer)

Aggregations

SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)19 Nonnull (javax.annotation.Nonnull)19 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)17 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)16 Trainable (com.simiacryptus.mindseye.eval.Trainable)12 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)12 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)7 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)7 Layer (com.simiacryptus.mindseye.lang.Layer)6 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)5 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)3 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)3 QuadraticSearch (com.simiacryptus.mindseye.opt.line.QuadraticSearch)3 TrustRegionStrategy (com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)3 StepRecord (com.simiacryptus.mindseye.test.StepRecord)3 TestUtil (com.simiacryptus.mindseye.test.TestUtil)3 TableOutput (com.simiacryptus.util.TableOutput)3 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)3 Format (guru.nidi.graphviz.engine.Format)3