Search in sources :

Example 1 with IntArray

use of com.simiacryptus.util.data.IntArray in project MindsEye by SimiaCryptus.

the class MaxDropoutNoiseLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result in0 = inObj[0];
    final TensorList data0 = in0.getData();
    final int itemCnt = data0.length();
    in0.addRef();
    data0.addRef();
    final Tensor[] mask = IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        @Nullable final Tensor input = data0.get(dataIndex);
        @Nullable final Tensor output = input.map(x -> 0);
        final List<List<Coordinate>> cells = getCellMap_cached.apply(new IntArray(output.getDimensions()));
        cells.forEach(cell -> {
            output.set(cell.stream().max(Comparator.comparingDouble(c -> input.get(c))).get(), 1);
        });
        input.freeRef();
        return output;
    }).toArray(i -> new Tensor[i]);
    return new Result(TensorArray.wrap(IntStream.range(0, itemCnt).mapToObj(dataIndex -> {
        Tensor inputData = data0.get(dataIndex);
        @Nullable final double[] input = inputData.getData();
        @Nullable final double[] maskT = mask[dataIndex].getData();
        @Nonnull final Tensor output = new Tensor(inputData.getDimensions());
        @Nullable final double[] outputData = output.getData();
        for (int i = 0; i < outputData.length; i++) {
            outputData[i] = input[i] * maskT[i];
        }
        inputData.freeRef();
        return output;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (in0.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).mapToObj(dataIndex -> {
                Tensor deltaTensor = delta.get(dataIndex);
                @Nullable final double[] deltaData = deltaTensor.getData();
                @Nonnull final int[] dims = data0.getDimensions();
                @Nullable final double[] maskData = mask[dataIndex].getData();
                @Nonnull final Tensor passback = new Tensor(dims);
                for (int i = 0; i < passback.length(); i++) {
                    passback.set(i, maskData[i] * deltaData[i]);
                }
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            in0.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            in0.freeRef();
            data0.freeRef();
            Arrays.stream(mask).forEach(ReferenceCounting::freeRef);
        }

        @Override
        public boolean isAlive() {
            return in0.isAlive() || !isFrozen();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) IntArray(com.simiacryptus.util.data.IntArray) Function(java.util.function.Function) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) ArrayList(java.util.ArrayList) JsonUtil(com.simiacryptus.util.io.JsonUtil) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Util(com.simiacryptus.util.Util) Logger(org.slf4j.Logger) Collectors(java.util.stream.Collectors) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Comparator(java.util.Comparator) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) IntArray(com.simiacryptus.util.data.IntArray) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) ArrayList(java.util.ArrayList) List(java.util.List) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Aggregations

JsonObject (com.google.gson.JsonObject)1 Coordinate (com.simiacryptus.mindseye.lang.Coordinate)1 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)1 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)1 Layer (com.simiacryptus.mindseye.lang.Layer)1 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)1 ReferenceCounting (com.simiacryptus.mindseye.lang.ReferenceCounting)1 Result (com.simiacryptus.mindseye.lang.Result)1 Tensor (com.simiacryptus.mindseye.lang.Tensor)1 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)1 TensorList (com.simiacryptus.mindseye.lang.TensorList)1 Util (com.simiacryptus.util.Util)1 IntArray (com.simiacryptus.util.data.IntArray)1 JsonUtil (com.simiacryptus.util.io.JsonUtil)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Comparator (java.util.Comparator)1 List (java.util.List)1 Map (java.util.Map)1 Function (java.util.function.Function)1