Search in sources :

Example 6 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class BinarySumLayer method evalAndFree.

@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
    if (inObj.length == 1) {
        if (rightFactor != 1)
            throw new IllegalStateException();
        if (leftFactor != 1)
            throw new IllegalStateException();
        return inObj[0];
    }
    if (inObj.length > 2) {
        if (rightFactor != 1)
            throw new IllegalStateException();
        if (leftFactor != 1)
            throw new IllegalStateException();
        return Arrays.stream(inObj).reduce((a, b) -> evalAndFree(a, b)).get();
    }
    assert (inObj.length == 2);
    final TensorList leftData = inObj[0].getData();
    final TensorList rightData = inObj[1].getData();
    int[] leftDimensions = leftData.getDimensions();
    if (3 < leftDimensions.length) {
        throw new IllegalArgumentException("dimensions=" + Arrays.toString(leftDimensions));
    }
    @Nonnull final int[] dimensions = { leftDimensions.length < 1 ? 0 : leftDimensions[0], leftDimensions.length < 2 ? 1 : leftDimensions[1], leftDimensions.length < 3 ? 1 : leftDimensions[2] };
    final int length = leftData.length();
    if (length != rightData.length())
        throw new IllegalArgumentException();
    if (3 != dimensions.length) {
        throw new IllegalArgumentException("dimensions=" + Arrays.toString(dimensions));
    }
    for (int i = 1; i < inObj.length; i++) {
        if (Tensor.length(dimensions) != Tensor.length(inObj[i].getData().getDimensions())) {
            throw new IllegalArgumentException(Arrays.toString(dimensions) + " != " + Arrays.toString(inObj[i].getData().getDimensions()));
        }
    }
    if (!CudaSystem.isEnabled())
        return getCompatibilityLayer().evalAndFree(inObj);
    return new Result(CudaSystem.run(gpu -> {
        @Nonnull final CudaResource<cudnnOpTensorDescriptor> opDescriptor = gpu.newOpDescriptor(cudnnOpTensorOp.CUDNN_OP_TENSOR_ADD, precision);
        @Nonnull final CudaDevice.CudaTensorDescriptor outputDescriptor = gpu.newTensorDescriptor(precision, length, dimensions[2], dimensions[1], dimensions[0], dimensions[2] * dimensions[1] * dimensions[0], dimensions[1] * dimensions[0], dimensions[0], 1);
        // .getDenseAndFree(gpu);//.moveTo(gpu.getDeviceNumber());
        @Nullable final CudaTensor lPtr = gpu.getTensor(leftData, precision, MemoryType.Device, false);
        // .getDenseAndFree(gpu);//.moveTo(gpu.getDeviceNumber());
        @Nullable final CudaTensor rPtr = gpu.getTensor(rightData, precision, MemoryType.Device, false);
        @Nonnull final CudaMemory outputPtr = gpu.allocate(precision.size * Tensor.length(dimensions) * length, MemoryType.Managed, true);
        CudaMemory lPtrMemory = lPtr.getMemory(gpu);
        CudaMemory rPtrMemory = rPtr.getMemory(gpu);
        gpu.cudnnOpTensor(opDescriptor.getPtr(), precision.getPointer(leftFactor), lPtr.descriptor.getPtr(), lPtrMemory.getPtr(), precision.getPointer(rightFactor), rPtr.descriptor.getPtr(), rPtrMemory.getPtr(), precision.getPointer(0.0), outputDescriptor.getPtr(), outputPtr.getPtr());
        assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
        lPtrMemory.dirty();
        rPtrMemory.dirty();
        outputPtr.dirty();
        rPtrMemory.freeRef();
        lPtrMemory.freeRef();
        CudaTensor cudaTensor = CudaTensor.wrap(outputPtr, outputDescriptor, precision);
        Stream.<ReferenceCounting>of(opDescriptor, lPtr, rPtr).forEach(ReferenceCounting::freeRef);
        return CudaTensorList.wrap(cudaTensor, length, dimensions, precision);
    }, leftData), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        Runnable a = () -> {
            if (inObj[0].isAlive()) {
                CudaTensorList tensorList = CudaSystem.run(gpu -> {
                    @Nullable final CudaTensor lPtr = gpu.getTensor(delta, precision, MemoryType.Device, false);
                    @Nonnull final CudaMemory passbackPtr = gpu.allocate(precision.size * Tensor.length(dimensions) * length, MemoryType.Managed.normalize(), true);
                    @Nonnull final CudaDevice.CudaTensorDescriptor passbackDescriptor = gpu.newTensorDescriptor(precision, length, dimensions[2], dimensions[1], dimensions[0], dimensions[2] * dimensions[1] * dimensions[0], dimensions[1] * dimensions[0], dimensions[0], 1);
                    CudaMemory lPtrMemory = lPtr.getMemory(gpu);
                    gpu.cudnnTransformTensor(precision.getPointer(leftFactor), lPtr.descriptor.getPtr(), lPtrMemory.getPtr(), precision.getPointer(0.0), passbackDescriptor.getPtr(), passbackPtr.getPtr());
                    assert CudaDevice.isThreadDeviceId(gpu.getDeviceId());
                    passbackPtr.dirty();
                    lPtrMemory.dirty();
                    lPtrMemory.freeRef();
                    CudaTensor cudaTensor = CudaTensor.wrap(passbackPtr, passbackDescriptor, precision);
                    lPtr.freeRef();
                    return CudaTensorList.wrap(cudaTensor, length, dimensions, precision);
                }, delta);
                inObj[0].accumulate(buffer, tensorList);
            }
        };
        Runnable b = () -> {
            if (inObj[1].isAlive()) {
                CudaTensorList tensorList = CudaSystem.run(gpu -> {
                    @Nullable final CudaTensor lPtr = gpu.getTensor(delta, precision, MemoryType.Device, false);
                    @Nonnull final CudaMemory outputPtr = gpu.allocate(precision.size * Tensor.length(dimensions) * length, MemoryType.Managed.normalize(), true);
                    @Nonnull final CudaDevice.CudaTensorDescriptor passbackDescriptor = gpu.newTensorDescriptor(precision, length, dimensions[2], dimensions[1], dimensions[0], dimensions[2] * dimensions[1] * dimensions[0], dimensions[1] * dimensions[0], dimensions[0], 1);
                    CudaMemory lPtrMemory = lPtr.getMemory(gpu);
                    gpu.cudnnTransformTensor(precision.getPointer(rightFactor), lPtr.descriptor.getPtr(), lPtrMemory.getPtr(), precision.getPointer(0.0), passbackDescriptor.getPtr(), outputPtr.getPtr());
                    outputPtr.dirty();
                    lPtrMemory.dirty();
                    lPtrMemory.freeRef();
                    CudaTensor cudaTensor = CudaTensor.wrap(outputPtr, passbackDescriptor, precision);
                    lPtr.freeRef();
                    return CudaTensorList.wrap(cudaTensor, length, dimensions, precision);
                }, delta);
                inObj[1].accumulate(buffer, tensorList);
            }
        };
        if (CoreSettings.INSTANCE.isSingleThreaded())
            TestUtil.runAllSerial(a, b);
        else
            TestUtil.runAllParallel(a, b);
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(x -> x.freeRef());
            leftData.freeRef();
            rightData.freeRef();
        }

        @Override
        public boolean isAlive() {
            for (@Nonnull final Result element : inObj) if (element.isAlive()) {
                return true;
            }
            return false;
        }
    };
}
Also used : PipelineNetwork(com.simiacryptus.mindseye.network.PipelineNetwork) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) CudaMemory(com.simiacryptus.mindseye.lang.cudnn.CudaMemory) Tensor(com.simiacryptus.mindseye.lang.Tensor) SumInputsLayer(com.simiacryptus.mindseye.layers.java.SumInputsLayer) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) Precision(com.simiacryptus.mindseye.lang.cudnn.Precision) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) CudaResource(com.simiacryptus.mindseye.lang.cudnn.CudaResource) CudaDevice(com.simiacryptus.mindseye.lang.cudnn.CudaDevice) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) jcuda.jcudnn.cudnnOpTensorOp(jcuda.jcudnn.cudnnOpTensorOp) TestUtil(com.simiacryptus.mindseye.test.TestUtil) CoreSettings(com.simiacryptus.mindseye.lang.CoreSettings) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Stream(java.util.stream.Stream) CudaSystem(com.simiacryptus.mindseye.lang.cudnn.CudaSystem) TensorList(com.simiacryptus.mindseye.lang.TensorList) LinearActivationLayer(com.simiacryptus.mindseye.layers.java.LinearActivationLayer) MemoryType(com.simiacryptus.mindseye.lang.cudnn.MemoryType) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) jcuda.jcudnn.cudnnOpTensorDescriptor(jcuda.jcudnn.cudnnOpTensorDescriptor) CudaTensor(com.simiacryptus.mindseye.lang.cudnn.CudaTensor) Nonnull(javax.annotation.Nonnull) CudaMemory(com.simiacryptus.mindseye.lang.cudnn.CudaMemory) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) CudaTensorList(com.simiacryptus.mindseye.lang.cudnn.CudaTensorList) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) CudaResource(com.simiacryptus.mindseye.lang.cudnn.CudaResource) Nullable(javax.annotation.Nullable)

