use of com.simiacryptus.util.MonitoredObject in project MindsEye by SimiaCryptus.
the class MnistTestBase method report.
/**
* Report.
*
* @param log the log
* @param monitoringRoot the monitoring root
* @param history the history
* @param network the network
*/
public void report(@Nonnull final NotebookOutput log, @Nonnull final MonitoredObject monitoringRoot, @Nonnull final List<Step> history, @Nonnull final Layer network) {
if (!history.isEmpty()) {
log.code(() -> {
@Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[] { step.iteration, Math.log10(step.point.getMean()) }).toArray(i -> new double[i][]));
plot.setTitle("Convergence Plot");
plot.setAxisLabels("Iteration", "log10(Fitness)");
plot.setSize(600, 400);
return plot;
});
}
@Nonnull final String modelName = "model" + modelNo++ + ".json";
log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
log.h1("Metrics");
log.code(() -> {
try {
@Nonnull final ByteArrayOutputStream out = new ByteArrayOutputStream();
JsonUtil.writeJson(out, monitoringRoot.getMetrics());
return out.toString();
} catch (@Nonnull final IOException e) {
throw new RuntimeException(e);
}
});
}
use of com.simiacryptus.util.MonitoredObject in project MindsEye by SimiaCryptus.
the class MnistTestBase method run.
/**
* Run.
*
* @param log the log
*/
public void run(@Nonnull NotebookOutput log) {
@Nonnull final List<Step> history = new ArrayList<>();
@Nonnull final MonitoredObject monitoringRoot = new MonitoredObject();
@Nonnull final TrainingMonitor monitor = getMonitor(history);
final Tensor[][] trainingData = getTrainingData(log);
final DAGNetwork network = buildModel(log);
addMonitoring(network, monitoringRoot);
log.h1("Training");
train(log, network, trainingData, monitor);
report(log, monitoringRoot, history, network);
validate(log, network);
removeMonitoring(network);
}
Aggregations