Search in sources :

Example 6 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class TestUtil method printDataStatistics.

/**
 * Print data statistics.
 *
 * @param log  the log
 * @param data the data
 */
public static void printDataStatistics(@Nonnull final NotebookOutput log, @Nonnull final Tensor[][] data) {
    for (int col = 1; col < data[0].length; col++) {
        final int c = col;
        log.out("Learned Representation Statistics for Column " + col + " (all bands)");
        log.code(() -> {
            @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
            Arrays.stream(data).flatMapToDouble(row -> Arrays.stream(row[c].getData())).forEach(v -> scalarStatistics.add(v));
            return scalarStatistics.getMetrics();
        });
        final int _col = col;
        log.out("Learned Representation Statistics for Column " + col + " (by band)");
        log.code(() -> {
            @Nonnull final int[] dimensions = data[0][_col].getDimensions();
            return IntStream.range(0, dimensions[2]).mapToObj(x -> x).flatMap(b -> {
                return Arrays.stream(data).map(r -> r[_col]).map(tensor -> {
                    @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
                    scalarStatistics.add(new Tensor(dimensions[0], dimensions[1]).setByCoord(coord -> tensor.get(coord.getCoords()[0], coord.getCoords()[1], b)).getData());
                    return scalarStatistics;
                });
            }).map(x -> x.getMetrics().toString()).reduce((a, b) -> a + "\n" + b).get();
        });
    }
}
Also used : Arrays(java.util.Arrays) ScheduledFuture(java.util.concurrent.ScheduledFuture) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) IntUnaryOperator(java.util.function.IntUnaryOperator) BiFunction(java.util.function.BiFunction) LoggerFactory(org.slf4j.LoggerFactory) DoubleSummaryStatistics(java.util.DoubleSummaryStatistics) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) Map(java.util.Map) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) URI(java.net.URI) Graph(guru.nidi.graphviz.model.Graph) LongToIntFunction(java.util.function.LongToIntFunction) StochasticComponent(com.simiacryptus.mindseye.layers.java.StochasticComponent) BufferedImage(java.awt.image.BufferedImage) UUID(java.util.UUID) ComponentEvent(java.awt.event.ComponentEvent) WindowAdapter(java.awt.event.WindowAdapter) DAGNode(com.simiacryptus.mindseye.network.DAGNode) Collectors(java.util.stream.Collectors) WindowEvent(java.awt.event.WindowEvent) Executors(java.util.concurrent.Executors) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) LoggingWrapperLayer(com.simiacryptus.mindseye.layers.java.LoggingWrapperLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) IntStream(java.util.stream.IntStream) MonitoringWrapperLayer(com.simiacryptus.mindseye.layers.java.MonitoringWrapperLayer) ActionListener(java.awt.event.ActionListener) ScatterPlot(smile.plot.ScatterPlot) ByteArrayOutputStream(java.io.ByteArrayOutputStream) LinkSource(guru.nidi.graphviz.model.LinkSource) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) Supplier(java.util.function.Supplier) JsonUtil(com.simiacryptus.util.io.JsonUtil) MutableNode(guru.nidi.graphviz.model.MutableNode) Charset(java.nio.charset.Charset) Factory(guru.nidi.graphviz.model.Factory) ScheduledExecutorService(java.util.concurrent.ScheduledExecutorService) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) WeakReference(java.lang.ref.WeakReference) LinkTarget(guru.nidi.graphviz.model.LinkTarget) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) LongSummaryStatistics(java.util.LongSummaryStatistics) PrintStream(java.io.PrintStream) Logger(org.slf4j.Logger) PlotCanvas(smile.plot.PlotCanvas) RankDir(guru.nidi.graphviz.attribute.RankDir) IOException(java.io.IOException) FileFilter(javax.swing.filechooser.FileFilter) ActionEvent(java.awt.event.ActionEvent) PercentileStatistics(com.simiacryptus.util.data.PercentileStatistics) File(java.io.File) java.awt(java.awt) ComponentAdapter(java.awt.event.ComponentAdapter) TimeUnit(java.util.concurrent.TimeUnit) Consumer(java.util.function.Consumer) MonitoredObject(com.simiacryptus.util.MonitoredObject) IntToLongFunction(java.util.function.IntToLongFunction) Link(guru.nidi.graphviz.model.Link) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) javax.swing(javax.swing) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics)

Example 7 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class EncodingProblem method run.

