Search in sources :

Example 1 with ConstantResult

use of com.simiacryptus.mindseye.lang.ConstantResult in project MindsEye by SimiaCryptus.

the class ClassifyProblem method run.

@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        @Nonnull File file = new File(log.getResourceDir(), filename);
        log.appendFrontMatterProperty("result_plot", file.toString(), ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        try {
            @Nonnull final TableOutput table = new TableOutput();
            Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
                @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
                TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
                return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
            }).filter(x -> null != x).limit(10).forEach(table::putRow);
            return table;
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) IOException(java.io.IOException) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) File(java.io.File) Nonnull(javax.annotation.Nonnull)

Example 2 with ConstantResult

use of com.simiacryptus.mindseye.lang.ConstantResult in project MindsEye by SimiaCryptus.

the class TensorListTrainable method getNNContext.

/**
 * Get nn context nn result [ ].
 *
 * @param data the data
 * @param mask the mask
 * @return the nn result [ ]
 */
public static Result[] getNNContext(@Nullable final TensorList[] data, @Nullable final boolean[] mask) {
    if (null == data)
        throw new IllegalArgumentException();
    int inputs = data.length;
    assert 0 < inputs;
    int items = data[0].length();
    assert 0 < items;
    return IntStream.range(0, inputs).mapToObj(col -> {
        final Tensor[] tensors = IntStream.range(0, items).mapToObj(row -> data[col].get(row)).toArray(i -> new Tensor[i]);
        @Nonnull TensorArray tensorArray = TensorArray.create(tensors);
        if (null == mask || col >= mask.length || !mask[col]) {
            return new ConstantResult(tensorArray);
        } else {
            return new Result(tensorArray, (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
                for (int index = 0; index < delta.length(); index++) {
                    final Tensor dt = delta.get(index);
                    @Nullable final double[] d = dt.getData();
                    final Tensor t = tensors[index];
                    @Nullable final double[] p = t.getData();
                    @Nonnull PlaceholderLayer<double[]> layer = new PlaceholderLayer<>(p);
                    buffer.get(layer, p).addInPlace(d).freeRef();
                    dt.freeRef();
                    layer.freeRef();
                }
            }) {

                @Override
                public boolean isAlive() {
                    return true;
                }
            };
        }
    }).toArray(x1 -> new Result[x1]);
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer)

Aggregations

ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)2 Layer (com.simiacryptus.mindseye.lang.Layer)2 Tensor (com.simiacryptus.mindseye.lang.Tensor)2 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)2 TensorList (com.simiacryptus.mindseye.lang.TensorList)2 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)2 Arrays (java.util.Arrays)2 IntStream (java.util.stream.IntStream)2 Nonnull (javax.annotation.Nonnull)2 Nullable (javax.annotation.Nullable)2 Lists (com.google.common.collect.Lists)1 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)1 SampledArrayTrainable (com.simiacryptus.mindseye.eval.SampledArrayTrainable)1 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)1 PointSample (com.simiacryptus.mindseye.lang.PointSample)1 ReferenceCountingBase (com.simiacryptus.mindseye.lang.ReferenceCountingBase)1 Result (com.simiacryptus.mindseye.lang.Result)1 StateSet (com.simiacryptus.mindseye.lang.StateSet)1 EntropyLossLayer (com.simiacryptus.mindseye.layers.java.EntropyLossLayer)1 PlaceholderLayer (com.simiacryptus.mindseye.layers.java.PlaceholderLayer)1