Search in sources :

Example 1 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class ClassifyProblem method run.

@Nonnull
@Override
public ClassifyProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    final Tensor[][] trainingData = getTrainingData(log);
    @Nonnull final DAGNetwork network = fwdFactory.imageToVector(log, categories);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(network)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.h3("Training");
    @Nonnull final SimpleLossNetwork supervisedNetwork = new SimpleLossNetwork(network, new EntropyLossLayer());
    TestUtil.instrumentPerformance(supervisedNetwork);
    int initialSampleSize = Math.max(trainingData.length / 5, Math.min(10, trainingData.length / 2));
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, initialSampleSize, getBatchSize()), new ArrayTrainable(trainingData, supervisedNetwork, getBatchSize()), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName() + "_" + ClassifyProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        @Nonnull File file = new File(log.getResourceDir(), filename);
        log.appendFrontMatterProperty("result_plot", file.toString(), ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    @Nonnull final String modelName = "classification_model_" + ClassifyProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
    log.h3("Validation");
    log.p("If we apply our model against the entire validation dataset, we get this accuracy:");
    log.code(() -> {
        return data.validationData().mapToDouble(labeledObject -> predict(network, labeledObject)[0] == parse(labeledObject.label) ? 1 : 0).average().getAsDouble() * 100;
    });
    log.p("Let's examine some incorrectly predicted results in more detail:");
    log.code(() -> {
        try {
            @Nonnull final TableOutput table = new TableOutput();
            Lists.partition(data.validationData().collect(Collectors.toList()), 100).stream().flatMap(batch -> {
                @Nonnull TensorList batchIn = TensorArray.create(batch.stream().map(x -> x.data).toArray(i -> new Tensor[i]));
                TensorList batchOut = network.eval(new ConstantResult(batchIn)).getData();
                return IntStream.range(0, batchOut.length()).mapToObj(i -> toRow(log, batch.get(i), batchOut.get(i).getData()));
            }).filter(x -> null != x).limit(10).forEach(table::putRow);
            return table;
        } catch (@Nonnull final IOException e) {
            throw new RuntimeException(e);
        }
    });
    return this;
}
Also used : IntStream(java.util.stream.IntStream) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Lists(com.google.common.collect.Lists) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Collectors(java.util.stream.Collectors) File(java.io.File) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) IOException(java.io.IOException) TensorList(com.simiacryptus.mindseye.lang.TensorList) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) File(java.io.File) Nonnull(javax.annotation.Nonnull)

Example 2 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class EncodingUtil method getMonitor.

/**
 * Gets monitor.
 *
 * @param history the history
 * @return the monitor
 */
public static TrainingMonitor getMonitor(@Nonnull final List<StepRecord> history) {
    return new TrainingMonitor() {

        @Override
        public void clear() {
            super.clear();
        }

        @Override
        public void log(final String msg) {
            // Logged MnistProblemData
            log.info(msg);
            // Realtime MnistProblemData
            EncodingUtil.rawOut.println(msg);
        }

        @Override
        public void onStepComplete(@Nonnull final Step currentPoint) {
            history.add(new StepRecord(currentPoint.point.getMean(), currentPoint.time, currentPoint.iteration));
        }
    };
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Nonnull(javax.annotation.Nonnull) Step(com.simiacryptus.mindseye.opt.Step)

Example 3 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class BasicTrainable method eval.

/**
 * Eval point sample.
 *
 * @param list    the list
 * @param monitor the monitor
 * @return the point sample
 */
@Nonnull
protected PointSample eval(@Nonnull final List<Tensor[]> list, @Nullable final TrainingMonitor monitor) {
    @Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
        final Result[] nnContext = BasicTrainable.getNNContext(list, mask);
        final Result result = network.eval(nnContext);
        for (@Nonnull Result nnResult : nnContext) {
            nnResult.getData().freeRef();
            nnResult.freeRef();
        }
        final TensorList resultData = result.getData();
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull StateSet<Layer> stateSet = null;
        try {
            final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
                double[] array = Arrays.stream(x.getData()).toArray();
                x.freeRef();
                return Arrays.stream(array);
            }).summaryStatistics();
            final double sum = statistics.getSum();
            result.accumulate(deltaSet, 1.0);
            stateSet = new StateSet<>(deltaSet);
            // log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
            return new PointSample(deltaSet, stateSet, sum, 0.0, list.size());
        } finally {
            if (null != stateSet)
                stateSet.freeRef();
            resultData.freeRefAsync();
            result.freeRefAsync();
            deltaSet.freeRefAsync();
        }
    });
    if (null != monitor && verbosity() > 0) {
        monitor.log(String.format("Device completed %s items in %.3f sec", list.size(), timedResult.timeNanos / 1e9));
    }
    @Nonnull PointSample normalize = timedResult.result.normalize();
    timedResult.result.freeRef();
    return normalize;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) MutableResult(com.simiacryptus.mindseye.lang.MutableResult) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull)

