Search in sources :

Example 1 with DropoutNoiseLayer

use of com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer in project MindsEye by SimiaCryptus.

the class AutoencodingProblem method run.

@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
    @Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
    @Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
    echoNetwork.add(fwdNetwork);
    echoNetwork.add(revNetwork);
    @Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
    supervisedNetwork.add(fwdNetwork);
    @Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
    supervisedNetwork.add(dropoutNoiseLayer);
    supervisedNetwork.add(revNetwork);
    supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
    log.h3("Network Diagrams");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final TrainingMonitor monitor = new TrainingMonitor() {

        @Nonnull
        TrainingMonitor inner = TestUtil.getMonitor(history);

        @Override
        public void log(final String msg) {
            inner.log(msg);
        }

        @Override
        public void onStepComplete(final Step currentPoint) {
            dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
            inner.onStepComplete(currentPoint);
        }
    };
    final Tensor[][] trainingData = getTrainingData(log);
    // MonitoredObject monitoringRoot = new MonitoredObject();
    // TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
    log.h3("Training");
    TestUtil.instrumentPerformance(supervisedNetwork);
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    {
        @Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
        log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
    }
    @Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
    log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
    // log.h3("Metrics");
    // log.code(() -> {
    // return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
    // });
    log.h3("Validation");
    log.p("Here are some re-encoded examples:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        data.validationData().map(labeledObject -> {
            return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
        log.out(log.image(tensor.toImage(), ""));
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)1 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 DropoutNoiseLayer (com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer)1 MeanSqLossLayer (com.simiacryptus.mindseye.layers.java.MeanSqLossLayer)1 StochasticComponent (com.simiacryptus.mindseye.layers.java.StochasticComponent)1 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 Step (com.simiacryptus.mindseye.opt.Step)1 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)1 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)1 StepRecord (com.simiacryptus.mindseye.test.StepRecord)1 TestUtil (com.simiacryptus.mindseye.test.TestUtil)1 TableOutput (com.simiacryptus.util.TableOutput)1 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)1 LabeledObject (com.simiacryptus.util.test.LabeledObject)1 Format (guru.nidi.graphviz.engine.Format)1 Graphviz (guru.nidi.graphviz.engine.Graphviz)1 IOException (java.io.IOException)1 ArrayList (java.util.ArrayList)1