Search in sources :

Example 1 with ArrayTrainable

use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.

the class QQNTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        // return new IterativeTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 10000))
        @Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 10000), new ArrayTrainable(trainingData, supervisedNetwork)).setMonitor(monitor);
        trainer.getRegimen().get(0).setOrientation(new QQN());
        return trainer.setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).run();
    });
}
Also used : Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork)

Example 2 with ArrayTrainable

use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.

the class ClassifyProblem method run.

@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        @Nonnull File file = new File(log.getResourceDir(), filename);
        log.appendFrontMatterProperty("result_plot", file.toString(), ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        try {
            @Nonnull final TableOutput table = new TableOutput();
            Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
                @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
                TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
                return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
            }).filter(x -> null != x).limit(10).forEach(table::putRow);
            return table;
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) IOException(java.io.IOException) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) File(java.io.File) Nonnull(javax.annotation.Nonnull)

Example 3 with ArrayTrainable

use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.

the class LBFGSTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 10000), new ArrayTrainable(trainingData, supervisedNetwork).cached()).setMonitor(monitor);
        trainer.getRegimen().get(0).setOrientation(new LBFGS()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new QuadraticSearch().setCurrentRate(1.0) : new QuadraticSearch());
        return trainer.setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).run();
    });
}
Also used : Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork)

Example 4 with ArrayTrainable

use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.

the class RecursiveSubspaceTest method train.

@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
    log.code(() -> {
        @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
        @Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 1000), new ArrayTrainable(trainingData, supervisedNetwork, 1000).cached()).setMonitor(monitor);
        trainer.getRegimen().get(0).setOrientation(getOrientation()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new StaticLearningRate(1.0) : new QuadraticSearch());
        return trainer.setTimeout(15, TimeUnit.MINUTES).setMaxIterations(500).run();
    });
}
Also used : Nonnull(javax.annotation.Nonnull) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork)

Example 5 with ArrayTrainable

use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.

the class DeepDream method train.

/**
 * Train buffered image.
 *
 * @param server          the server
 * @param log             the log
 * @param canvasImage     the canvas image
 * @param network         the network
 * @param precision       the precision
 * @param trainingMinutes the training minutes
 * @return the buffered image
 */
@Nonnull
public BufferedImage train(final StreamNanoHTTPD server, @Nonnull final NotebookOutput log, final BufferedImage canvasImage, final PipelineNetwork network, final Precision precision, final int trainingMinutes) {
    System.gc();
    Tensor canvas = Tensor.fromRGB(canvasImage);
    TestUtil.monitorImage(canvas, false, false);
    network.setFrozen(true);
    ArtistryUtil.setPrecision(network, precision);
    @Nonnull Trainable trainable = new ArrayTrainable(network, 1).setVerbose(true).setMask(true).setData(Arrays.asList(new Tensor[][] { { canvas } }));
    TestUtil.instrumentPerformance(network);
    if (null != server)
        ArtistryUtil.addLayersHandler(network, server);
    log.code(() -> {
        @Nonnull ArrayList<StepRecord> history = new ArrayList<>();
        new IterativeTrainer(trainable).setMonitor(TestUtil.getMonitor(history)).setIterationsPerSample(100).setOrientation(new TrustRegionStrategy() {

            @Override
            public TrustRegion getRegionPolicy(final Layer layer) {
                return new RangeConstraint();
            }
        }).setLineSearchFactory(name -> new BisectionSearch().setSpanTol(1e-1).setCurrentRate(1e3)).setTimeout(trainingMinutes, TimeUnit.MINUTES).setTerminateThreshold(Double.NEGATIVE_INFINITY).runAndFree();
        return TestUtil.plot(history);
    });
    return canvas.toImage();
}
Also used : TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Arrays(java.util.Arrays) TrustRegion(com.simiacryptus.mindseye.opt.region.TrustRegion) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy) HashMap(java.util.HashMap) NullNotebookOutput(com.simiacryptus.util.io.NullNotebookOutput) MultiLayerImageNetwork(com.simiacryptus.mindseye.models.MultiLayerImageNetwork) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Tuple2(com.simiacryptus.util.lang.Tuple2) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) SquareActivationLayer(com.simiacryptus.mindseye.layers.cudnn.SquareActivationLayer) Logger(org.slf4j.Logger) BufferedImage(java.awt.image.BufferedImage) AvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.AvgReducerLayer) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) TestUtil(com.simiacryptus.mindseye.test.TestUtil) UUID(java.util.UUID) DAGNode(com.simiacryptus.mindseye.network.DAGNode) StreamNanoHTTPD(com.simiacryptus.util.StreamNanoHTTPD) TimeUnit(java.util.concurrent.TimeUnit) BisectionSearch(com.simiacryptus.mindseye.opt.line.BisectionSearch) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) BinarySumLayer(com.simiacryptus.mindseye.layers.cudnn.BinarySumLayer) MultiLayerVGG16(com.simiacryptus.mindseye.models.MultiLayerVGG16) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) LayerEnum(com.simiacryptus.mindseye.models.LayerEnum) MultiLayerVGG19(com.simiacryptus.mindseye.models.MultiLayerVGG19) Tensor(com.simiacryptus.mindseye.lang.Tensor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer) Layer(com.simiacryptus.mindseye.lang.Layer) SquareActivationLayer(com.simiacryptus.mindseye.layers.cudnn.SquareActivationLayer) AvgReducerLayer(com.simiacryptus.mindseye.layers.cudnn.AvgReducerLayer) ValueLayer(com.simiacryptus.mindseye.layers.cudnn.ValueLayer) BinarySumLayer(com.simiacryptus.mindseye.layers.cudnn.BinarySumLayer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) RangeConstraint(com.simiacryptus.mindseye.opt.region.RangeConstraint) BisectionSearch(com.simiacryptus.mindseye.opt.line.BisectionSearch) Trainable(com.simiacryptus.mindseye.eval.Trainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TrustRegionStrategy(com.simiacryptus.mindseye.opt.orient.TrustRegionStrategy) Nonnull(javax.annotation.Nonnull)

Aggregations

ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)15 Nonnull (javax.annotation.Nonnull)15 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 ArrayList (java.util.ArrayList)10 StepRecord (com.simiacryptus.mindseye.test.StepRecord)9 Trainable (com.simiacryptus.mindseye.eval.Trainable)8 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)8 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)8 List (java.util.List)8 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)7 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)7 Arrays (java.util.Arrays)7 Layer (com.simiacryptus.mindseye.lang.Layer)6 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)6 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)6 TestUtil (com.simiacryptus.mindseye.test.TestUtil)6 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)6 TimeUnit (java.util.concurrent.TimeUnit)6 Nullable (javax.annotation.Nullable)6 DAGNode (com.simiacryptus.mindseye.network.DAGNode)5