Search in sources :

Example 6 with Coordinate

use of com.simiacryptus.mindseye.lang.Coordinate in project MindsEye by SimiaCryptus.

the class AvgPoolingLayer method eval.

@Nonnull
@SuppressWarnings("unchecked")
@Override
public Result eval(@Nonnull final Result... inObj) {
    final int kernelSize = Tensor.length(kernelDims);
    final TensorList data = inObj[0].getData();
    @Nonnull final int[] inputDims = data.getDimensions();
    final int[] newDims = IntStream.range(0, inputDims.length).map(i -> {
        assert 0 == inputDims[i] % kernelDims[i] : inputDims[i] + ":" + kernelDims[i];
        return inputDims[i] / kernelDims[i];
    }).toArray();
    final Map<Coordinate, List<int[]>> coordMap = AvgPoolingLayer.getCoordMap(kernelDims, newDims);
    final Tensor[] outputValues = IntStream.range(0, data.length()).mapToObj(dataIndex -> {
        @Nullable final Tensor input = data.get(dataIndex);
        @Nonnull final Tensor output = new Tensor(newDims);
        for (@Nonnull final Entry<Coordinate, List<int[]>> entry : coordMap.entrySet()) {
            double sum = entry.getValue().stream().mapToDouble(inputCoord -> input.get(inputCoord)).sum();
            if (Double.isFinite(sum)) {
                output.add(entry.getKey(), sum / kernelSize);
            }
        }
        input.freeRef();
        return output;
    }).toArray(i -> new Tensor[i]);
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    return new Result(TensorArray.wrap(outputValues), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (inObj[0].isAlive()) {
            final Tensor[] passback = IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                @Nullable Tensor tensor = delta.get(dataIndex);
                @Nonnull final Tensor backSignal = new Tensor(inputDims);
                for (@Nonnull final Entry<Coordinate, List<int[]>> outputMapping : coordMap.entrySet()) {
                    final double outputValue = tensor.get(outputMapping.getKey());
                    for (@Nonnull final int[] inputCoord : outputMapping.getValue()) {
                        backSignal.add(inputCoord, outputValue / kernelSize);
                    }
                }
                tensor.freeRef();
                return backSignal;
            }).toArray(i -> new Tensor[i]);
            @Nonnull TensorArray tensorArray = TensorArray.wrap(passback);
            inObj[0].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) LoadingCache(com.google.common.cache.LoadingCache) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) JsonUtil(com.simiacryptus.util.io.JsonUtil) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Logger(org.slf4j.Logger) Collectors(java.util.stream.Collectors) CacheLoader(com.google.common.cache.CacheLoader) ExecutionException(java.util.concurrent.ExecutionException) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Entry(java.util.Map.Entry) CacheBuilder(com.google.common.cache.CacheBuilder) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Entry(java.util.Map.Entry) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) List(java.util.List) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nonnull(javax.annotation.Nonnull)

Example 7 with Coordinate

use of com.simiacryptus.mindseye.lang.Coordinate in project MindsEye by SimiaCryptus.

the class MaxDropoutNoiseLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result in0 = inObj[0];
    final TensorList data0 = in0.getData();
    final int itemCnt = data0.length();
    in0.addRef();
    data0.addRef();
    final Tensor[] mask = IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        @Nullable final Tensor input = data0.get(dataIndex);
        @Nullable final Tensor output = input.map(x -> 0);
        final List<List<Coordinate>> cells = getCellMap_cached.apply(new IntArray(output.getDimensions()));
        cells.forEach(cell -> {
            output.set(cell.stream().max(Comparator.comparingDouble(c -> input.get(c))).get(), 1);
        });
        input.freeRef();
        return output;
    }).toArray(i -> new Tensor[i]);
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        Tensor inputData = data0.get(dataIndex);
        @Nullable final double[] input = inputData.getData();
        @Nullable final double[] maskT = mask[dataIndex].getData();
        @Nonnull final Tensor output = new Tensor(inputData.getDimensions());
        @Nullable final double[] outputData = output.getData();
        for (int i = 0; i < outputData.length; i++) {
            outputData[i] = input[i] * maskT[i];
        }
        inputData.freeRef();
        return output;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (in0.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nonnull final int[] dims = data0.getDimensions();
                @Nullable final double[] maskData = mask[dataIndex].getData();
                @Nonnull final Tensor passback = new Tensor(dims);
                for (int i = 0; i < passback.length(); i++) {
                    passback.set(i, maskData[i] * deltaData[i]);
                }
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            in0.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            in0.freeRef();
            data0.freeRef();
            Arrays.stream(mask).forEach(ReferenceCounting::freeRef);
        }

        @Override
        public boolean isAlive() {
            return in0.isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) IntArray(com.simiacryptus.util.data.IntArray) Function(java.util.function.Function) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) Logger(org.slf4j.Logger) Collectors(java.util.stream.Collectors) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Comparator(java.util.Comparator) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) IntArray(com.simiacryptus.util.data.IntArray) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) ArrayList(java.util.ArrayList) List(java.util.List) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

Coordinate (com.simiacryptus.mindseye.lang.Coordinate)7 Tensor (com.simiacryptus.mindseye.lang.Tensor)7 TensorList (com.simiacryptus.mindseye.lang.TensorList)7 List (java.util.List)7 JsonObject (com.google.gson.JsonObject)6 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)6 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)6 Layer (com.simiacryptus.mindseye.lang.Layer)6 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)6 Result (com.simiacryptus.mindseye.lang.Result)6 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)6 Arrays (java.util.Arrays)6 Map (java.util.Map)6 IntStream (java.util.stream.IntStream)6 Nonnull (javax.annotation.Nonnull)6 Nullable (javax.annotation.Nullable)6 Logger (org.slf4j.Logger)6 LoggerFactory (org.slf4j.LoggerFactory)6 JsonUtil (com.simiacryptus.util.io.JsonUtil)3 ToDoubleFunction (java.util.function.ToDoubleFunction)3