Example 7 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class MaxImageBandLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    assert 1 == inObj.length;
    final TensorList inputData = inObj[0].getData();
    inputData.addRef();
    inputData.length();
    @Nonnull final int[] inputDims = inputData.getDimensions();
    assert 3 == inputDims.length;
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    final Coordinate[][] maxCoords = inputData.stream().map(data -> {
        Coordinate[] coordinates = IntStream.range(0, inputDims[2]).mapToObj(band -> {
            return data.coordStream(true).filter(e -> e.getCoords()[2] == band).max(Comparator.comparing(c -> data.get(c))).get();
        }).toArray(i -> new Coordinate[i]);
        data.freeRef();
        return coordinates;
    }).toArray(i -> new Coordinate[i][]);
    return new Result(TensorArray.wrap(IntStream.range(0, inputData.length()).mapToObj(dataIndex -> {
        Tensor tensor = inputData.get(dataIndex);
        final DoubleStream doubleStream = IntStream.range(0, inputDims[2]).mapToDouble(band -> {
            final int[] maxCoord = maxCoords[dataIndex][band].getCoords();
            double v = tensor.get(maxCoord[0], maxCoord[1], band);
            return v;
        });
        Tensor tensor1 = new Tensor(1, 1, inputDims[2]).set(Tensor.getDoubles(doubleStream, inputDims[2]));
        tensor.freeRef();
        return tensor1;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList delta) -> {
        if (inObj[0].isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, delta.length()).parallel().mapToObj(dataIndex -> {
                Tensor deltaTensor = delta.get(dataIndex);
                @Nonnull final Tensor passback = new Tensor(inputData.getDimensions());
                IntStream.range(0, inputDims[2]).forEach(b -> {
                    final int[] maxCoord = maxCoords[dataIndex][b].getCoords();
                    passback.set(new int[] { maxCoord[0], maxCoord[1], b }, deltaTensor.get(0, 0, b));
                });
                deltaTensor.freeRef();
                return passback;
            }).toArray(i -> new Tensor[i]));
            inObj[0].accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
            inputData.freeRef();
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) DoubleStream(java.util.stream.DoubleStream) JsonUtil(com.simiacryptus.util.io.JsonUtil) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Coordinate(com.simiacryptus.mindseye.lang.Coordinate) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DoubleStream(java.util.stream.DoubleStream) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 8 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class MaxPoolingLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    Arrays.stream(inObj).forEach(nnResult -> nnResult.addRef());
    final Result in = inObj[0];
    in.getData().length();
    @Nonnull final int[] inputDims = in.getData().getDimensions();
    final List<Tuple2<Integer, int[]>> regions = MaxPoolingLayer.calcRegionsCache.apply(new MaxPoolingLayer.CalcRegionsParameter(inputDims, kernelDims));
    final Tensor[] outputA = IntStream.range(0, in.getData().length()).mapToObj(dataIndex -> {
        final int[] newDims = IntStream.range(0, inputDims.length).map(i -> {
            return (int) Math.ceil(inputDims[i] * 1.0 / kernelDims[i]);
        }).toArray();
        @Nonnull final Tensor output = new Tensor(newDims);
        return output;
    }).toArray(i -> new Tensor[i]);
    Arrays.stream(outputA).mapToInt(x -> x.length()).sum();
    @Nonnull final int[][] gradientMapA = new int[in.getData().length()][];
    IntStream.range(0, in.getData().length()).forEach(dataIndex -> {
        @Nullable final Tensor input = in.getData().get(dataIndex);
        final Tensor output = outputA[dataIndex];
        @Nonnull final IntToDoubleFunction keyExtractor = inputCoords -> input.get(inputCoords);
        @Nonnull final int[] gradientMap = new int[input.length()];
        regions.parallelStream().forEach(tuple -> {
            final Integer from = tuple.getFirst();
            final int[] toList = tuple.getSecond();
            int toMax = -1;
            double bestValue = Double.NEGATIVE_INFINITY;
            for (final int c : toList) {
                final double value = keyExtractor.applyAsDouble(c);
                if (-1 == toMax || bestValue < value) {
                    bestValue = value;
                    toMax = c;
                }
            }
            gradientMap[from] = toMax;
            output.set(from, input.get(toMax));
        });
        input.freeRef();
        gradientMapA[dataIndex] = gradientMap;
    });
    return new Result(TensorArray.wrap(outputA), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (in.isAlive()) {
            @Nonnull TensorArray tensorArray = TensorArray.wrap(IntStream.range(0, in.getData().length()).parallel().mapToObj(dataIndex -> {
                @Nonnull final Tensor backSignal = new Tensor(inputDims);
                final int[] ints = gradientMapA[dataIndex];
                @Nullable final Tensor datum = data.get(dataIndex);
                for (int i = 0; i < datum.length(); i++) {
                    backSignal.add(ints[i], datum.get(i));
                }
                datum.freeRef();
                return backSignal;
            }).toArray(i -> new Tensor[i]));
            in.accumulate(buffer, tensorArray);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(nnResult -> nnResult.freeRef());
        }

        @Override
        public boolean isAlive() {
            return in.isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Util(com.simiacryptus.util.Util) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) IntToDoubleFunction(java.util.function.IntToDoubleFunction) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) Function(java.util.function.Function) Collectors(java.util.stream.Collectors) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) JsonUtil(com.simiacryptus.util.io.JsonUtil) Tuple2(com.simiacryptus.util.lang.Tuple2) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) IntToDoubleFunction(java.util.function.IntToDoubleFunction) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Tuple2(com.simiacryptus.util.lang.Tuple2) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nullable(javax.annotation.Nullable) Nonnull(javax.annotation.Nonnull)

