Search in sources :

Example 6 with Tuple2

use of com.simiacryptus.util.lang.Tuple2 in project MindsEye by SimiaCryptus.

the class PerformanceTester method test.

/**
 * Test.
 *
 * @param component      the component
 * @param inputPrototype the input prototype
 */
public void test(@Nonnull final Layer component, @Nonnull final Tensor[] inputPrototype) {
    log.info(String.format("%s batch length, %s trials", batches, samples));
    log.info("Input Dimensions:");
    Arrays.stream(inputPrototype).map(t -> "\t" + Arrays.toString(t.getDimensions())).forEach(System.out::println);
    log.info("Performance:");
    List<Tuple2<Double, Double>> performance = IntStream.range(0, samples).mapToObj(i -> {
        return testPerformance(component, inputPrototype);
    }).collect(Collectors.toList());
    if (isTestEvaluation()) {
        @Nonnull final DoubleStatistics statistics = new DoubleStatistics().accept(performance.stream().mapToDouble(x -> x._1).toArray());
        log.info(String.format("\tEvaluation performance: %.6fs +- %.6fs [%.6fs - %.6fs]", statistics.getAverage(), statistics.getStandardDeviation(), statistics.getMin(), statistics.getMax()));
    }
    if (isTestLearning()) {
        @Nonnull final DoubleStatistics statistics = new DoubleStatistics().accept(performance.stream().mapToDouble(x -> x._2).toArray());
        if (null != statistics) {
            log.info(String.format("\tLearning performance: %.6fs +- %.6fs [%.6fs - %.6fs]", statistics.getAverage(), statistics.getStandardDeviation(), statistics.getMin(), statistics.getMax()));
        }
    }
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TestUtil(com.simiacryptus.mindseye.test.TestUtil) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) Tuple2(com.simiacryptus.util.lang.Tuple2) List(java.util.List) Stream(java.util.stream.Stream) ConstantResult(com.simiacryptus.mindseye.lang.ConstantResult) ToleranceStatistics(com.simiacryptus.mindseye.test.ToleranceStatistics) TimedResult(com.simiacryptus.util.lang.TimedResult) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DAGNetwork(com.simiacryptus.mindseye.network.DAGNetwork) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) NotebookOutput(com.simiacryptus.util.io.NotebookOutput) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull) Tuple2(com.simiacryptus.util.lang.Tuple2) DoubleStatistics(com.simiacryptus.util.data.DoubleStatistics)

Example 7 with Tuple2

use of com.simiacryptus.util.lang.Tuple2 in project MindsEye by SimiaCryptus.

the class MaxPoolingLayer method calcRegions.

private static List<Tuple2<Integer, int[]>> calcRegions(@Nonnull final MaxPoolingLayer.CalcRegionsParameter p) {
    @Nonnull final Tensor input = new Tensor(p.inputDims);
    final int[] newDims = IntStream.range(0, p.inputDims.length).map(i -> {
        // assert 0 == p.inputDims[i] % p.kernelDims[i];
        return (int) Math.ceil(p.inputDims[i] * 1.0 / p.kernelDims[i]);
    }).toArray();
    @Nonnull final Tensor output = new Tensor(newDims);
    List<Tuple2<Integer, int[]>> tuple2s = output.coordStream(true).map(o -> {
        Tensor tensor = new Tensor(p.kernelDims);
        final int[] inCoords = tensor.coordStream(true).mapToInt(kernelCoord -> {
            @Nonnull final int[] result = new int[o.getCoords().length];
            for (int index = 0; index < o.getCoords().length; index++) {
                final int outputCoordinate = o.getCoords()[index];
                final int kernelSize = p.kernelDims[index];
                final int baseCoordinate = Math.min(outputCoordinate * kernelSize, p.inputDims[index] - kernelSize);
                final int kernelCoordinate = kernelCoord.getCoords()[index];
                result[index] = baseCoordinate + kernelCoordinate;
            }
            return input.index(result);
        }).toArray();
        tensor.freeRef();
        return new Tuple2<>(o.getIndex(), inCoords);
    }).collect(Collectors.toList());
    input.freeRef();
    output.freeRef();
    return tuple2s;
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Util(com.simiacryptus.util.Util) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) IntToDoubleFunction(java.util.function.IntToDoubleFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) JsonUtil(com.simiacryptus.util.io.JsonUtil) Tuple2(com.simiacryptus.util.lang.Tuple2) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Tuple2(com.simiacryptus.util.lang.Tuple2)

Aggregations

Tuple2 (com.simiacryptus.util.lang.Tuple2)7 Nonnull (javax.annotation.Nonnull)7 Tensor (com.simiacryptus.mindseye.lang.Tensor)5 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)4 Layer (com.simiacryptus.mindseye.lang.Layer)4 Result (com.simiacryptus.mindseye.lang.Result)4 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)4 Arrays (java.util.Arrays)4 List (java.util.List)4 Collectors (java.util.stream.Collectors)4 IntStream (java.util.stream.IntStream)4 Nullable (javax.annotation.Nullable)4 Logger (org.slf4j.Logger)4 LoggerFactory (org.slf4j.LoggerFactory)4 MeanSqLossLayer (com.simiacryptus.mindseye.layers.cudnn.MeanSqLossLayer)3 ValueLayer (com.simiacryptus.mindseye.layers.cudnn.ValueLayer)3 DAGNetwork (com.simiacryptus.mindseye.network.DAGNetwork)3 DAGNode (com.simiacryptus.mindseye.network.DAGNode)3 ArrayList (java.util.ArrayList)3 JsonObject (com.google.gson.JsonObject)2