Search in sources :

Example 6 with DoubleStatistics

use of com.simiacryptus.util.data.DoubleStatistics in project MindsEye by SimiaCryptus.

the class PerformanceTester method test.

/**
 * Test.
 *
 * @param component      the component
 * @param inputPrototype the input prototype
 */
public void test(@Nonnull final Layer component, @Nonnull final Tensor[] inputPrototype) {
    log.info(String.format("%s batch length, %s trials", batches, samples));
    log.info("Input Dimensions:");
    Arrays.stream(inputPrototype).map(t -> "\t" + Arrays.toString(t.getDimensions())).forEach(System.out::println);
    log.info("Performance:");
    List<Tuple2<Double, Double>> performance = IntStream.range(0, samples).mapToObj(i -> {
        return testPerformance(component, inputPrototype);
    }).collect(Collectors.toList());
    if (isTestEvaluation()) {
        @Nonnull final DoubleStatistics statistics = new DoubleStatistics().accept(performance.stream().mapToDouble(x -> x._1).toArray());
        log.info(String.format("\tEvaluation performance: %.6fs +- %.6fs [%.6fs - %.6fs]", statistics.getAverage(), statistics.getStandardDeviation(), statistics.getMin(), statistics.getMax()));
    }
    if (isTestLearning()) {
        @Nonnull final DoubleStatistics statistics = new DoubleStatistics().accept(performance.stream().mapToDouble(x -> x._2).toArray());
        if (null != statistics) {
            log.info(String.format("\tLearning performance: %.6fs +- %.6fs [%.6fs - %.6fs]", statistics.getAverage(), statistics.getStandardDeviation(), statistics.getMin(), statistics.getMax()));
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Tuple2(com.simiacryptus.util.lang.Tuple2) List(java.util.List) Stream(java.util.stream.Stream) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) Tuple2(com.simiacryptus.util.lang.Tuple2) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics)

Example 7 with DoubleStatistics

use of com.simiacryptus.util.data.DoubleStatistics in project MindsEye by SimiaCryptus.

the class ReferenceIO method test.

@Nullable
@Override
public ToleranceStatistics test(@Nonnull final NotebookOutput log, @Nonnull final Layer layer, @Nonnull final Tensor... inputPrototype) {
    if (!referenceIO.isEmpty()) {
        log.h1("Reference Input/Output Pairs");
        log.p("Display pre-setBytes input/output example pairs:");
        referenceIO.forEach((input, output) -> {
            log.code(() -> {
                @Nonnull final SimpleEval eval = SimpleEval.run(layer, input);
                Tensor add = output.scale(-1).addAndFree(eval.getOutput());
                @Nonnull final DoubleStatistics error = new DoubleStatistics().accept(add.getData());
                add.freeRef();
                String format = String.format("--------------------\nInput: \n[%s]\n--------------------\nOutput: \n%s\nError: %s\n--------------------\nDerivative: \n%s", Arrays.stream(input).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get(), eval.getOutput().prettyPrint(), error, Arrays.stream(eval.getDerivative()).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).get());
                eval.freeRef();
                return format;
            });
        });
    } else {
        log.h1("Example Input/Output Pair");
        log.p("Display input/output pairs from random executions:");
        log.code(() -> {
            @Nonnull final SimpleEval eval = SimpleEval.run(layer, inputPrototype);
            String format = String.format("--------------------\nInput: \n[%s]\n--------------------\nOutput: \n%s\n--------------------\nDerivative: \n%s", Arrays.stream(inputPrototype).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).orElse(""), eval.getOutput().prettyPrint(), Arrays.stream(eval.getDerivative()).map(t -> t.prettyPrint()).reduce((a, b) -> a + ",\n" + b).orElse(""));
            eval.freeRef();
            return format;
        });
    }
    return null;
}
Also used : Arrays(java.util.Arrays) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) Layer(com.simiacryptus.mindseye.lang.Layer) Tensor(com.simiacryptus.mindseye.lang.Tensor) HashMap(java.util.HashMap) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) SimpleEval(com.simiacryptus.mindseye.test.SimpleEval) Nullable(javax.annotation.Nullable)

Example 8 with DoubleStatistics

use of com.simiacryptus.util.data.DoubleStatistics in project MindsEye by SimiaCryptus.

the class PCAUtil method getCovariance.

/**
 * Forked from Apache Commons Math
 *
 * @param stream the stream
 * @return covariance covariance
 */
@Nonnull
public static RealMatrix getCovariance(@Nonnull final Supplier<Stream<double[]>> stream) {
    final int dimension = stream.get().findAny().get().length;
    final List<DoubleStatistics> statList = IntStream.range(0, dimension * dimension).mapToObj(i -> new DoubleStatistics()).collect(Collectors.toList());
    stream.get().forEach(array -> {
        for (int i = 0; i < dimension; i++) {
            for (int j = 0; j <= i; j++) {
                statList.get(i * dimension + j).accept(array[i] * array[j]);
            }
        }
        RecycleBin.DOUBLES.recycle(array, array.length);
    });
    @Nonnull final RealMatrix covariance = new BlockRealMatrix(dimension, dimension);
    for (int i = 0; i < dimension; i++) {
        for (int j = 0; j <= i; j++) {
            final double v = statList.get(i + dimension * j).getAverage();
            covariance.setEntry(i, j, v);
            covariance.setEntry(j, i, v);
        }
    }
    return covariance;
}
Also used : IntStream(java.util.stream.IntStream) BlockRealMatrix(org.apache.commons.math3.linear.BlockRealMatrix) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) Tensor(com.simiacryptus.mindseye.lang.Tensor) Supplier(java.util.function.Supplier) Collectors(java.util.stream.Collectors) RecycleBin(com.simiacryptus.mindseye.lang.RecycleBin) List(java.util.List) Stream(java.util.stream.Stream) EigenDecomposition(org.apache.commons.math3.linear.EigenDecomposition) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) BlockRealMatrix(org.apache.commons.math3.linear.BlockRealMatrix) RealMatrix(org.apache.commons.math3.linear.RealMatrix) Nonnull(javax.annotation.Nonnull) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) BlockRealMatrix(org.apache.commons.math3.linear.BlockRealMatrix) Nonnull(javax.annotation.Nonnull)

Aggregations

DoubleStatistics (com.simiacryptus.util.data.DoubleStatistics)8 Nonnull (javax.annotation.Nonnull)8 Layer (com.simiacryptus.mindseye.lang.Layer)7 Tensor (com.simiacryptus.mindseye.lang.Tensor)7 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)6 List (java.util.List)6 Collectors (java.util.stream.Collectors)6 IntStream (java.util.stream.IntStream)6 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)5 Arrays (java.util.Arrays)5 Stream (java.util.stream.Stream)5 Nullable (javax.annotation.Nullable)5 TrainingMonitor (com.simiacryptus.mindseye.opt.TrainingMonitor)4 BufferedImage (java.awt.image.BufferedImage)4 File (java.io.File)4 IOException (java.io.IOException)4 Comparator (java.util.Comparator)4 Supplier (java.util.function.Supplier)4 ImageIO (javax.imageio.ImageIO)4 Logger (org.slf4j.Logger)4