Search in sources :

Example 6 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class BisectionLineSearchTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 1000);
        return new IterativeTrainer(trainable).setMonitor(monitor).setOrientation(new GradientDescent()).setLineSearchFactory((@Nonnull final CharSequence name) -> new BisectionSearch()).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 7 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class OWLQNTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(new ValidatingOrientationWrapper(new OwlQn())).setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable)

Example 8 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class SingleOrthantTrustRegionTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull final Trainable trainable = new SampledArrayTrainable(trainingData, supervisedNetwork, 10000);
        @Nonnull final TrustRegionStrategy trustRegionStrategy = new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new SingleOrthant();
            }
        };
        return new IterativeTrainer(trainable).setIterationsPerSample(100).setMonitor(monitor).setOrientation(trustRegionStrategy).setTimeout(3, TimeUnit.MINUTES).setMaxIterations(500).runAndFree();
    });
}
Also used : IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) Trainable(com.simiacryptus.mindseye.eval.Trainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)

Example 9 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class DeepDream method train.

/**
 * Train buffered image.
 *
 * @param server          the server
 * @param log             the log
 * @param canvasImage     the canvas image
 * @param network         the network
 * @param precision       the precision
 * @param trainingMinutes the training minutes
 * @return the buffered image
 */
@Nonnull
public BufferedImage train(final StreamNanoHTTPD server, @Nonnull final NotebookOutput log, final BufferedImage canvasImage, final PipelineNetwork network, final Precision precision, final int trainingMinutes) {
    System.gc();
    Tensor canvas = Tensor.fromRGB(canvasImage);
    TestUtil.monitorImage(canvas, false, false);
    network.setFrozen(true);
    ArtistryUtil.setPrecision(network, precision);
    @Nonnull Trainable trainable = new ArrayTrainable(network, 1).setVerbose(true).setMask(true).setData(Arrays.asList(new Tensor[][] { { canvas } }));
    TestUtil.instrumentPerformance(network);
    if (null != server)
        ArtistryUtil.addLayersHandler(network, server);
    log.code(() -> {
        @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
        new IterativeTrainer(trainable).setMonitor(TestUtil.getMonitor(history)).setIterationsPerSample(100).setOrientation(new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new RangeConstraint();
            }
        }).setLineSearchFactory(name -> new BisectionSearch().setSpanTol(1e-1).setCurrentRate(1e3)).setTimeout(trainingMinutes, TimeUnit.MINUTES).setTerminateThreshold(Double.NEGATIVE_INFINITY).runAndFree();
        return TestUtil.plot(history);
    });
    return canvas.toImage();
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Arrays(java.util.Arrays) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy) HashMap(java.util.HashMap) NullNotebookOutput(com.simiacryptus.util.io.NullNotebookOutput) MultiLayerImageNetwork(com.simiacryptus.mindseye.models.MultiLayerImageNetwork) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Tuple2(com.simiacryptus.util.lang.Tuple2) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SquareActivationLayer(com.simiacryptus.mindseye.layers.cudnn.SquareActivationLayer) Logger(org.slf4j.Logger) BufferedImage(java.awt.image.BufferedImage) AvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.AvgReducerLayer) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) TestUtil(com.simiacryptus.mindseye.test.TestUtil) UUID(java.util.UUID) DAGNode(com.simiacryptus.mindseye.network.DAGNode) StreamNanoHTTPD(com.simiacryptus.util.StreamNanoHTTPD) TimeUnit(java.util.concurrent.TimeUnit) BisectionSearch(com.simiacryptus.mindseye.opt.line.BisectionSearch) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) BinarySumLayer(com.simiacryptus.mindseye.layers.cudnn.BinarySumLayer) MultiLayerVGG16(com.simiacryptus.mindseye.models.MultiLayerVGG16) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) LayerEnum(com.simiacryptus.mindseye.models.LayerEnum) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) Tensor(com.simiacryptus.mindseye.lang.Tensor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SquareActivationLayer(com.simiacryptus.mindseye.layers.cudnn.SquareActivationLayer) AvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.AvgReducerLayer) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) BinarySumLayer(com.simiacryptus.mindseye.layers.cudnn.BinarySumLayer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) BisectionSearch(com.simiacryptus.mindseye.opt.line.BisectionSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy) Nonnull(javax.annotation.Nonnull)

Example 10 with IterativeTrainer

use of com.simiacryptus.mindseye.opt.IterativeTrainer in project MindsEye by SimiaCryptus.

the class ImageClassifier method deepDream.

/**
 * Deep dream.
 *
 * @param log   the log
 * @param image the image
 */
public void deepDream(@Nonnull final NotebookOutput log, final Tensor image) {
    log.code(() -> {
        @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
        @Nonnull PipelineNetwork clamp = new PipelineNetwork(1);
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        clamp.add(new ActivationLayer(ActivationLayer.Mode.RELU));
        clamp.add(new LinearActivationLayer().setBias(255).setScale(-1).freeze());
        @Nonnull PipelineNetwork supervised = new PipelineNetwork(1);
        supervised.add(getNetwork().freeze(), supervised.wrap(clamp, supervised.getInput(0)));
        // CudaTensorList gpuInput = CudnnHandle.apply(gpu -> {
        // Precision precision = Precision.Float;
        // return CudaTensorList.wrap(gpu.getPtr(TensorArray.wrap(image), precision, MemoryType.Managed), 1, image.getDimensions(), precision);
        // });
        // @Nonnull Trainable trainable = new TensorListTrainable(supervised, gpuInput).setVerbosity(1).setMask(true);
        @Nonnull Trainable trainable = new ArrayTrainable(supervised, 1).setVerbose(true).setMask(true, false).setData(Arrays.<Tensor[]>asList(new Tensor[] { image }));
        new IterativeTrainer(trainable).setMonitor(getTrainingMonitor(history, supervised)).setOrientation(new QQN()).setLineSearchFactory(name -> new ArmijoWolfeSearch()).setTimeout(60, TimeUnit.MINUTES).runAndFree();
        return TestUtil.plot(history);
    });
}
Also used : ActivationLayer(com.simiacryptus.mindseye.layers.cudnn.ActivationLayer) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) Tensor(com.simiacryptus.mindseye.lang.Tensor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable)

Aggregations

IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)22 Nonnull (javax.annotation.Nonnull)22 Trainable (com.simiacryptus.mindseye.eval.Trainable)20 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)13 SimpleLossNetwork (com.simiacryptus.mindseye.network.SimpleLossNetwork)13 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)12 Layer (com.simiacryptus.mindseye.lang.Layer)10 ArrayList (java.util.ArrayList)9 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)8 GradientDescent (com.simiacryptus.mindseye.opt.orient.GradientDescent)8 StepRecord (com.simiacryptus.mindseye.test.StepRecord)8 Tensor (com.simiacryptus.mindseye.lang.Tensor)7 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)6 ArmijoWolfeSearch (com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch)6 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)5 Arrays (java.util.Arrays)5 List (java.util.List)5 Map (java.util.Map)5 QQN (com.simiacryptus.mindseye.opt.orient.QQN)4 TrustRegionStrategy (com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy)4