use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.
the class ImageClassifierTestBase method run.
/**
* Test.
*
* @param log the log
*/
public void run(@Nonnull NotebookOutput log) {
Future<Tensor[][]> submit = Executors.newSingleThreadExecutor().submit(() -> Arrays.stream(EncodingUtil.getImages(log, img -> {
return img;
// return TestUtil.resize(img, 224, 224);
// if(img.getWidth()>img.getHeight()) {
// return TestUtil.resize(img, 224, img.getHeight() * 224 / img.getWidth());
// } else {
// return TestUtil.resize(img, img.getWidth() * 224 / img.getHeight(), 224);
// }
}, 10, new CharSequence[] {})).toArray(i -> new Tensor[i][]));
ImageClassifier vgg16 = getImageClassifier(log);
@Nonnull Layer network = vgg16.getNetwork();
log.h1("Network Diagram");
log.p("This is a diagram of the imported network:");
log.code(() -> {
return Graphviz.fromGraph(TestUtil.toGraph((DAGNetwork) network)).height(4000).width(800).render(Format.PNG).toImage();
});
// @javax.annotation.Nonnull SerializationTest serializationTest = new SerializationTest();
// serializationTest.setPersist(true);
// serializationTest.test(log, network, (Tensor[]) null);
log.h1("Predictions");
Tensor[][] images;
try {
images = submit.get();
} catch (Exception e) {
throw new RuntimeException(e);
}
@Nonnull Map<CharSequence, List<LinkedHashMap<CharSequence, Double>>> modelPredictions = new HashMap<>();
modelPredictions.put("Source", predict(log, vgg16, network, images));
network.freeRef();
// serializationTest.getModels().forEach((precision, model) -> {
// log.h2(precision.name());
// modelPredictions.put(precision.name(), predict(log, vgg16, model, images));
// });
log.h1("Result");
log.code(() -> {
@Nonnull TableOutput tableOutput = new TableOutput();
for (int i = 0; i < images.length; i++) {
int index = i;
@Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
row.put("Image", log.image(images[i][1].toImage(), ""));
modelPredictions.forEach((model, predictions) -> {
row.put(model, predictions.get(index).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
});
tableOutput.putRow(row);
}
return tableOutput;
}, 256 * 1024);
// log.p("CudaSystem Statistics:");
// log.code(() -> {
// return TestUtil.toFormattedJson(CudaSystem.getExecutionStatistics());
// });
}
use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.
the class CudaLayerTestBase method getReferenceIOTester.
@Nullable
@Override
protected ComponentTest<ToleranceStatistics> getReferenceIOTester() {
@Nullable final ComponentTest<ToleranceStatistics> inner = super.getReferenceIOTester();
return new ComponentTestBase<ToleranceStatistics>() {
@Override
protected void _free() {
inner.freeRef();
super._free();
}
@Override
public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
@Nullable PrintStream apiLog = null;
try {
@Nonnull String logName = "cuda_" + log.getName() + "_io.log";
log.p(log.file((String) null, logName, "GPU Log"));
apiLog = new PrintStream(log.file(logName));
CudaSystem.addLog(apiLog);
return inner.test(log, component, inputPrototype);
} finally {
if (null != apiLog) {
apiLog.close();
CudaSystem.apiLog.remove(apiLog);
}
}
}
};
}
use of com.simiacryptus.util.io.NotebookOutput in project MindsEye by SimiaCryptus.
the class CudaLayerTestBase method getPerformanceTester.
@Nullable
@Override
public ComponentTest<ToleranceStatistics> getPerformanceTester() {
@Nullable ComponentTest<ToleranceStatistics> inner = new PerformanceTester().setBatches(testingBatchSize);
return new ComponentTestBase<ToleranceStatistics>() {
@Override
protected void _free() {
inner.freeRef();
super._free();
}
@Override
public ToleranceStatistics test(@Nonnull NotebookOutput log, Layer component, Tensor... inputPrototype) {
@Nullable PrintStream apiLog = null;
try {
@Nonnull String logName = "cuda_" + log.getName() + "_perf.log";
log.p(log.file((String) null, logName, "GPU Log"));
apiLog = new PrintStream(log.file(logName));
CudaSystem.addLog(apiLog);
return inner.test(log, component, inputPrototype);
} finally {
if (null != apiLog) {
apiLog.close();
CudaSystem.apiLog.remove(apiLog);
}
}
}
};
}
Aggregations