Search in sources :

Example 1 with ReshapedTensorList

use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.

the class CudaTensorList method addAndFree.

@Override
public TensorList addAndFree(@Nonnull final TensorList right) {
    assertAlive();
    right.assertAlive();
    if (right instanceof ReshapedTensorList)
        return addAndFree(((ReshapedTensorList) right).getInner());
    if (1 < currentRefCount()) {
        TensorList sum = add(right);
        freeRef();
        return sum;
    }
    assert length() == right.length();
    if (heapCopy == null) {
        if (right instanceof CudaTensorList) {
            @Nonnull final CudaTensorList nativeRight = (CudaTensorList) right;
            if (nativeRight.getPrecision() == this.getPrecision()) {
                if (nativeRight.heapCopy == null) {
                    assert (!nativeRight.gpuCopy.equals(CudaTensorList.this.gpuCopy));
                    CudaMemory rightMem = gpuCopy.memory;
                    CudaMemory leftMem = rightMem;
                    if (null != leftMem && null != rightMem)
                        return CudaSystem.run(gpu -> {
                            if (gpu.getDeviceId() == leftMem.getDeviceId()) {
                                return gpu.addInPlace(this, nativeRight);
                            } else {
                                assertAlive();
                                right.assertAlive();
                                TensorList add = add(right);
                                freeRef();
                                return add;
                            }
                        }, this, right);
                }
            }
        }
    }
    if (right.length() == 0)
        return this;
    if (length() == 0)
        throw new IllegalArgumentException();
    assert length() == right.length();
    return TensorArray.wrap(IntStream.range(0, length()).mapToObj(i -> {
        Tensor a = get(i);
        Tensor b = right.get(i);
        @Nullable Tensor r = a.addAndFree(b);
        b.freeRef();
        return r;
    }).toArray(i -> new Tensor[i]));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TestUtil(com.simiacryptus.mindseye.test.TestUtil) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Stream(java.util.stream.Stream) RegisteredObjectBase(com.simiacryptus.mindseye.lang.RegisteredObjectBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Nullable(javax.annotation.Nullable)

Example 2 with ReshapedTensorList

use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.

the class CudaTensorList method add.

@Override
public TensorList add(@Nonnull final TensorList right) {
    assertAlive();
    right.assertAlive();
    assert length() == right.length();
    if (right instanceof ReshapedTensorList)
        return add(((ReshapedTensorList) right).getInner());
    if (heapCopy == null) {
        if (right instanceof CudaTensorList) {
            @Nonnull final CudaTensorList nativeRight = (CudaTensorList) right;
            if (nativeRight.getPrecision() == this.getPrecision()) {
                if (nativeRight.heapCopy == null) {
                    return CudaSystem.run(gpu -> {
                        return gpu.add(this, nativeRight);
                    }, this);
                }
            }
        }
    }
    if (right.length() == 0)
        return this;
    if (length() == 0)
        throw new IllegalArgumentException();
    assert length() == right.length();
    return TensorArray.wrap(IntStream.range(0, length()).mapToObj(i -> {
        Tensor a = get(i);
        Tensor b = right.get(i);
        @Nullable Tensor r = a.addAndFree(b);
        b.freeRef();
        return r;
    }).toArray(i -> new Tensor[i]));
}
Also used : IntStream(java.util.stream.IntStream) Arrays(java.util.Arrays) Logger(org.slf4j.Logger) LoggerFactory(org.slf4j.LoggerFactory) Tensor(com.simiacryptus.mindseye.lang.Tensor) TestUtil(com.simiacryptus.mindseye.test.TestUtil) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Stream(java.util.stream.Stream) RegisteredObjectBase(com.simiacryptus.mindseye.lang.RegisteredObjectBase) TensorList(com.simiacryptus.mindseye.lang.TensorList) TimedResult(com.simiacryptus.util.lang.TimedResult) TensorArray(com.simiacryptus.mindseye.lang.TensorArray) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Nullable(javax.annotation.Nullable)

Example 3 with ReshapedTensorList

use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.

the class CudnnHandle method getTensor.

/**
 * Gets cuda ptr.
 *
 * @param data       the data
 * @param precision  the precision
 * @param memoryType the memory type
 * @param dense      the dense
 * @return the cuda ptr
 */
@Nonnull
public CudaTensor getTensor(@Nonnull final TensorList data, @Nonnull final Precision precision, final MemoryType memoryType, final boolean dense) {
    assert CudaDevice.isThreadDeviceId(getDeviceId());
    int[] inputSize = data.getDimensions();
    data.assertAlive();
    if (data instanceof ReshapedTensorList) {
        ReshapedTensorList reshapedTensorList = (ReshapedTensorList) data;
        int[] newDims = reshapedTensorList.getDimensions();
        CudaTensor reshapedTensor = getTensor(reshapedTensorList.getInner(), precision, memoryType, true);
        int channels = newDims.length < 3 ? 1 : newDims[2];
        int height = newDims.length < 2 ? 1 : newDims[1];
        int width = newDims.length < 1 ? 1 : newDims[0];
        CudaTensorDescriptor descriptor = newTensorDescriptor(precision, reshapedTensor.descriptor.batchCount, channels, height, width, channels * height * width, height * width, width, 1);
        CudaMemory tensorMemory = reshapedTensor.getMemory(this, memoryType);
        reshapedTensor.freeRef();
        CudaTensor cudaTensor = new CudaTensor(tensorMemory, descriptor, precision);
        tensorMemory.freeRef();
        descriptor.freeRef();
        return cudaTensor;
    }
    if (data instanceof CudaTensorList) {
        CudaTensorList cudaTensorList = (CudaTensorList) data;
        if (precision == cudaTensorList.getPrecision()) {
            return this.getTensor(cudaTensorList, memoryType, dense);
        } else {
            CudaTensorList.logger.warn(String.format("Incompatible precision types for Tensor %s in GPU at %s, created by %s", Integer.toHexString(System.identityHashCode(cudaTensorList)), TestUtil.toString(TestUtil.getStackTrace()).replaceAll("\n", "\n\t"), TestUtil.toString(cudaTensorList.createdBy).replaceAll("\n", "\n\t")));
        }
    }
    final int listLength = data.length();
    final int elementLength = Tensor.length(data.getDimensions());
    @Nonnull final CudaMemory ptr = this.allocate((long) elementLength * listLength * precision.size, memoryType, true);
    for (int i = 0; i < listLength; i++) {
        Tensor tensor = data.get(i);
        assert null != data;
        assert null != tensor;
        assert Arrays.equals(tensor.getDimensions(), data.getDimensions()) : Arrays.toString(tensor.getDimensions()) + " != " + Arrays.toString(data.getDimensions());
        double[] tensorData = tensor.getData();
        ptr.write(precision, tensorData, (long) i * elementLength);
        tensor.freeRef();
    }
    final int channels = inputSize.length < 3 ? 1 : inputSize[2];
    final int height = inputSize.length < 2 ? 1 : inputSize[1];
    final int width = inputSize.length < 1 ? 1 : inputSize[0];
    @Nonnull final CudaDevice.CudaTensorDescriptor descriptor = newTensorDescriptor(precision, data.length(), channels, height, width, channels * height * width, height * width, width, 1);
    return CudaTensor.wrap(ptr, descriptor, precision);
}
Also used : ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Tensor(com.simiacryptus.mindseye.lang.Tensor) Nonnull(javax.annotation.Nonnull) Nonnull(javax.annotation.Nonnull)

Example 4 with ReshapedTensorList

use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.

the class ReshapeLayer method evalAndFree.

@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
    assert 1 == inObj.length;
    TensorList data = inObj[0].getData();
    @Nonnull int[] inputDims = data.getDimensions();
    ReshapedTensorList reshapedTensorList = new ReshapedTensorList(data, outputDims);
    data.freeRef();
    return new Result(reshapedTensorList, (DeltaSet<Layer> buffer, TensorList delta) -> {
        @Nonnull ReshapedTensorList tensorList = new ReshapedTensorList(delta, inputDims);
        inObj[0].accumulate(buffer, tensorList);
    }) {

        @Override
        protected void _free() {
            for (@Nonnull Result result : inObj) {
                result.freeRef();
            }
        }

        @Override
        public boolean isAlive() {
            return inObj[0].isAlive();
        }
    };
}
Also used : ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) Nonnull(javax.annotation.Nonnull) DeltaSet(com.simiacryptus.mindseye.lang.DeltaSet) ReshapedTensorList(com.simiacryptus.mindseye.lang.ReshapedTensorList) TensorList(com.simiacryptus.mindseye.lang.TensorList) Result(com.simiacryptus.mindseye.lang.Result) Nullable(javax.annotation.Nullable)

Aggregations

ReshapedTensorList (com.simiacryptus.mindseye.lang.ReshapedTensorList)4 Nonnull (javax.annotation.Nonnull)4 Tensor (com.simiacryptus.mindseye.lang.Tensor)3 TensorList (com.simiacryptus.mindseye.lang.TensorList)3 Nullable (javax.annotation.Nullable)3 RegisteredObjectBase (com.simiacryptus.mindseye.lang.RegisteredObjectBase)2 TensorArray (com.simiacryptus.mindseye.lang.TensorArray)2 TestUtil (com.simiacryptus.mindseye.test.TestUtil)2 TimedResult (com.simiacryptus.util.lang.TimedResult)2 Arrays (java.util.Arrays)2 IntStream (java.util.stream.IntStream)2 Stream (java.util.stream.Stream)2 Logger (org.slf4j.Logger)2 LoggerFactory (org.slf4j.LoggerFactory)2 DeltaSet (com.simiacryptus.mindseye.lang.DeltaSet)1 Result (com.simiacryptus.mindseye.lang.Result)1