Search in sources :

Example 6 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class TrainingTester method trainMagic.

/**
 * Train lbfgs list.
 *
 * @param log       the log
 * @param trainable the trainable
 * @return the list
 */
@Nonnull
public List<StepRecord> trainMagic(@Nonnull final NotebookOutput log, final Trainable trainable) {
    log.p("Now we train using an experimental optimizer:");
    @Nonnull final List<StepRecord> history = new ArrayList<>();
    @Nonnull final TrainingMonitor monitor = TrainingTester.getMonitor(history);
    try {
        log.code(() -> {
            return new IterativeTrainer(trainable).setLineSearchFactory(label -> new StaticLearningRate(1.0)).setOrientation(new RecursiveSubspace() {

                @Override
                public void train(@Nonnull TrainingMonitor monitor, Layer macroLayer) {
                    @Nonnull Tensor[][] nullData = { { new Tensor() } };
                    @Nonnull BasicTrainable inner = new BasicTrainable(macroLayer);
                    @Nonnull ArrayTrainable trainable1 = new ArrayTrainable(inner, nullData);
                    inner.freeRef();
                    new IterativeTrainer(trainable1).setOrientation(new QQN()).setLineSearchFactory(n -> new QuadraticSearch().setCurrentRate(n.equals(QQN.CURSOR_NAME) ? 1.0 : 1e-4)).setMonitor(new TrainingMonitor() {

                        @Override
                        public void log(String msg) {
                            monitor.log("\t" + msg);
                        }
                    }).setMaxIterations(getIterations()).setIterationsPerSample(getIterations()).runAndFree();
                    trainable1.freeRef();
                    for (@Nonnull Tensor[] tensors : nullData) {
                        for (@Nonnull Tensor tensor : tensors) {
                            tensor.freeRef();
                        }
                    }
                }
            }).setMonitor(monitor).setTimeout(30, TimeUnit.SECONDS).setIterationsPerSample(100).setMaxIterations(250).setTerminateThreshold(0).runAndFree();
        });
    } catch (Throwable e) {
        if (isThrowExceptions())
            throw new RuntimeException(e);
    }
    return history;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Random(java.util.Random) Result(com.simiacryptus.mindseye.lang.Result) ArmijoWolfeSearch(com.simiacryptus.mindseye.opt.line.ArmijoWolfeSearch) ArrayList(java.util.ArrayList) Trainable(com.simiacryptus.mindseye.eval.Trainable) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) LBFGS(com.simiacryptus.mindseye.opt.orient.LBFGS) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) QQN(com.simiacryptus.mindseye.opt.orient.QQN) GradientDescent(com.simiacryptus.mindseye.opt.orient.GradientDescent) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) DoubleStream(java.util.stream.DoubleStream) java.awt(java.awt) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Stream(java.util.stream.Stream) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) TensorList(com.simiacryptus.mindseye.lang.TensorList) Step(com.simiacryptus.mindseye.opt.Step) ProblemRun(com.simiacryptus.mindseye.test.ProblemRun) javax.swing(javax.swing) RecursiveSubspace(com.simiacryptus.mindseye.opt.orient.RecursiveSubspace) BasicTrainable(com.simiacryptus.mindseye.eval.BasicTrainable) IterativeTrainer(com.simiacryptus.mindseye.opt.IterativeTrainer) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) QuadraticSearch(com.simiacryptus.mindseye.opt.line.QuadraticSearch) ArrayList(java.util.ArrayList) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) Layer(com.simiacryptus.mindseye.lang.Layer) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) QQN(com.simiacryptus.mindseye.opt.orient.QQN) StepRecord(com.simiacryptus.mindseye.test.StepRecord) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) StaticLearningRate(com.simiacryptus.mindseye.opt.line.StaticLearningRate) Nonnull(javax.annotation.Nonnull)

Example 7 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class AutoencodingProblem method run.

