Search in sources :

Example 11 with SimpleLossNetwork

use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.

the class RecursiveSubspaceTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 1000), new ArrayTrainable(trainingData, supervisedNetwork, 1000).cached()).setMonitor(monitor);
        trainer.getRegimen().get(0).setOrientation(getOrientation()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new StaticLearningRate(1.0) : new QuadraticSearch());
        return trainer.setTimeout(15, TimeUnit.MINUTES).setMaxIterations(500).run();
    });
}
Also used : Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork)

Example 12 with SimpleLossNetwork

use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.

the class SingleOrthantTrustRegionTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        @Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new SingleOrthant();
            }
        };
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Example 13 with SimpleLossNetwork

use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.

the class AutoencoderNetwork method train.

/**
 * Train autoencoder network . training parameters.
 *
 * @return the autoencoder network . training parameters
 */
@Nonnull
public AutoencoderNetwork.TrainingParameters train() {
    return new AutoencoderNetwork.TrainingParameters() {

        @Nonnull
        @Override
        public SimpleLossNetwork getTrainingNetwork() {
            @Nonnull final PipelineNetwork student = new PipelineNetwork();
            student.add(encoder);
            student.add(decoder);
            return new SimpleLossNetwork(student, new MeanSqLossLayer());
        }

        @Nonnull
        @Override
        protected TrainingMonitor wrap(@Nonnull final TrainingMonitor monitor) {
            return new TrainingMonitor() {

                @Override
                public void log(final String msg) {
                    monitor.log(msg);
                }

                @Override
                public void onStepComplete(final Step currentPoint) {
                    inputNoise.shuffle();
                    encodedNoise.shuffle(StochasticComponent.random.get().nextLong());
                    monitor.onStepComplete(currentPoint);
                }
            };
        }
    };
}
Also used : TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Step(com.simiacryptus.mindseye.opt.Step) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Nonnull(javax.annotation.Nonnull)

Example 14 with SimpleLossNetwork

use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.

the class TrustSphereTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        @Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new AdaptiveTrustSphere();
            }
        };
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Example 15 with SimpleLossNetwork

use of com.simiacryptus.mindseye.network.SimpleLossNetwork in project MindsEye by SimiaCryptus.

the class QuadraticLineSearchTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new QuadraticSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Aggregations

SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)18 Nonnull (javax.annotation.Nonnull)18 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)17 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)16 Trainable (com.simiacryptus.mindseye.eval.Trainable)13 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)13 Layer (com.simiacryptus.mindseye.lang.Layer)6 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)5 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)4 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)4 TrustRegionStrategy (com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)3 L12Normalizer (com.simiacryptus.mindseye.eval.L12Normalizer)2 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)2 QuadraticSearch (com.simiacryptus.mindseye.opt.line.QuadraticSearch)2 ArrayList (java.util.ArrayList)2 Lists (com.google.common.collect.Lists)1 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)1 TensorList (com.simiacryptus.mindseye.lang.TensorList)1