Search in sources :

Example 1 with LabeledObject

use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.

the class MnistTestBase method validate.

/**
 * Validate.
 *
 * @param log     the log
 * @param network the network
 */
public void validate(@Nonnull final NotebookOutput log, @Nonnull final Layer network) {
    log.h1("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return MNIST.validationDataStream().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        MNIST.validationDataStream().map(labeledObject -> {
            final int actualCategory = parse(labeledObject.label);
            @Nullable final double[] predictionSignal = network.eval(labeledObject.data).getData().get(0).getData();
            final int[] predictionList = IntStream.range(0, 10).mapToObj(x -> x).sorted(Comparator.comparing(i -> -predictionSignal[i])).mapToInt(x -> x).toArray();
            // We will only examine mispredicted rows
            if (predictionList[0] == actualCategory)
                return null;
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Image", log.image(labeledObject.data.toGrayImage(), labeledObject.label));
            row.put("Prediction", Arrays.stream(predictionList).limit(3).mapToObj(i -> String.format("%d (%.1f%%)", i, 100.0 * predictionSignal[i])).reduce((a, b) -> a + ", " + b).get());
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) NotebookReportBase(com.simiacryptus.mindseye.test.NotebookReportBase) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) LabeledObject(com.simiacryptus.util.test.LabeledObject) Layer(com.simiacryptus.mindseye.lang.Layer) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) MNIST(com.simiacryptus.mindseye.test.data.MNIST) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) Test(org.junit.Test) IOException(java.io.IOException) Category(org.junit.experimental.categories.Category) MonitoredObject(com.simiacryptus.util.MonitoredObject) FullyConnectedLayer(com.simiacryptus.mindseye.layers.java.FullyConnectedLayer) List(java.util.List) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) BiasLayer(com.simiacryptus.mindseye.layers.java.BiasLayer) TestCategories(com.simiacryptus.util.test.TestCategories) TableOutput(com.simiacryptus.util.TableOutput) Nonnull(javax.annotation.Nonnull) LinkedHashMap(java.util.LinkedHashMap)

Example 2 with LabeledObject

use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.

the class ClassifyProblem method run.

@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        @Nonnull File file = new File(log.getResourceDir(), filename);
        log.appendFrontMatterProperty("result_plot", file.toString(), ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        try {
            @Nonnull final TableOutput table = new TableOutput();
            Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
                @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
                TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
                return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
            }).filter(x -> null != x).limit(10).forEach(table::putRow);
            return table;
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) IOException(java.io.IOException) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) File(java.io.File) Nonnull(javax.annotation.Nonnull)

Example 3 with LabeledObject

use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.

the class CIFAR10 method toImage.

private static LabeledObject<BufferedImage> toImage(final byte[] b) {
    @Nonnull final BufferedImage img = new BufferedImage(32, 32, BufferedImage.TYPE_INT_RGB);
    for (int x = 0; x < img.getWidth(); x++) {
        for (int y = 0; y < img.getHeight(); y++) {
            final int red = 0xFF & b[1 + 1024 * 0 + x + y * 32];
            final int blue = 0xFF & b[1 + 1024 * 1 + x + y * 32];
            final int green = 0xFF & b[1 + 1024 * 2 + x + y * 32];
            final int c = (red << 16) + (blue << 8) + green;
            img.setRGB(x, y, c);
        }
    }
    return new LabeledObject<>(img, Arrays.toString(new byte[] { b[0] }));
}
Also used : Nonnull(javax.annotation.Nonnull) LabeledObject(com.simiacryptus.util.test.LabeledObject) BufferedImage(java.awt.image.BufferedImage)

Example 4 with LabeledObject

use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.

the class AutoencodingProblem method run.

@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
    @Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
    @Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
    echoNetwork.add(fwdNetwork);
    echoNetwork.add(revNetwork);
    @Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
    supervisedNetwork.add(fwdNetwork);
    @Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
    supervisedNetwork.add(dropoutNoiseLayer);
    supervisedNetwork.add(revNetwork);
    supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
    log.h3("Network Diagrams");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final TrainingMonitor monitor = new TrainingMonitor() {

        @Nonnull
        TrainingMonitor inner = TestUtil.getMonitor(history);

        @Override
        public void log(final String msg) {
            inner.log(msg);
        }

        @Override
        public void onStepComplete(final Step currentPoint) {
            dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
            inner.onStepComplete(currentPoint);
        }
    };
    final Tensor[][] trainingData = getTrainingData(log);
    // MonitoredObject monitoringRoot = new MonitoredObject();
    // TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
    log.h3("Training");
    TestUtil.instrumentPerformance(supervisedNetwork);
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    {
        @Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
        log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
    }
    @Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
    log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
    // log.h3("Metrics");
    // log.code(() -> {
    // return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
    // });
    log.h3("Validation");
    log.p("Here are some re-encoded examples:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        data.validationData().map(labeledObject -> {
            return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
        log.out(log.image(tensor.toImage(), ""));
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 5 with LabeledObject

use of com.simiacryptus.util.test.LabeledObject in project MindsEye by SimiaCryptus.

the class AutoencodingProblem method toRow.

/**
 * To row linked hash map.
 *
 * @param log              the log
 * @param labeledObject    the labeled object
 * @param predictionSignal the prediction signal
 * @return the linked hash map
 */
@Nonnull
public LinkedHashMap<CharSequence, Object> toRow(@Nonnull final NotebookOutput log, @Nonnull final LabeledObject<Tensor> labeledObject, final double[] predictionSignal) {
    @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
    row.put("Image", log.image(labeledObject.data.toImage(), labeledObject.label));
    row.put("Echo", log.image(new Tensor(predictionSignal, labeledObject.data.getDimensions()).toImage(), labeledObject.label));
    return row;
}
Also used : Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) LabeledObject(com.simiacryptus.util.test.LabeledObject) LinkedHashMap(java.util.LinkedHashMap) Nonnull(javax.annotation.Nonnull)

Aggregations

LabeledObject (com.simiacryptus.util.test.LabeledObject)6 Nonnull (javax.annotation.Nonnull)5 Tensor (com.simiacryptus.mindseye.lang.Tensor)4 LinkedHashMap (java.util.LinkedHashMap)4 Nullable (javax.annotation.Nullable)4 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)3 TableOutput (com.simiacryptus.util.TableOutput)3 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)3 IOException (java.io.IOException)3 ArrayList (java.util.ArrayList)3 List (java.util.List)3 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)2 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)2 Layer (com.simiacryptus.mindseye.lang.Layer)2 PipelineNetwork (com.simiacryptus.mindseye.network.PipelineNetwork)2 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)2 ValidatingTrainer (com.simiacryptus.mindseye.opt.ValidatingTrainer)2 StepRecord (com.simiacryptus.mindseye.test.StepRecord)2 TestUtil (com.simiacryptus.mindseye.test.TestUtil)2 Format (guru.nidi.graphviz.engine.Format)2