Search in sources :

Example 11 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class SingleDerivativeTester method test.

/**
 * Test tolerance statistics.
 *
 * @param output
 * @param component      the component
 * @param inputPrototype the input prototype
 * @return the tolerance statistics
 */
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput output, @Nonnull final Layer component, @Nonnull final Tensor... inputPrototype) {
    output.h1("Differential Validation");
    ToleranceStatistics _statistics = new ToleranceStatistics();
    final Tensor outputPrototype = SimpleEval.run(component, inputPrototype).getOutputAndFree();
    try {
        if (verbose) {
            output.code(() -> {
                log.info(String.format("Inputs: %s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).orElse("")));
                log.info(String.format("Inputs Statistics: %s", Arrays.stream(inputPrototype).map(x -> new ScalarStatistics().add(x.getData()).toString()).reduce((a, b) -> a + ",\n" + b).orElse("")));
                log.info(String.format("Output: %s", null == outputPrototype ? null : outputPrototype.prettyPrint()));
                log.info(String.format("Outputs Statistics: %s", new ScalarStatistics().add(outputPrototype.getData())));
            });
        }
        if (isTestFeedback()) {
            output.h2("Feedback Validation");
            output.p("We validate the agreement between the implemented derivative _of the inputs_ apply finite difference estimations:");
            final ToleranceStatistics statistics = _statistics;
            _statistics = output.code(() -> {
                return testFeedback(statistics, component, inputPrototype, outputPrototype);
            });
        }
        if (isTestLearning()) {
            output.h2("Learning Validation");
            output.p("We validate the agreement between the implemented derivative _of the internal weights_ apply finite difference estimations:");
            final ToleranceStatistics statistics = _statistics;
            _statistics = output.code(() -> {
                return testLearning(statistics, component, inputPrototype, outputPrototype);
            });
        }
    } finally {
        outputPrototype.freeRef();
    }
    output.h2("Total Accuracy");
    output.p("The overall agreement accuracy between the implemented derivative and the finite difference estimations:");
    final ToleranceStatistics statistics = _statistics;
    output.code(() -> {
        // log.info(String.format("Component: %s\nInputs: %s\noutput=%s", component, Arrays.toString(inputPrototype), outputPrototype));
        log.info(String.format("Finite-Difference Derivative Accuracy:"));
        log.info(String.format("absoluteTol: %s", statistics.absoluteTol));
        log.info(String.format("relativeTol: %s", statistics.relativeTol));
    });
    output.h2("Frozen and Alive Status");
    output.code(() -> {
        testFrozen(component, inputPrototype);
        testUnFrozen(component, inputPrototype);
    });
    return _statistics;
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) DoubleBuffer(com.simiacryptus.mindseye.lang.DoubleBuffer) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Delta(com.simiacryptus.mindseye.lang.Delta) List(java.util.List) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) TensorList(com.simiacryptus.mindseye.lang.TensorList) PlaceholderLayer(com.simiacryptus.mindseye.layers.java.PlaceholderLayer) Layer(com.simiacryptus.mindseye.lang.Layer) Optional(java.util.Optional) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics)

Example 12 with ScalarStatistics

use of com.simiacryptus.util.data.ScalarStatistics in project MindsEye by SimiaCryptus.

the class EncodingUtil method printModel.

/**
 * Print model.
 *
 * @param log     the log
 * @param network the network
 * @param modelNo the model no
 */
public static void printModel(@Nonnull final NotebookOutput log, @Nonnull final Layer network, final int modelNo) {
    log.out("Learned Model Statistics: ");
    log.code(() -> {
        @Nonnull final ScalarStatistics scalarStatistics = new ScalarStatistics();
        network.state().stream().flatMapToDouble(x -> Arrays.stream(x)).forEach(v -> scalarStatistics.add(v));
        return scalarStatistics.getMetrics();
    });
    @Nonnull final String modelName = "model" + modelNo + ".json";
    log.p("Saved model as " + log.file(network.getJson().toString(), modelName, modelName));
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) IntStream(java.util.stream.IntStream) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) GifSequenceWriter(com.simiacryptus.util.io.GifSequenceWriter) TableOutput(com.simiacryptus.util.TableOutput) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Caltech101(com.simiacryptus.mindseye.test.data.Caltech101) Function(java.util.function.Function) LinkedHashMap(java.util.LinkedHashMap) ImgBandScaleLayer(com.simiacryptus.mindseye.layers.java.ImgBandScaleLayer) TrainingMonitor(com.simiacryptus.mindseye.opt.TrainingMonitor) ImageIO(javax.imageio.ImageIO) Layer(com.simiacryptus.mindseye.lang.Layer) StepRecord(com.simiacryptus.mindseye.test.StepRecord) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) PCAUtil(com.simiacryptus.mindseye.test.PCAUtil) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) PrintStream(java.io.PrintStream) MeanSqLossLayer(com.simiacryptus.mindseye.layers.java.MeanSqLossLayer) Logger(org.slf4j.Logger) SysOutInterceptor(com.simiacryptus.util.test.SysOutInterceptor) BufferedImage(java.awt.image.BufferedImage) ImgBandSelectLayer(com.simiacryptus.mindseye.layers.java.ImgBandSelectLayer) IOException(java.io.IOException) TestUtil(com.simiacryptus.mindseye.test.TestUtil) FastRandom(com.simiacryptus.util.FastRandom) Collectors(java.util.stream.Collectors) File(java.io.File) DoubleStream(java.util.stream.DoubleStream) ConvolutionLayer(com.simiacryptus.mindseye.layers.cudnn.ConvolutionLayer) List(java.util.List) Stream(java.util.stream.Stream) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics) ToDoubleFunction(java.util.function.ToDoubleFunction) ImgReshapeLayer(com.simiacryptus.mindseye.layers.java.ImgReshapeLayer) ImgBandBiasLayer(com.simiacryptus.mindseye.layers.java.ImgBandBiasLayer) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) Step(com.simiacryptus.mindseye.opt.Step) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) ScalarStatistics(com.simiacryptus.util.data.ScalarStatistics)

Aggregations

ScalarStatistics (com.simiacryptus.util.data.ScalarStatistics)12 Nonnull (javax.annotation.Nonnull)12 Tensor (com.simiacryptus.mindseye.lang.Tensor)11 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)11 Arrays (java.util.Arrays)11 List (java.util.List)11 Nullable (javax.annotation.Nullable)11 Layer (com.simiacryptus.mindseye.lang.Layer)10 Collectors (java.util.stream.Collectors)10 IntStream (java.util.stream.IntStream)10 Logger (org.slf4j.Logger)10 LoggerFactory (org.slf4j.LoggerFactory)10 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 SimpleEval (com.simiacryptus.mindseye.test.SimpleEval)7 ToleranceStatistics (com.simiacryptus.mindseye.test.ToleranceStatistics)7 ConstantResult (com.simiacryptus.mindseye.lang.ConstantResult)6 Delta (com.simiacryptus.mindseye.lang.Delta)6 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)6 DoubleBuffer (com.simiacryptus.mindseye.lang.DoubleBuffer)6