Search in sources :

Example 1 with IntToDoubleFunction

use of java.util.function.IntToDoubleFunction in project MindsEye by SimiaCryptus.

the class MaxPoolingLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    final Result in = inObj[0];
    in.getData().length();
    @Nonnull final int[] inputDims = in.getData().getDimensions();
    final List<Tuple2<Integer, int[]>> regions = MaxPoolingLayer.calcRegionsCache.apply(new MaxPoolingLayer.CalcRegionsParameter(inputDims, kernelDims));
    final Tensor[] outputA = IntStream.range(0, in.getData().length()).mapToObj(dataIndex -> {
        final int[] newDims = IntStream.range(0, inputDims.length).map(i -> {
            return (int) Math.ceil(inputDims[i] * 1.0 / kernelDims[i]);
        }).toArray();
        @Nonnull final Tensor output = new Tensor(newDims);
        return output;
    }).toArray(i -> new Tensor[i]);
    Arrays.stream(outputA).mapToInt(x -> x.length()).sum();
    @Nonnull final int[][] gradientMapA = new int[in.getData().length()][];
    IntStream.range(0, in.getData().length()).forEach(dataIndex -> {
        @Nullable final Tensor input = in.getData().get(dataIndex);
        final Tensor output = outputA[dataIndex];
        @Nonnull final IntToDoubleFunction keyExtractor = inputCoords -> input.get(inputCoords);
        @Nonnull final int[] gradientMap = new int[input.length()];
        regions.parallelStream().forEach(tuple -> {
            final Integer from = tuple.getFirst();
            final int[] toList = tuple.getSecond();
            int toMax = -1;
            double bestValue = Double.NEGATIVE_INFINITY;
            for (final int c : toList) {
                final double value = keyExtractor.applyAsDouble(c);
                if (-1 == toMax || bestValue < value) {
                    bestValue = value;
                    toMax = c;
                }
            }
            gradientMap[from] = toMax;
            output.set(from, input.get(toMax));
        });
        input.freeRef();
        gradientMapA[dataIndex] = gradientMap;
    });
    return new Result(TensorArray.wrap(outputA), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (in.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, in.getData().length()).parallel().mapToObj(dataIndex -> {
                @Nonnull final Tensor backSignal = new Tensor(inputDims);
                final int[] ints = gradientMapA[dataIndex];
                @Nullable final Tensor datum = data.get(dataIndex);
                for (int i = 0; i < datum.length(); i++) {
                    backSignal.add(ints[i], datum.get(i));
                }
                datum.freeRef();
                return backSignal;
            }).toArray(i -> new Tensor[i]));
            in.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return in.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Util(com.simiacryptus.util.Util) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) IntToDoubleFunction(java.util.function.IntToDoubleFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) JsonUtil(com.simiacryptus.util.io.JsonUtil) Tuple2(com.simiacryptus.util.lang.Tuple2) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) IntToDoubleFunction(java.util.function.IntToDoubleFunction) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Tuple2(com.simiacryptus.util.lang.Tuple2) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 2 with IntToDoubleFunction

use of java.util.function.IntToDoubleFunction in project gatk by broadinstitute.

the class IndexRangeUnitTest method testSum.

@Test(dataProvider = "correctFromToData", dependsOnMethods = "testCorrectConstruction")
public void testSum(final int from, final int to) {
    final IndexRange range = new IndexRange(from, to);
    final IntToDoubleFunction func = Math::exp;
    Assert.assertEquals(range.sum(func), IntStream.range(from, to).mapToDouble(func).sum(), 1.0e-8);
}
Also used : IntToDoubleFunction(java.util.function.IntToDoubleFunction) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 3 with IntToDoubleFunction

use of java.util.function.IntToDoubleFunction in project gatk by broadinstitute.

the class IndexRangeUnitTest method testMapToDouble.

@Test(dataProvider = "correctFromToData", dependsOnMethods = "testCorrectConstruction")
public void testMapToDouble(final int from, final int to) {
    final IndexRange range = new IndexRange(from, to);
    final IntToDoubleFunction func = Math::exp;
    Assert.assertEquals(range.mapToDouble(func), IntStream.range(from, to).mapToDouble(func).toArray());
}
Also used : IntToDoubleFunction(java.util.function.IntToDoubleFunction) BaseTest(org.broadinstitute.hellbender.utils.test.BaseTest) Test(org.testng.annotations.Test)

Example 4 with IntToDoubleFunction

use of java.util.function.IntToDoubleFunction in project java-certification by springapidev.

the class IntToDoubleFunctionEx method main.

public static void main(String[] args) {
    System.out.println("$x: " + $x);
    IntToDoubleFunction function = (a) -> (a / 3d);
    System.out.println(function.applyAsDouble(9));
    System.out.println(function.applyAsDouble(22));
}
Also used : IntToDoubleFunction(java.util.function.IntToDoubleFunction) IntToDoubleFunction(java.util.function.IntToDoubleFunction)

Aggregations

IntToDoubleFunction (java.util.function.IntToDoubleFunction)4 BaseTest (org.broadinstitute.hellbender.utils.test.BaseTest)2 Test (org.testng.annotations.Test)2 JsonObject (com.google.gson.JsonObject)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)1 Result (com.simiacryptus.mindseye.lang.Result)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)1 TensorList (com.simiacryptus.mindseye.lang.TensorList)1 Util (com.simiacryptus.util.Util)1 JsonUtil (com.simiacryptus.util.io.JsonUtil)1 Tuple2 (com.simiacryptus.util.lang.Tuple2)1 Arrays (java.util.Arrays)1 List (java.util.List)1 Map (java.util.Map)1 Function (java.util.function.Function)1 Collectors (java.util.stream.Collectors)1