Example 4 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class TensorListTrainable method eval.

/**
 * Eval point sample.
 *
 * @param list    the list
 * @param monitor the monitor
 * @return the point sample
 */
@Nonnull
protected PointSample eval(@Nonnull final TensorList[] list, @Nullable final TrainingMonitor monitor) {
    int inputs = data.length;
    assert 0 < inputs;
    int items = data[0].length();
    assert 0 < items;
    @Nonnull final TimedResult<PointSample> timedResult = TimedResult.time(() -> {
        final Result[] nnContext = TensorListTrainable.getNNContext(list, mask);
        final Result result = network.eval(nnContext);
        for (@Nonnull Result nnResult : nnContext) {
            nnResult.getData().freeRef();
            nnResult.freeRef();
        }
        final TensorList resultData = result.getData();
        final DoubleSummaryStatistics statistics = resultData.stream().flatMapToDouble(x -> {
            double[] array = Arrays.stream(x.getData()).toArray();
            x.freeRef();
            return Arrays.stream(array);
        }).summaryStatistics();
        final double sum = statistics.getSum();
        @Nonnull final DeltaSet<Layer> deltaSet = new DeltaSet<Layer>();
        @Nonnull PointSample pointSample;
        try {
            result.accumulate(deltaSet, 1.0);
            // log.info(String.format("Evaluated to %s delta buffers, %s mag", DeltaSet<LayerBase>.getMap().size(), DeltaSet<LayerBase>.getMagnitude()));
            @Nonnull StateSet<Layer> stateSet = new StateSet<>(deltaSet);
            pointSample = new PointSample(deltaSet, stateSet, sum, 0.0, items);
            stateSet.freeRef();
        } finally {
            resultData.freeRef();
            result.freeRef();
            deltaSet.freeRef();
        }
        return pointSample;
    });
    if (null != monitor && verbosity() > 0) {
        monitor.log(String.format("Device completed %s items in %.3f sec", items, timedResult.timeNanos / 1e9));
    }
    @Nonnull PointSample normalize = timedResult.result.normalize();
    timedResult.result.freeRef();
    return normalize;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCountingBase(com.simiacryptus.mindseye.lang.ReferenceCountingBase) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) Result(com.simiacryptus.mindseye.lang.Result) StateSet(com.simiacryptus.mindseye.lang.StateSet) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Result(com.simiacryptus.mindseye.lang.Result) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TimedResult(com.simiacryptus.util.lang.TimedResult) PointSample(com.simiacryptus.mindseye.lang.PointSample) StateSet(com.simiacryptus.mindseye.lang.StateSet) Nonnull(javax.annotation.Nonnull)

Example 5 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class TrainingTester method trainGD.

/**
 * Train gd list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainGD(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("First, we train using basic gradient descent method apply weak line search conditions.");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new ArmijoWolfeSearch()).setOrientation(new GradientDescent()).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) Nonnull(javax.annotation.Nonnull) ArrayList(java.util.ArrayList) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) Nonnull(javax.annotation.Nonnull)

Aggregations

TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 Nonnull (javax.annotation.Nonnull)19 Layer (com.simiacryptus.mindseye.lang.Layer)13 List (java.util.List)12 Nullable (javax.annotation.Nullable)12 PointSample (com.simiacryptus.mindseye.lang.PointSample)10 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)9 Trainable (com.simiacryptus.mindseye.eval.Trainable)8 StepRecord (com.simiacryptus.mindseye.test.StepRecord)8 IntStream (java.util.stream.IntStream)8 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 ArrayList (java.util.ArrayList)7 Result (com.simiacryptus.mindseye.lang.Result)6 StateSet (com.simiacryptus.mindseye.lang.StateSet)6 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)6 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)5 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)5