@Nonnull
@Override
public EncodingProblem run(@Nonnull final NotebookOutput log) {
    @Nonnull final TrainingMonitor monitor = TestUtil.getMonitor(history);
    Tensor[][] trainingData;
    try {
        trainingData = data.trainingData().map(labeledObject -> {
            return new Tensor[] { new Tensor(features).set(this::random), labeledObject.data };
        }).toArray(i -> new Tensor[i][]);
    } catch (@Nonnull final IOException e) {
        throw new RuntimeException(e);
    }
    @Nonnull final DAGNetwork imageNetwork = revFactory.vectorToImage(log, features);
    log.h3("Network Diagram");
    log.code(() -> {
        return Graphviz.fromGraph(TestUtil.toGraph(imageNetwork)).height(400).width(600).render(Format.PNG).toImage();
    });
    @Nonnull final PipelineNetwork trainingNetwork = new PipelineNetwork(2);
    @Nullable final DAGNode image = trainingNetwork.add(imageNetwork, trainingNetwork.getInput(0));
    @Nullable final DAGNode softmax = trainingNetwork.add(new SoftmaxActivationLayer(), trainingNetwork.getInput(0));
    trainingNetwork.add(new SumInputsLayer(), trainingNetwork.add(new EntropyLossLayer(), softmax, softmax), trainingNetwork.add(new NthPowerActivationLayer().setPower(1.0 / 2.0), trainingNetwork.add(new MeanSqLossLayer(), image, trainingNetwork.getInput(1))));
    log.h3("Training");
    log.p("We start by training apply a very small population to improve initial convergence performance:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final Tensor[][] primingData = Arrays.copyOfRange(trainingData, 0, 1000);
    @Nonnull final ValidatingTrainer preTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(primingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(primingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        preTrainer.setTimeout(timeoutMinutes / 2, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    log.p("Then our main training phase:");
    TestUtil.instrumentPerformance(trainingNetwork);
    @Nonnull final ValidatingTrainer mainTrainer = optimizer.train(log, (SampledTrainable) new SampledArrayTrainable(trainingData, trainingNetwork, trainingSize, batchSize).setMinSamples(trainingSize).setMask(true, false), new ArrayTrainable(trainingData, trainingNetwork, batchSize), monitor);
    log.code(() -> {
        mainTrainer.setTimeout(timeoutMinutes, TimeUnit.MINUTES).setMaxIterations(batchSize).run();
    });
    TestUtil.extractPerformance(log, trainingNetwork);
    if (!history.isEmpty()) {
        log.code(() -> {
            return TestUtil.plot(history);
        });
        log.code(() -> {
            return TestUtil.plotTime(history);
        });
    }
    try {
        @Nonnull String filename = log.getName().toString() + EncodingProblem.modelNo++ + "_plot.png";
        ImageIO.write(Util.toImage(TestUtil.plot(history)), "png", log.file(filename));
        log.appendFrontMatterProperty("result_plot", filename, ";");
    } catch (IOException e) {
        throw new RuntimeException(e);
    }
    // log.file()
    @Nonnull final String modelName = "encoding_model_" + EncodingProblem.modelNo++ + ".json";
    log.appendFrontMatterProperty("result_model", modelName, ";");
    log.p("Saved model as " + log.file(trainingNetwork.getJson().toString(), modelName, modelName));
    log.h3("Results");
    @Nonnull final PipelineNetwork testNetwork = new PipelineNetwork(2);
    testNetwork.add(imageNetwork, testNetwork.getInput(0));
    log.code(() -> {
        @Nonnull final TableOutput table = new TableOutput();
        Arrays.stream(trainingData).map(tensorArray -> {
            @Nullable final Tensor predictionSignal = testNetwork.eval(tensorArray).getData().get(0);
            @Nonnull final LinkedHashMap<CharSequence, Object> row = new LinkedHashMap<>();
            row.put("Source", log.image(tensorArray[1].toImage(), ""));
            row.put("Echo", log.image(predictionSignal.toImage(), ""));
            return row;
        }).filter(x -> null != x).limit(10).forEach(table::putRow);
        return table;
    });
    log.p("Learned Model Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        trainingNetwork.state().stream().flatMapToDouble(x -> Arrays.stream(x)).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Learned Representation Statistics:");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        Arrays.stream(trainingData).flatMapToDouble(row -> Arrays.stream(row[0].getData())).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    log.p("Some rendered unit vectors:");
    for (int featureNumber = 0; featureNumber < features; featureNumber++) {
        @Nonnull final Tensor input = new Tensor(features).set(featureNumber, 1);
        @Nullable final Tensor tensor = imageNetwork.eval(input).getData().get(0);
        TestUtil.renderToImages(tensor, true).forEach(img -> {
            log.out(log.image(img, ""));
        });
    }
    return this;
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) Graphviz(guru.nidi.graphviz.engine.Graphviz) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) Arrays(java.util.Arrays) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) SampledTrainable(com.simiacryptus.mindseye.eval.SampledTrainable) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) Format(guru.nidi.graphviz.engine.Format) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) DAGNode(com.simiacryptus.mindseye.network.DAGNode) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) LinkedHashMap(java.util.LinkedHashMap) SoftmaxActivationLayer(com.simiacryptus.mindseye.layers.java.SoftmaxActivationLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) TableOutput(com.simiacryptus.util.TableOutput) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IOException(java.io.IOException) EntropyLossLayer(com.simiacryptus.mindseye.layers.java.EntropyLossLayer) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) ArrayTrainable(com.simiacryptus.mindseye.eval.ArrayTrainable) DAGNode(com.simiacryptus.mindseye.network.DAGNode) SampledArrayTrainable(com.simiacryptus.mindseye.eval.SampledArrayTrainable) NthPowerActivationLayer(com.simiacryptus.mindseye.layers.java.NthPowerActivationLayer) ValidatingTrainer(com.simiacryptus.mindseye.opt.ValidatingTrainer) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 8 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class BatchDerivativeTester method testLearning.

