use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.
the class CudaTensorList method addAndFree.
@Override
public TensorList addAndFree(@Nonnull final TensorList right) {
assertAlive();
right.assertAlive();
if (right instanceof ReshapedTensorList)
return addAndFree(((ReshapedTensorList) right).getInner());
if (1 < currentRefCount()) {
TensorList sum = add(right);
freeRef();
return sum;
}
assert length() == right.length();
if (heapCopy == null) {
if (right instanceof CudaTensorList) {
@Nonnull final CudaTensorList nativeRight = (CudaTensorList) right;
if (nativeRight.getPrecision() == this.getPrecision()) {
if (nativeRight.heapCopy == null) {
assert (!nativeRight.gpuCopy.equals(CudaTensorList.this.gpuCopy));
CudaMemory rightMem = gpuCopy.memory;
CudaMemory leftMem = rightMem;
if (null != leftMem && null != rightMem)
return CudaSystem.run(gpu -> {
if (gpu.getDeviceId() == leftMem.getDeviceId()) {
return gpu.addInPlace(this, nativeRight);
} else {
assertAlive();
right.assertAlive();
TensorList add = add(right);
freeRef();
return add;
}
}, this, right);
}
}
}
}
if (right.length() == 0)
return this;
if (length() == 0)
throw new IllegalArgumentException();
assert length() == right.length();
return TensorArray.wrap(IntStream.range(0, length()).mapToObj(i -> {
Tensor a = get(i);
Tensor b = right.get(i);
@Nullable Tensor r = a.addAndFree(b);
b.freeRef();
return r;
}).toArray(i -> new Tensor[i]));
}
use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.
the class CudaTensorList method add.
@Override
public TensorList add(@Nonnull final TensorList right) {
assertAlive();
right.assertAlive();
assert length() == right.length();
if (right instanceof ReshapedTensorList)
return add(((ReshapedTensorList) right).getInner());
if (heapCopy == null) {
if (right instanceof CudaTensorList) {
@Nonnull final CudaTensorList nativeRight = (CudaTensorList) right;
if (nativeRight.getPrecision() == this.getPrecision()) {
if (nativeRight.heapCopy == null) {
return CudaSystem.run(gpu -> {
return gpu.add(this, nativeRight);
}, this);
}
}
}
}
if (right.length() == 0)
return this;
if (length() == 0)
throw new IllegalArgumentException();
assert length() == right.length();
return TensorArray.wrap(IntStream.range(0, length()).mapToObj(i -> {
Tensor a = get(i);
Tensor b = right.get(i);
@Nullable Tensor r = a.addAndFree(b);
b.freeRef();
return r;
}).toArray(i -> new Tensor[i]));
}
use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.
the class CudnnHandle method getTensor.
/**
* Gets cuda ptr.
*
* @param data the data
* @param precision the precision
* @param memoryType the memory type
* @param dense the dense
* @return the cuda ptr
*/
@Nonnull
public CudaTensor getTensor(@Nonnull final TensorList data, @Nonnull final Precision precision, final MemoryType memoryType, final boolean dense) {
assert CudaDevice.isThreadDeviceId(getDeviceId());
int[] inputSize = data.getDimensions();
data.assertAlive();
if (data instanceof ReshapedTensorList) {
ReshapedTensorList reshapedTensorList = (ReshapedTensorList) data;
int[] newDims = reshapedTensorList.getDimensions();
CudaTensor reshapedTensor = getTensor(reshapedTensorList.getInner(), precision, memoryType, true);
int channels = newDims.length < 3 ? 1 : newDims[2];
int height = newDims.length < 2 ? 1 : newDims[1];
int width = newDims.length < 1 ? 1 : newDims[0];
CudaTensorDescriptor descriptor = newTensorDescriptor(precision, reshapedTensor.descriptor.batchCount, channels, height, width, channels * height * width, height * width, width, 1);
CudaMemory tensorMemory = reshapedTensor.getMemory(this, memoryType);
reshapedTensor.freeRef();
CudaTensor cudaTensor = new CudaTensor(tensorMemory, descriptor, precision);
tensorMemory.freeRef();
descriptor.freeRef();
return cudaTensor;
}
if (data instanceof CudaTensorList) {
CudaTensorList cudaTensorList = (CudaTensorList) data;
if (precision == cudaTensorList.getPrecision()) {
return this.getTensor(cudaTensorList, memoryType, dense);
} else {
CudaTensorList.logger.warn(String.format("Incompatible precision types for Tensor %s in GPU at %s, created by %s", Integer.toHexString(System.identityHashCode(cudaTensorList)), TestUtil.toString(TestUtil.getStackTrace()).replaceAll("\n", "\n\t"), TestUtil.toString(cudaTensorList.createdBy).replaceAll("\n", "\n\t")));
}
}
final int listLength = data.length();
final int elementLength = Tensor.length(data.getDimensions());
@Nonnull final CudaMemory ptr = this.allocate((long) elementLength * listLength * precision.size, memoryType, true);
for (int i = 0; i < listLength; i++) {
Tensor tensor = data.get(i);
assert null != data;
assert null != tensor;
assert Arrays.equals(tensor.getDimensions(), data.getDimensions()) : Arrays.toString(tensor.getDimensions()) + " != " + Arrays.toString(data.getDimensions());
double[] tensorData = tensor.getData();
ptr.write(precision, tensorData, (long) i * elementLength);
tensor.freeRef();
}
final int channels = inputSize.length < 3 ? 1 : inputSize[2];
final int height = inputSize.length < 2 ? 1 : inputSize[1];
final int width = inputSize.length < 1 ? 1 : inputSize[0];
@Nonnull final CudaDevice.CudaTensorDescriptor descriptor = newTensorDescriptor(precision, data.length(), channels, height, width, channels * height * width, height * width, width, 1);
return CudaTensor.wrap(ptr, descriptor, precision);
}
use of com.simiacryptus.mindseye.lang.ReshapedTensorList in project MindsEye by SimiaCryptus.
the class ReshapeLayer method evalAndFree.
@Nullable
@Override
public Result evalAndFree(@Nonnull final Result... inObj) {
assert 1 == inObj.length;
TensorList data = inObj[0].getData();
@Nonnull int[] inputDims = data.getDimensions();
ReshapedTensorList reshapedTensorList = new ReshapedTensorList(data, outputDims);
data.freeRef();
return new Result(reshapedTensorList, (DeltaSet<Layer> buffer, TensorList delta) -> {
@Nonnull ReshapedTensorList tensorList = new ReshapedTensorList(delta, inputDims);
inObj[0].accumulate(buffer, tensorList);
}) {
@Override
protected void _free() {
for (@Nonnull Result result : inObj) {
result.freeRef();
}
}
@Override
public boolean isAlive() {
return inObj[0].isAlive();
}
};
}
Aggregations