Search in sources :

Example 1 with TableOutput

use of com.simiacryptus.util.TableOutput in project MindsEye by SimiaCryptus.

the class MnistTestBase method validate.

/**
 * Validate.
 *
 * @param log     the log
 * @param network the network
 */
public void validate(@Nonnull final NotebookOutput log, @Nonnull final Layer network) {
    log.h1("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return MNIST.validationDataStream().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        MNIST.validationDataStream().map(labeledObject -> {
            final int actualCategory = parse(labeledObject.label);
            @Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
            final int[] predictionList = IntStream.range(0, 10).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
            // We will only examine mispredicted rows
            if (predictionList[0] == actualCategory)
                return null;
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Image", log.image(labeledObject.data.toGrayImage(), labeledObject.label));
            row.put("Prediction", Arrays.stream(predictionList).limit(3).mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i])).reduce((a, b) -> a + ", " + b).get());
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) LabeledObject(com.simiacryptus.util.test.LabeledObject) Layer(com.simiacryptus.mindseye.lang.Layer) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MNIST(com.simiacryptus.mindseye.test.data.MNIST) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) Test(org.junit.Test) IOException(java.io.IOException) Category(org.junit.experimental.categories.Category) MonitoredObject(com.simiacryptus.util.MonitoredObject) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) List(java.util.List) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) TestCategories(com.simiacryptus.util.test.TestCategories) TableOutput(com.simiacryptus.util.TableOutput) Nonnull(javax.annotation.Nonnull) LinkedHashMap(java.util.LinkedHashMap)

Example 2 with TableOutput

use of com.simiacryptus.util.TableOutput in project MindsEye by SimiaCryptus.

the class ClassifyProblem method run.

@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        @Nonnull File file = new File(log.getResourceDir(), filename);
        log.appendFrontMatterProperty("result_plot", file.toString(), ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        try {
            @Nonnull final TableOutput table = new TableOutput();
            Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
                @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
                TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
                return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
            }).filter(x -> null != x).limit(10).forEach(table::putRow);
            return table;
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) IOException(java.io.IOException) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) File(java.io.File) Nonnull(javax.annotation.Nonnull)

Example 3 with TableOutput

use of com.simiacryptus.util.TableOutput in project MindsEye by SimiaCryptus.

the class AutoencodingProblem method run.

