use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.
the class MnistTestBase method validate.
/**
* Validate.
*
* @param log the log
* @param network the network
*/
public void validate(@Nonnull final NotebookOutput log, @Nonnull final Layer network) {
log.h1("Validation");
log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
log.code(() -> {
return MNIST.validationDataStream().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
});
log.p("Let's examine some incorrectly predicted results in more detail:");
log.code(() -> {
@Nonnull final TableOutput table = new TableOutput();
MNIST.validationDataStream().map(labeledObject -> {
final int actualCategory = parse(labeledObject.label);
@Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
final int[] predictionList = IntStream.range(0, 10).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
// We will only examine mispredicted rows
if (predictionList[0] == actualCategory)
return null;
@Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
row.put("Image", log.image(labeledObject.data.toGrayImage(), labeledObject.label));
row.put("Prediction", Arrays.stream(predictionList).limit(3).mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i])).reduce((a, b) -> a + ", " + b).get());
return row;
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
});
}
use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.
the class ClassifyProblem method run.
@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
final Tensor[][] trainingData = getTrainingData(log);
@Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
log.h3("Network Diagram");
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
});
log.h3("Training");
@Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
TestUtil.instrumentPerformance(supervisedNetwork);
int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
@Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
log.code(() -> {
trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
});
if (!history.isEmpty()) {
log.code(() -> {
return TestUtil.plot(history);
});
log.code(() -> {
return TestUtil.plotTime(history);
});
}
try {
@Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
@Nonnull File file = new File(log.getResourceDir(), filename);
log.appendFrontMatterProperty("result_plot", file.toString(), ";");
} catch (IOException e) {
throw new RuntimeException(e);
}
TestUtil.extractPerformance(log, supervisedNetwork);
@Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
log.appendFrontMatterProperty("result_model", modelName, ";");
log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
log.h3("Validation");
log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
log.code(() -> {
return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
});
log.p("Let's examine some incorrectly predicted results in more detail:");
log.code(() -> {
try {
@Nonnull final TableOutput table = new TableOutput();
Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
@Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
} catch (@Nonnull final IOException e) {
throw new RuntimeException(e);
}
});
return this;
}
use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.
the class CIFAR10 method toImage.
private static LabeledObject<BufferedImage> toImage(final byte[] b) {
@Nonnull final BufferedImage img = new BufferedImage(32, 32, BufferedImage.TYPE_INT_RGB);
for (int x = 0; x < img.getWidth(); x++) {
for (int y = 0; y < img.getHeight(); y++) {
final int red = 0xFF & b[1 + 1024 * 0 + x + y * 32];
final int blue = 0xFF & b[1 + 1024 * 1 + x + y * 32];
final int green = 0xFF & b[1 + 1024 * 2 + x + y * 32];
final int c = (red << 16) + (blue << 8) + green;
img.setRGB(x, y, c);
}
}
return new LabeledObject<>(img, Arrays.toString(new byte[] { b[0] }));
}
use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.
the class AutoencodingProblem method run.
@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
@Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
@Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
@Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
echoNetwork.add(fwdNetwork);
echoNetwork.add(revNetwork);
@Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
supervisedNetwork.add(fwdNetwork);
@Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
supervisedNetwork.add(dropoutNoiseLayer);
supervisedNetwork.add(revNetwork);
supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
log.h3("Network Diagrams");
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
});
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
});
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
});
@Nonnull final TrainingMonitor monitor = new TrainingMonitor() {
@Nonnull
TrainingMonitor inner = TestUtil.getMonitor(history);
@Override
public void log(final String msg) {
inner.log(msg);
}
@Override
public void onStepComplete(final Step currentPoint) {
dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
inner.onStepComplete(currentPoint);
}
};
final Tensor[][] trainingData = getTrainingData(log);
// MonitoredObject monitoringRoot = new MonitoredObject();
// TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
log.h3("Training");
TestUtil.instrumentPerformance(supervisedNetwork);
@Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
log.code(() -> {
trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
});
if (!history.isEmpty()) {
log.code(() -> {
return TestUtil.plot(history);
});
log.code(() -> {
return TestUtil.plotTime(history);
});
}
TestUtil.extractPerformance(log, supervisedNetwork);
{
@Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
}
@Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
// log.h3("Metrics");
// log.code(() -> {
// return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
// });
log.h3("Validation");
log.p("Here are some re-encoded examples:");
log.code(() -> {
@Nonnull final TableOutput table = new TableOutput();
data.validationData().map(labeledObject -> {
return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
}).filter(x -> null != x).limit(10).forEach(table::putRow);
return table;
});
log.p("Some rendered unit vectors:");
for (int featureNumber = 0; featureNumber < features; featureNumber++) {
@Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
@Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
log.out(log.image(tensor.toImage(), ""));
}
return this;
}
use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.
the class AutoencodingProblem method toRow.
/**
* To row linked hash map.
*
* @param log the log
* @param labeledObject the labeled object
* @param predictionSignal the prediction signal
* @return the linked hash map
*/
@Nonnull
public LinkedHashMap<CharSequence, Object> toRow(@Nonnull final NotebookOutput log, @Nonnull final LabeledObject<Tensor> labeledObject, final double[] predictionSignal) {
@Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
row.put("Image", log.image(labeledObject.data.toImage(), labeledObject.label));
row.put("Echo", log.image(new Tensor(predictionSignal, labeledObject.data.getDimensions()).toImage(), labeledObject.label));
return row;
}
Aggregations