Search in sources :

Example 1 with MonitoredObject

use of com.simiacryptus.util.MonitoredObject in project MindsEye by SimiaCryptus.

the class MnistTestBase method report.

/**
 * Report.
 *
 * @param log            the log
 * @param monitoringRoot the monitoring root
 * @param history        the history
 * @param network        the network
 */
public void report(@Nonnull final NotebookOutput log, @Nonnull final MonitoredObject monitoringRoot, @Nonnull final List<Step> history, @Nonnull final Layer network) {
    if (!history.isEmpty()) {
        log.code(() -> {
            @Nonnull final PlotCanvas plot = ScatterPlot.plot(history.stream().map(step -> new double[] { step.iteration, Math.log10(step.point.getMean()) }).toArray(i -> new double[i][]));
            plot.setTitle("Convergence Plot");
            plot.setAxisLabels("Iteration", "log10(Fitness)");
            plot.setSize(600, 400);
            return plot;
        });
    }
    @Nonnull final String modelName = "model" + modelNo++ + ".json";
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h1("Metrics");
    log.code(() -> {
        try {
            @Nonnull final ByteArrayOutputStream out = new ByteArrayOutputStream();
            JsonUtil.writeJson(out, monitoringRoot.getMetrics());
            return out.toString();
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) LabeledObject(com.simiacryptus.util.test.LabeledObject) Layer(com.simiacryptus.mindseye.lang.Layer) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MNIST(com.simiacryptus.mindseye.test.data.MNIST) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) Test(org.junit.Test) IOException(java.io.IOException) Category(org.junit.experimental.categories.Category) MonitoredObject(com.simiacryptus.util.MonitoredObject) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) List(java.util.List) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) TestCategories(com.simiacryptus.util.test.TestCategories) Nonnull(javax.annotation.Nonnull) ByteArrayOutputStream(java.io.ByteArrayOutputStream) IOException(java.io.IOException) PlotCanvas(smile.plot.PlotCanvas)

Example 2 with MonitoredObject

use of com.simiacryptus.util.MonitoredObject in project MindsEye by SimiaCryptus.

the class MnistTestBase method run.

/**
 * Run.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    @Nonnull final List<Step> history = new ArrayList<>();
    @Nonnull final MonitoredObject monitoringRoot = new MonitoredObject();
    @Nonnull final TrainingMonitor monitor = getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    final DAGNetwork network = buildModel(log);
    addMonitoring(network, monitoringRoot);
    log.h1("Training");
    train(log, network, trainingData, monitor);
    report(log, monitoringRoot, history, network);
    validate(log, network);
    removeMonitoring(network);
}
Also used : Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) MonitoredObject(com.simiacryptus.util.MonitoredObject) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork)

Aggregations

DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)2 MonitoredObject (com.simiacryptus.util.MonitoredObject)2 ArrayList (java.util.ArrayList)2 Nonnull (javax.annotation.Nonnull)2 Layer (com.simiacryptus.mindseye.lang.Layer)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 BiasLayer (com.simiacryptus.mindseye.layers.java.BiasLayer)1 FullyConnectedLayer (com.simiacryptus.mindseye.layers.java.FullyConnectedLayer)1 MonitoringWrapperLayer (com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer)1 SoftmaxActivationLayer (com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer)1 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)1 NotebookReportBase (com.simiacryptus.mindseye.test.NotebookReportBase)1 MNIST (com.simiacryptus.mindseye.test.data.MNIST)1 TableOutput (com.simiacryptus.util.TableOutput)1 JsonUtil (com.simiacryptus.util.io.JsonUtil)1 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)1 LabeledObject (com.simiacryptus.util.test.LabeledObject)1 TestCategories (com.simiacryptus.util.test.TestCategories)1 ByteArrayOutputStream (java.io.ByteArrayOutputStream)1 IOException (java.io.IOException)1