@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
    @Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
    @Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
    echoNetwork.add(fwdNetwork);
    echoNetwork.add(revNetwork);
    @Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
    supervisedNetwork.add(fwdNetwork);
    @Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
    supervisedNetwork.add(dropoutNoiseLayer);
    supervisedNetwork.add(revNetwork);
    supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
    log.h3("Network Diagrams");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final TrainingMonitor monitor = new TrainingMonitor() {

        @Nonnull
        TrainingMonitor inner = TestUtil.getMonitor(history);

        @Override
        public void log(final String msg) {
            inner.log(msg);
        }

        @Override
        public void onStepComplete(final Step currentPoint) {
            dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
            inner.onStepComplete(currentPoint);
        }
    };
    final Tensor[][] trainingData = getTrainingData(log);
    // MonitoredObject monitoringRoot = new MonitoredObject();
    // TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
    log.h3("Training");
    TestUtil.instrumentPerformance(supervisedNetwork);
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    {
        @Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
        log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
    }
    @Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
    log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
    // log.h3("Metrics");
    // log.code(() -> {
    // return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
    // });
    log.h3("Validation");
    log.p("Here are some re-encoded examples:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        data.validationData().map(labeledObject -> {
            return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
        log.out(log.image(tensor.toImage(), ""));
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 4 with TableOutput

use of com.simiacryptus.util.TableOutput in project MindsEye by SimiaCryptus.

the class EncodingProblem method run.

@Nonnull
@Override
public EncodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    Tensor[][] trainingData;
    try {
        trainingData = data.trainingData().map(labeledObject -> {
            return new Tensor[] { new Tensor(features).set(this::random), labeledObject.data };
        }).toArray(i -> new Tensor[i][]);
    } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
    }
    @Nonnull final DAGNetwork imageNetwork = revFactory.vectorToImage(log, features);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(imageNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final PipelineNetwork trainingNetwork = new PipelineNetwork(2);
    @Nullable final DAGNode image = trainingNetwork.add(imageNetwork, trainingNetwork.getInput(0));
    @Nullable final DAGNode softmax = trainingNetwork.add(new SoftmaxActivationLayer(), trainingNetwork.getInput(0));
    trainingNetwork.add(new SumInputsLayer(), trainingNetwork.add(new EntropyLossLayer(), softmax, softmax), trainingNetwork.add(new NthPowerActivationLayer().setPower(1.0 / 2.0), trainingNetwork.add(new MeanSqLossLayer(), image, trainingNetwork.getInput(1))));
    log.h3("Training");
    log.p("We start by training apply a very small population to improve initial convergence performance:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final Tensor[][] primingData = Arrays.copyOfRange(trainingData, 0, 1000);
    @Nonnull final ValidatingTrainer preTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(primingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(primingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        preTrainer.setTimeout(timeoutMinutes / 2, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    log.p("Then our main training phase:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final ValidatingTrainer mainTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(trainingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(trainingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        mainTrainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName().toString() + EncodingProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        log.appendFrontMatterProperty("result_plot", filename, ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // log.file()
    @Nonnull final String modelName = "encoding_model_" + EncodingProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(trainingNetwork.getJson().toString(), modelName, modelName));
    log.h3("Results");
    @Nonnull final PipelineNetwork testNetwork = new PipelineNetwork(2);
    testNetwork.add(imageNetwork, testNetwork.getInput(0));
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        Arrays.stream(trainingData).map(tensorArray -> {
            @Nullable final Tensor predictionSignal = testNetwork.eval(tensorArray).getData().get(0);
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Source", log.image(tensorArray[1].toImage(), ""));
            row.put("Echo", log.image(predictionSignal.toImage(), ""));
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Learned Model Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        trainingNetwork.state().stream().flatMapToDouble(x -> Arrays.stream(x)).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Learned Representation Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        Arrays.stream(trainingData).flatMapToDouble(row -> Arrays.stream(row[0].getData())).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = imageNetwork.eval(input).getData().get(0);
        TestUtil.renderToImages(tensor, true).forEach(img -> {
            log.out(log.image(img, ""));
        });
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) Format(guru.nidi.graphviz.engine.Format) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IOException(java.io.IOException) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNode(com.simiacryptus.mindseye.network.DAGNode) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 5 with TableOutput

use of com.simiacryptus.util.TableOutput in project MindsEye by SimiaCryptus.

the class ImageClassificationBase method run.

/**
 * Test.
 *
 * @param log the log
 */
public void run(@Nonnull NotebookOutput log) {
    log.h1("Model");
    log.p("In this demonstration, we will show how to load an image recognition network and use it to identify object in images.");
    log.p("We start by loading the VGG16 pretrained model using the HD5 importer. This downloads, if needed, the weights from a file in S3 and re-constructs the network architecture by custom code.");
    log.p("Next, we need an example image to analyze:");
    log.p("We pass this image to the categorization network, and get the following top-10 results. Note that multiple objects may be detected, and the total percentage may be greater than 100%.");
    log.p("Once we have categories identified, we can attempt to localize each object category within the image. We do this via a pipeline starting with the backpropagated input signal delta and applying several filters e.g. blurring and normalization to produce an alpha channel. When applied to the input image, we highlight the image areas related to the object type in question. Note that this produces a fuzzy blob, which does indicate object location but is a poor indicator of object boundaries. Below we perform this task for the top 5 object categories:");
    ImageClassifier vgg16 = loadModel(log);
    log.h1("Data");
    Tensor[] images = loadData(log);
    log.h1("Prediction");
    List<LinkedHashMap<CharSequence, Double>> predictions = log.code(() -> {
        return vgg16.predict(5, images);
    });
    log.h1("Results");
    log.code(() -> {
        @Nonnull TableOutput tableOutput = new TableOutput();
        for (int i = 0; i < images.length; i++) {
            @Nonnull HashMap<CharSequence, Object> row = new HashMap<>();
            row.put("Image", log.image(images[i].toImage(), ""));
            row.put("Prediction", predictions.get(i).entrySet().stream().map(e -> String.format("%s -> %.2f", e.getKey(), 100 * e.getValue())).reduce((a, b) -> a + "<br/>" + b).get());
            tableOutput.putRow(row);
        }
        return tableOutput;
    }, 256 * 1024);
    log.setFrontMatterProperty("status", "OK");
}
Also used : BufferedImage(java.awt.image.BufferedImage) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Caltech101(com.simiacryptus.mindseye.test.data.Caltech101) LinkedHashMap(java.util.LinkedHashMap) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier) List(java.util.List) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) LinkedHashMap(java.util.LinkedHashMap) TableOutput(com.simiacryptus.util.TableOutput) ImageClassifier(com.simiacryptus.mindseye.applications.ImageClassifier)

Aggregations

Tensor (com.simiacryptus.mindseye.lang.Tensor)7 TableOutput (com.simiacryptus.util.TableOutput)7 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)7 LinkedHashMap (java.util.LinkedHashMap)7 List (java.util.List)7 Nonnull (javax.annotation.Nonnull)7 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)6 TestUtil (com.simiacryptus.mindseye.test.TestUtil)6 Nullable (javax.annotation.Nullable)6 IOException (java.io.IOException)5 Arrays (java.util.Arrays)5 Layer (com.simiacryptus.mindseye.lang.Layer)4 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)4 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)4 StepRecord (com.simiacryptus.mindseye.test.StepRecord)4 Format (guru.nidi.graphviz.engine.Format)4 Graphviz (guru.nidi.graphviz.engine.Graphviz)4 ArrayList (java.util.ArrayList)4 Comparator (java.util.Comparator)4 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)3