@Nonnull
@Override
public AutoencodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final DAGNetwork fwdNetwork = fwdFactory.imageToVector(log, features);
    @Nonnull final DAGNetwork revNetwork = revFactory.vectorToImage(log, features);
    @Nonnull final PipelineNetwork echoNetwork = new PipelineNetwork(1);
    echoNetwork.add(fwdNetwork);
    echoNetwork.add(revNetwork);
    @Nonnull final PipelineNetwork supervisedNetwork = new PipelineNetwork(1);
    supervisedNetwork.add(fwdNetwork);
    @Nonnull final DropoutNoiseLayer dropoutNoiseLayer = new DropoutNoiseLayer().setValue(dropout);
    supervisedNetwork.add(dropoutNoiseLayer);
    supervisedNetwork.add(revNetwork);
    supervisedNetwork.add(new MeanSqLossLayer(), supervisedNetwork.getHead(), supervisedNetwork.getInput(0));
    log.h3("Network Diagrams");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(fwdNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(revNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(supervisedNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final TrainingMonitor monitor = new TrainingMonitor() {

        @Nonnull
        TrainingMonitor inner = TestUtil.getMonitor(history);

        @Override
        public void log(final String msg) {
            inner.log(msg);
        }

        @Override
        public void onStepComplete(final Step currentPoint) {
            dropoutNoiseLayer.shuffle(StochasticComponent.random.get().nextLong());
            inner.onStepComplete(currentPoint);
        }
    };
    final Tensor[][] trainingData = getTrainingData(log);
    // MonitoredObject monitoringRoot = new MonitoredObject();
    // TestUtil.addMonitoring(supervisedNetwork, monitoringRoot);
    log.h3("Training");
    TestUtil.instrumentPerformance(supervisedNetwork);
    @Nonnull final ValidatingTrainer trainer = optimizer.train(log, new SampledArrayTrainable(trainingData, supervisedNetwork, trainingData.length / 2, batchSize), new ArrayTrainable(trainingData, supervisedNetwork, batchSize), monitor);
    log.code(() -> {
        trainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(10000).run();
    });
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    TestUtil.extractPerformance(log, supervisedNetwork);
    {
        @Nonnull final String modelName = "encoder_model" + AutoencodingProblem.modelNo++ + ".json";
        log.p("Saved model as " + log.file(fwdNetwork.getJson().toString(), modelName, modelName));
    }
    @Nonnull final String modelName = "decoder_model" + AutoencodingProblem.modelNo++ + ".json";
    log.p("Saved model as " + log.file(revNetwork.getJson().toString(), modelName, modelName));
    // log.h3("Metrics");
    // log.code(() -> {
    // return TestUtil.toFormattedJson(monitoringRoot.getMetrics());
    // });
    log.h3("Validation");
    log.p("Here are some re-encoded examples:");
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        data.validationData().map(labeledObject -> {
            return toRow(log, labeledObject, echoNetwork.eval(labeledObject.data).getData().get(0).getData());
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = revNetwork.eval(input).getData().get(0);
        log.out(log.image(tensor.toImage(), ""));
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Format(guru.nidi.graphviz.engine.Format) LabeledObject(com.simiacryptus.util.test.LabeledObject) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) DropoutNoiseLayer(com.simiacryptus.mindseye.layers.java.DropoutNoiseLayer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 8 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class EncodingProblem method run.

@Nonnull
@Override
public EncodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    Tensor[][] trainingData;
    try {
        trainingData = data.trainingData().map(labeledObject -> {
            return new Tensor[] { new Tensor(features).set(this::random), labeledObject.data };
        }).toArray(i -> new Tensor[i][]);
    } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
    }
    @Nonnull final DAGNetwork imageNetwork = revFactory.vectorToImage(log, features);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(imageNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final PipelineNetwork trainingNetwork = new PipelineNetwork(2);
    @Nullable final DAGNode image = trainingNetwork.add(imageNetwork, trainingNetwork.getInput(0));
    @Nullable final DAGNode softmax = trainingNetwork.add(new SoftmaxActivationLayer(), trainingNetwork.getInput(0));
    trainingNetwork.add(new SumInputsLayer(), trainingNetwork.add(new EntropyLossLayer(), softmax, softmax), trainingNetwork.add(new NthPowerActivationLayer().setPower(1.0 / 2.0), trainingNetwork.add(new MeanSqLossLayer(), image, trainingNetwork.getInput(1))));
    log.h3("Training");
    log.p("We start by training apply a very small population to improve initial convergence performance:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final Tensor[][] primingData = Arrays.copyOfRange(trainingData, 0, 1000);
    @Nonnull final ValidatingTrainer preTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(primingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(primingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        preTrainer.setTimeout(timeoutMinutes / 2, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    log.p("Then our main training phase:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final ValidatingTrainer mainTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(trainingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(trainingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        mainTrainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName().toString() + EncodingProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        log.appendFrontMatterProperty("result_plot", filename, ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // log.file()
    @Nonnull final String modelName = "encoding_model_" + EncodingProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(trainingNetwork.getJson().toString(), modelName, modelName));
    log.h3("Results");
    @Nonnull final PipelineNetwork testNetwork = new PipelineNetwork(2);
    testNetwork.add(imageNetwork, testNetwork.getInput(0));
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        Arrays.stream(trainingData).map(tensorArray -> {
            @Nullable final Tensor predictionSignal = testNetwork.eval(tensorArray).getData().get(0);
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Source", log.image(tensorArray[1].toImage(), ""));
            row.put("Echo", log.image(predictionSignal.toImage(), ""));
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Learned Model Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        trainingNetwork.state().stream().flatMapToDouble(x -> Arrays.stream(x)).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Learned Representation Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        Arrays.stream(trainingData).flatMapToDouble(row -> Arrays.stream(row[0].getData())).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = imageNetwork.eval(input).getData().get(0);
        TestUtil.renderToImages(tensor, true).forEach(img -> {
            log.out(log.image(img, ""));
        });
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) Format(guru.nidi.graphviz.engine.Format) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IOException(java.io.IOException) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNode(com.simiacryptus.mindseye.network.DAGNode) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 9 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class AutoencoderNetwork method train.

/**
 * Train autoencoder network . training parameters.
 *
 * @return the autoencoder network . training parameters
 */
@Nonnull
public AutoencoderNetwork.TrainingParameters train() {
    return new AutoencoderNetwork.TrainingParameters() {

        @Nonnull
        @Override
        public SimpleLossNetwork getTrainingNetwork() {
            @Nonnull final PipelineNetwork student = new PipelineNetwork();
            student.add(encoder);
            student.add(decoder);
            return new SimpleLossNetwork(student, new MeanSqLossLayer());
        }

        @Nonnull
        @Override
        protected TrainingMonitor wrap(@Nonnull final TrainingMonitor monitor) {
            return new TrainingMonitor() {

                @Override
                public void log(final String msg) {
                    monitor.log(msg);
                }

                @Override
                public void onStepComplete(final Step currentPoint) {
                    inputNoise.shuffle();
                    encodedNoise.shuffle(StochasticComponent.random.get().nextLong());
                    monitor.onStepComplete(currentPoint);
                }
            };
        }
    };
}
Also used : TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Step(com.simiacryptus.mindseye.opt.Step) SimpleLossNetwork(com.simiacryptus.mindseye.network.SimpleLossNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Nonnull(javax.annotation.Nonnull)

Example 10 with TrainingMonitor

use of com.simiacryptus.mindseye.opt.TrainingMonitor in project MindsEye by SimiaCryptus.

the class QuantifyOrientationWrapper method orient.

@Override
public LineSearchCursor orient(final Trainable subject, final PointSample measurement, @Nonnull final TrainingMonitor monitor) {
    final LineSearchCursor cursor = inner.orient(subject, measurement, monitor);
    if (cursor instanceof SimpleLineSearchCursor) {
        final DeltaSet<Layer> direction = ((SimpleLineSearchCursor) cursor).direction;
        @Nonnull final StateSet<Layer> weights = ((SimpleLineSearchCursor) cursor).origin.weights;
        final Map<CharSequence, CharSequence> dataMap = weights.stream().collect(Collectors.groupingBy(x -> getId(x), Collectors.toList())).entrySet().stream().collect(Collectors.toMap(x -> x.getKey(), list -> {
            final List<Double> doubleList = list.getValue().stream().map(weightDelta -> {
                final DoubleBuffer<Layer> dirDelta = direction.getMap().get(weightDelta.layer);
                final double denominator = weightDelta.deltaStatistics().rms();
                final double numerator = null == dirDelta ? 0 : dirDelta.deltaStatistics().rms();
                return numerator / (0 == denominator ? 1 : denominator);
            }).collect(Collectors.toList());
            if (1 == doubleList.size())
                return Double.toString(doubleList.get(0));
            return new DoubleStatistics().accept(doubleList.stream().mapToDouble(x -> x).toArray()).toString();
        }));
        monitor.log(String.format("Line search stats: %s", dataMap));
    } else {
        monitor.log(String.format("Non-simple cursor: %s", cursor));
    }
    return cursor;
}
Also used : DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Collectors(java.util.stream.Collectors) StateSet(com.simiacryptus.mindseye.lang.StateSet) Trainable(com.simiacryptus.mindseye.eval.Trainable) List(java.util.List) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) Nonnull(javax.annotation.Nonnull) PointSample(com.simiacryptus.mindseye.lang.PointSample) Nonnull(javax.annotation.Nonnull) Layer(com.simiacryptus.mindseye.lang.Layer) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) LineSearchCursor(com.simiacryptus.mindseye.opt.line.LineSearchCursor) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleLineSearchCursor(com.simiacryptus.mindseye.opt.line.SimpleLineSearchCursor) List(java.util.List)

Aggregations

TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)19 Nonnull (javax.annotation.Nonnull)19 Layer (com.simiacryptus.mindseye.lang.Layer)13 List (java.util.List)12 Nullable (javax.annotation.Nullable)12 PointSample (com.simiacryptus.mindseye.lang.PointSample)10 Tensor (com.simiacryptus.mindseye.lang.Tensor)10 Arrays (java.util.Arrays)10 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)9 Trainable (com.simiacryptus.mindseye.eval.Trainable)8 StepRecord (com.simiacryptus.mindseye.test.StepRecord)8 IntStream (java.util.stream.IntStream)8 ArrayTrainable (com.simiacryptus.mindseye.eval.ArrayTrainable)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 ArrayList (java.util.ArrayList)7 Result (com.simiacryptus.mindseye.lang.Result)6 StateSet (com.simiacryptus.mindseye.lang.StateSet)6 IterativeTrainer (com.simiacryptus.mindseye.opt.IterativeTrainer)6 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)5 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)5