Example 9 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class MaxMetaLayer method eval.

@Nonnull
@Override
public Result eval(final Result... inObj) {
    final Result input = inObj[0];
    input.addRef();
    final int itemCnt = input.getData().length();
    final Tensor input0Tensor = input.getData().get(0);
    final int vectorSize = input0Tensor.length();
    @Nonnull final int[] indicies = new int[vectorSize];
    for (int i = 0; i < vectorSize; i++) {
        final int itemNumber = i;
        indicies[i] = IntStream.range(0, itemCnt).mapToObj(x -> x).max(Comparator.comparing(dataIndex -> {
            Tensor tensor = input.getData().get(dataIndex);
            double v = tensor.getData()[itemNumber];
            tensor.freeRef();
            return v;
        })).get();
    }
    return new Result(TensorArray.wrap(input0Tensor.mapIndex((v, c) -> {
        Tensor tensor = input.getData().get(indicies[c]);
        double v1 = tensor.getData()[c];
        tensor.freeRef();
        return v1;
    })), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (input.isAlive()) {
            @Nullable final Tensor delta = data.get(0);
            @Nonnull final Tensor[] feedback = new Tensor[itemCnt];
            Arrays.parallelSetAll(feedback, i -> new Tensor(delta.getDimensions()));
            input0Tensor.coordStream(true).forEach((inputCoord) -> {
                feedback[indicies[inputCoord.getIndex()]].add(inputCoord, delta.get(inputCoord));
            });
            @Nonnull TensorArray tensorArray = TensorArray.wrap(feedback);
            input.accumulate(buffer, tensorArray);
            delta.freeRef();
        }
    }) {

        @Override
        public boolean isAlive() {
            return input.isAlive();
        }

        @Override
        protected void _free() {
            input.freeRef();
            input0Tensor.freeRef();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) Comparator(java.util.Comparator) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Example 10 with Layer

use of com.simiacryptus.mindseye.lang.Layer in project MindsEye by SimiaCryptus.

the class MeanSqLossLayer method eval.

@Nonnull
@Override
public Result eval(@Nonnull final Result... inObj) {
    if (2 != inObj.length)
        throw new IllegalArgumentException();
    final int leftLength = inObj[0].getData().length();
    final int rightLength = inObj[1].getData().length();
    Arrays.stream(inObj).forEach(ReferenceCounting::addRef);
    if (leftLength != rightLength && leftLength != 1 && rightLength != 1) {
        throw new IllegalArgumentException(leftLength + " != " + rightLength);
    }
    @Nonnull final Tensor[] diffs = new Tensor[leftLength];
    return new Result(TensorArray.wrap(IntStream.range(0, leftLength).mapToObj(dataIndex -> {
        @Nullable final Tensor a = inObj[0].getData().get(1 == leftLength ? 0 : dataIndex);
        @Nullable final Tensor b = inObj[1].getData().get(1 == rightLength ? 0 : dataIndex);
        if (a.length() != b.length()) {
            throw new IllegalArgumentException(String.format("%s != %s", Arrays.toString(a.getDimensions()), Arrays.toString(b.getDimensions())));
        }
        @Nonnull final Tensor r = a.minus(b);
        a.freeRef();
        b.freeRef();
        diffs[dataIndex] = r;
        @Nonnull Tensor statsTensor = new Tensor(new double[] { r.sumSq() / r.length() }, 1);
        return statsTensor;
    }).toArray(i -> new Tensor[i])), (@Nonnull final DeltaSet<Layer> buffer, @Nonnull final TensorList data) -> {
        if (inObj[0].isAlive()) {
            Stream<Tensor> tensorStream = IntStream.range(0, data.length()).parallel().mapToObj(dataIndex -> {
                @Nullable Tensor tensor = data.get(dataIndex);
                Tensor diff = diffs[dataIndex];
                @Nullable Tensor scale = diff.scale(tensor.get(0) * 2.0 / diff.length());
                tensor.freeRef();
                return scale;
            }).collect(Collectors.toList()).stream();
            if (1 == leftLength) {
                tensorStream = Stream.of(tensorStream.reduce((a, b) -> {
                    @Nullable Tensor c = a.addAndFree(b);
                    b.freeRef();
                    return c;
                }).get());
            }
            @Nonnull final TensorList array = TensorArray.wrap(tensorStream.toArray(i -> new Tensor[i]));
            inObj[0].accumulate(buffer, array);
        }
        if (inObj[1].isAlive()) {
            Stream<Tensor> tensorStream = IntStream.range(0, data.length()).parallel().mapToObj(dataIndex -> {
                @Nullable Tensor tensor = data.get(dataIndex);
                @Nullable Tensor scale = diffs[dataIndex].scale(tensor.get(0) * 2.0 / diffs[dataIndex].length());
                tensor.freeRef();
                return scale;
            }).collect(Collectors.toList()).stream();
            if (1 == rightLength) {
                tensorStream = Stream.of(tensorStream.reduce((a, b) -> {
                    @Nullable Tensor c = a.addAndFree(b);
                    b.freeRef();
                    return c;
                }).get());
            }
            @Nonnull final TensorList array = TensorArray.wrap(tensorStream.map(x -> {
                @Nullable Tensor scale = x.scale(-1);
                x.freeRef();
                return scale;
            }).toArray(i -> new Tensor[i]));
            inObj[1].accumulate(buffer, array);
        }
    }) {

        @Override
        protected void _free() {
            Arrays.stream(inObj).forEach(ReferenceCounting::freeRef);
            Arrays.stream(diffs).forEach(ReferenceCounting::freeRef);
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive() || inObj[1].isAlive();
        }
    };
}
Also used : IntStream(java.util.stream.IntStream) JsonObject(com.google.gson.JsonObject) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) Result(com.simiacryptus.mindseye.lang.Result) Collectors(java.util.stream.Collectors) DataSerializer(com.simiacryptus.mindseye.lang.DataSerializer) List(java.util.List) LayerBase(com.simiacryptus.mindseye.lang.LayerBase) Stream(java.util.stream.Stream) TensorList(com.simiacryptus.mindseye.lang.TensorList) Map(java.util.Map) Layer(com.simiacryptus.mindseye.lang.Layer) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) Tensor(com.simiacryptus.mindseye.lang.Tensor) ReferenceCounting(com.simiacryptus.mindseye.lang.ReferenceCounting) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable) Result(com.simiacryptus.mindseye.lang.Result) Nonnull(javax.annotation.Nonnull)

Aggregations

Layer (com.simiacryptus.mindseye.lang.Layer)167 Nonnull (javax.annotation.Nonnull)159 Nullable (javax.annotation.Nullable)128 Arrays (java.util.Arrays)117 Tensor (com.simiacryptus.mindseye.lang.Tensor)116 List (java.util.List)108 Result (com.simiacryptus.mindseye.lang.Result)103 IntStream (java.util.stream.IntStream)98 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)95 TensorList (com.simiacryptus.mindseye.lang.TensorList)93 Map (java.util.Map)83 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)76 Logger (org.slf4j.Logger)76 LoggerFactory (org.slf4j.LoggerFactory)76 JsonObject (com.google.gson.JsonObject)70 DataSerializer (com.simiacryptus.mindseye.lang.DataSerializer)66 LayerBase (com.simiacryptus.mindseye.lang.LayerBase)64 NotebookOutput (com.simiacryptus.util.io.NotebookOutput)51 Collectors (java.util.stream.Collectors)42 Stream (java.util.stream.Stream)37