use of com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer in project MindsEye by SimiaCryptus.
the class AutoencodingProblem method run.
@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
@Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
@Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
echoNetwork.add(fwdNetwork);
echoNetwork.add(revNetwork);
@Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
supervisedNetwork.add(fwdNetwork);
@Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
supervisedNetwork.add(dropoutNoiseLayer);
supervisedNetwork.add(revNetwork);
supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
log.h3("Network Diagrams");
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
});
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
});
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
});
@Nonnull final TrainingMonitor monitor = new TrainingMonitor() {
@Nonnull
TrainingMonitor inner = TestUtil.getMonitor(history);
@Override
public void log(final String msg) {
inner.log(msg);
}
@Override
public void onStepComplete(final Step currentPoint) {
dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
inner.onStepComplete(currentPoint);
}
};
final Tensor[][] trainingData = getTrainingData(log);
// MonitoredObject monitoringRoot = new MonitoredObject();
// TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
log.h3("Training");
TestUtil.instrumentPerformance(supervisedNetwork);
@Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
log.code(() -> {
trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
});
if (!history.isEmpty()) {
log.code(() -> {
return TestUtil.plot(history);
});
log.code(() -> {
return TestUtil.plotTime(history);
});
}
TestUtil.extractPerformance(log, supervisedNetwork);
{
@Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
}
@Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
// log.h3("Metrics");
// log.code(() -> {
// return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
// });
log.h3("Validation");
log.p("Here are some re-encoded examples:");
log.code(() -> {
@Nonnull final TableOutput table = new TableOutput();
data.validationData().map(labeledObject -> {
return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
});
log.p("Some rendered unit vectors:");
for (int featureNumber = 0; featureNumber < features; featureNumber++) {
@Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
@Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
log.out(log.image(tensor.toImage(), ""));
}
return this;
}
Aggregations