/**
 * Test learning tolerance statistics.
 *
 * @param component  the component
 * @param IOPair     the io pair
 * @param statistics the statistics
 * @return the tolerance statistics
 */
public ToleranceStatistics testLearning(@Nonnull Layer component, @Nonnull IOPair IOPair, ToleranceStatistics statistics) {
    final ToleranceStatistics prev = statistics;
    statistics = IntStream.range(0, component.state().size()).mapToObj(i -> {
        @Nullable final Tensor measuredGradient = !verify ? null : measureLearningGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
        @Nonnull final Tensor implementedGradient = getLearningGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
        try {
            final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
                return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
            }).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
            if (!(result.absoluteTol.getMax() < tolerance)) {
                throw new AssertionError(result.toString());
            } else {
                // log.info(String.format("Component: %s", component));
                if (verbose) {
                    log.info(String.format("Learning Gradient for weight setByCoord %s", i));
                    log.info(String.format("Weights: %s", new Tensor(component.state().get(i)).prettyPrint()));
                    log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
                    log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
                    if (null != measuredGradient) {
                        log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
                        log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                        log.info(String.format("Gradient Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
                        log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
                    }
                }
                return result;
            }
        } catch (@Nonnull final Throwable e) {
            // log.info(String.format("Component: %s", component));
            log.info(String.format("Learning Gradient for weight setByCoord %s", i));
            log.info(String.format("Implemented Gradient: %s", implementedGradient.prettyPrint()));
            log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
            if (null != measuredGradient) {
                log.info(String.format("Measured Gradient: %s", measuredGradient.prettyPrint()));
                log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                log.info(String.format("Gradient Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
                log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
            }
            throw e;
        }
    }).reduce((a, b) -> a.combine(b)).map(x -> x.combine(prev)).orElseGet(() -> prev);
    return statistics;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) Nullable(javax.annotation.Nullable)

Example 9 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class BatchDerivativeTester method test.

/**
 * Test tolerance statistics.
 *
 * @param log
 * @param component      the component
 * @param inputPrototype the input prototype
 * @return the tolerance statistics
 */
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer component, @Nonnull final Tensor... inputPrototype) {
    log.h1("Differential Validation");
    @Nonnull IOPair ioPair = new IOPair(component, inputPrototype[0]).invoke();
    if (verbose) {
        log.code(() -> {
            BatchDerivativeTester.log.info(String.format("Inputs: %s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get()));
            BatchDerivativeTester.log.info(String.format("Inputs Statistics: %s", Arrays.stream(inputPrototype).map(x -> new ScalarStatistics().add(x.getData()).toString()).reduce((a, b) -> a + ",\n" + b).get()));
            BatchDerivativeTester.log.info(String.format("Output: %s", ioPair.getOutputPrototype().prettyPrint()));
            BatchDerivativeTester.log.info(String.format("Outputs Statistics: %s", new ScalarStatistics().add(ioPair.getOutputPrototype().getData())));
        });
    }
    ToleranceStatistics _statistics = new ToleranceStatistics();
    if (isTestFeedback()) {
        log.h2("Feedback Validation");
        log.p("We validate the agreement between the implemented derivative _of the inputs_ apply finite difference estimations:");
        ToleranceStatistics statistics = _statistics;
        _statistics = log.code(() -> {
            return testFeedback(component, ioPair, statistics);
        });
    }
    if (isTestLearning()) {
        log.h2("Learning Validation");
        log.p("We validate the agreement between the implemented derivative _of the internal weights_ apply finite difference estimations:");
        ToleranceStatistics statistics = _statistics;
        _statistics = log.code(() -> {
            return testLearning(component, ioPair, statistics);
        });
    }
    log.h2("Total Accuracy");
    log.p("The overall agreement accuracy between the implemented derivative and the finite difference estimations:");
    ToleranceStatistics statistics = _statistics;
    log.code(() -> {
        // log.info(String.format("Component: %s\nInputs: %s\noutput=%s", component, Arrays.toString(inputPrototype), outputPrototype));
        BatchDerivativeTester.log.info(String.format("Finite-Difference Derivative Accuracy:"));
        BatchDerivativeTester.log.info(String.format("absoluteTol: %s", statistics.absoluteTol));
        BatchDerivativeTester.log.info(String.format("relativeTol: %s", statistics.relativeTol));
    });
    log.h2("Frozen and Alive Status");
    log.code(() -> {
        testFrozen(component, ioPair.getInputPrototype());
        testUnFrozen(component, ioPair.getInputPrototype());
    });
    return _statistics;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics)

Example 10 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class BatchDerivativeTester method testFeedback.

/**
 * Test feedback tolerance statistics.
 *
 * @param component  the component
 * @param IOPair     the io pair
 * @param statistics the statistics
 * @return the tolerance statistics
 */
public ToleranceStatistics testFeedback(@Nonnull Layer component, @Nonnull IOPair IOPair, ToleranceStatistics statistics) {
    statistics = statistics.combine(IntStream.range(0, IOPair.getInputPrototype().length).mapToObj(i -> {
        @Nullable final Tensor measuredGradient = !verify ? null : measureFeedbackGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
        @Nonnull final Tensor implementedGradient = getFeedbackGradient(component, i, IOPair.getOutputPrototype(), IOPair.getInputPrototype());
        try {
            final ToleranceStatistics result = IntStream.range(0, null == measuredGradient ? 0 : measuredGradient.length()).mapToObj(i1 -> {
                return new ToleranceStatistics().accumulate(measuredGradient.getData()[i1], implementedGradient.getData()[i1]);
            }).reduce((a, b) -> a.combine(b)).orElse(new ToleranceStatistics());
            if (!(result.absoluteTol.getMax() < tolerance))
                throw new AssertionError(result.toString());
            // log.info(String.format("Component: %s", component));
            if (verbose) {
                log.info(String.format("Feedback for input %s", i));
                log.info(String.format("Inputs Values: %s", IOPair.getInputPrototype()[i].prettyPrint()));
                log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(IOPair.getInputPrototype()[i].getData())));
                log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
                log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
                if (null != measuredGradient) {
                    log.info(String.format("Measured Feedback: %s", measuredGradient.prettyPrint()));
                    log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                    log.info(String.format("Feedback Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
                    log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
                }
            }
            return result;
        } catch (@Nonnull final Throwable e) {
            // log.info(String.format("Component: %s", component));
            log.info(String.format("Feedback for input %s", i));
            log.info(String.format("Inputs Values: %s", IOPair.getInputPrototype()[i].prettyPrint()));
            log.info(String.format("Value Statistics: %s", new ScalarStatistics().add(IOPair.getInputPrototype()[i].getData())));
            log.info(String.format("Implemented Feedback: %s", implementedGradient.prettyPrint()));
            log.info(String.format("Implemented Statistics: %s", new ScalarStatistics().add(implementedGradient.getData())));
            if (null != measuredGradient) {
                log.info(String.format("Measured: %s", measuredGradient.prettyPrint()));
                log.info(String.format("Measured Statistics: %s", new ScalarStatistics().add(measuredGradient.getData())));
                log.info(String.format("Feedback Error: %s", measuredGradient.minus(implementedGradient).prettyPrint()));
                log.info(String.format("Error Statistics: %s", new ScalarStatistics().add(measuredGradient.minus(implementedGradient).getData())));
            }
            throw e;
        }
    }).reduce((a, b) -> a.combine(b)).get());
    return statistics;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) Nullable(javax.annotation.Nullable)

Aggregations

ScalarStatistics (com.simiacryptus.util.data.ScalarStatistics)12 Nonnull (javax.annotation.Nonnull)12 Tensor (com.simiacryptus.mindseye.lang.Tensor)11 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)11 Arrays (java.util.Arrays)11 List (java.util.List)11 Nullable (javax.annotation.Nullable)11 Layer (com.simiacryptus.mindseye.lang.Layer)10 Collectors (java.util.stream.Collectors)10 IntStream (java.util.stream.IntStream)10 Logger (org.slf4j.Logger)10 LoggerFactory (org.slf4j.LoggerFactory)10 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 SimpleEval (com.simiacryptus.mindseye.test.SimpleEval)7 ToleranceStatistics (com.simiacryptus.mindseye.test.ToleranceStatistics)7 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)6 Delta (com.simiacryptus.mindseye.lang.Delta)6 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)6