use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.
the class QQNTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
// return new IterativeTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 10000))
@Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 10000), new ArrayTrainable(trainingData, supervisedNetwork)).setMonitor(monitor);
trainer.getRegimen().get(0).setOrientation(new QQN());
return trainer.setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).run();
});
}
use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.
the class ClassifyProblem method run.
@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
final Tensor[][] trainingData = getTrainingData(log);
@Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
log.h3("Network Diagram");
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
});
log.h3("Training");
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
TestUtil.instrumentPerformance(supervisedNetwork);
int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
@Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
log.code(() -> {
trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
});
if (!history.isEmpty()) {
log.code(() -> {
return TestUtil.plot(history);
});
log.code(() -> {
return TestUtil.plotTime(history);
});
}
try {
@Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
@Nonnull File file = new File(log.getResourceDir(), filename);
log.appendFrontMatterProperty("result_plot", file.toString(), ";");
} catch (IOException e) {
throw new RuntimeException(e);
}
TestUtil.extractPerformance(log, supervisedNetwork);
@Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
log.appendFrontMatterProperty("result_model", modelName, ";");
log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
log.h3("Validation");
log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
log.code(() -> {
return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
});
log.p("Let's examine some incorrectly predicted results in more detail:");
log.code(() -> {
try {
@Nonnull final TableOutput table = new TableOutput();
Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
@Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
} catch (@Nonnull final IOException e) {
throw new RuntimeException(e);
}
});
return this;
}
use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.
the class LBFGSTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 10000), new ArrayTrainable(trainingData, supervisedNetwork).cached()).setMonitor(monitor);
trainer.getRegimen().get(0).setOrientation(new LBFGS()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new QuadraticSearch().setCurrentRate(1.0) : new QuadraticSearch());
return trainer.setTimeout(5, TimeUnit.MINUTES).setMaxIterations(500).run();
});
}
use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.
the class RecursiveSubspaceTest method train.
@Override
public void train(@Nonnull final NotebookOutput log, @Nonnull final Layer network, @Nonnull final Tensor[][] trainingData, final TrainingMonitor monitor) {
log.code(() -> {
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
@Nonnull ValidatingTrainer trainer = new ValidatingTrainer(new SampledArrayTrainable(trainingData, supervisedNetwork, 1000, 1000), new ArrayTrainable(trainingData, supervisedNetwork, 1000).cached()).setMonitor(monitor);
trainer.getRegimen().get(0).setOrientation(getOrientation()).setLineSearchFactory(name -> name.toString().contains("LBFGS") ? new StaticLearningRate(1.0) : new QuadraticSearch());
return trainer.setTimeout(15, TimeUnit.MINUTES).setMaxIterations(500).run();
});
}
use of com.simiacryptus.mindseye.eval.ArrayTrainable in project MindsEye by SimiaCryptus.
the class DeepDream method train.
/**
* Train buffered image.
*
* @param server the server
* @param log the log
* @param canvasImage the canvas image
* @param network the network
* @param precision the precision
* @param trainingMinutes the training minutes
* @return the buffered image
*/
@Nonnull
public BufferedImage train(final StreamNanoHTTPD server, @Nonnull final NotebookOutput log, final BufferedImage canvasImage, final PipelineNetwork network, final Precision precision, final int trainingMinutes) {
System.gc();
Tensor canvas = Tensor.fromRGB(canvasImage);
TestUtil.monitorImage(canvas, false, false);
network.setFrozen(true);
ArtistryUtil.setPrecision(network, precision);
@Nonnull Trainable trainable = new ArrayTrainable(network, 1).setVerbose(true).setMask(true).setData(Arrays.asList(new Tensor[][] { { canvas } }));
TestUtil.instrumentPerformance(network);
if (null != server)
ArtistryUtil.addLayersHandler(network, server);
log.code(() -> {
@Nonnull ArrayList<StepRecord> history = new ArrayList<>();
new IterativeTrainer(trainable).setMonitor(TestUtil.getMonitor(history)).setIterationsPerSample(100).setOrientation(new TrustRegionStrategy() {
@Override
public TrustRegion getRegionPolicy(final Layer layer) {
return new RangeConstraint();
}
}).setLineSearchFactory(name -> new BisectionSearch().setSpanTol(1e-1).setCurrentRate(1e3)).setTimeout(trainingMinutes, TimeUnit.MINUTES).setTerminateThreshold(Double.NEGATIVE_INFINITY).runAndFree();
return TestUtil.plot(history);
});
return canvas.toImage();
}
Aggregations