use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class Transpose method initFromTensorFlow.
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
// permute dimensions re not specified as second input
if (nodeDef.getInputCount() < 2)
return;
NodeDef permuteDimsNode = null;
for (int i = 0; i < graph.getNodeCount(); i++) {
if (graph.getNode(i).getName().equals(nodeDef.getInput(1))) {
permuteDimsNode = graph.getNode(i);
}
}
val permuteArrayOp = TFGraphMapper.getInstance().getNDArrayFromTensor("value", permuteDimsNode, graph);
if (permuteArrayOp != null) {
this.permuteDims = permuteArrayOp.data().asInt();
for (int i = 0; i < permuteDims.length; i++) {
addIArgument(permuteDims[i]);
}
}
// handle once properly mapped
if (arg().getShape() == null) {
return;
}
INDArray arr = sameDiff.getArrForVarName(arg().getVarName());
if (arr == null) {
val arrVar = sameDiff.getVariable(arg().getVarName());
arr = arrVar.getWeightInitScheme().create(arrVar.getShape());
sameDiff.putArrayForVarName(arg().getVarName(), arr);
}
addInputArgument(arr);
if (arr != null && permuteDims == null) {
this.permuteDims = ArrayUtil.reverseCopy(ArrayUtil.range(0, arr.rank()));
}
if (permuteDims != null && permuteDims.length < arg().getShape().length)
throw new ND4JIllegalStateException("Illegal permute found. Not all dimensions specified");
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class Concat method resolvePropertiesFromSameDiffBeforeExecution.
@Override
public void resolvePropertiesFromSameDiffBeforeExecution() {
val propertiesToResolve = sameDiff.propertiesToResolveForFunction(this);
if (!propertiesToResolve.isEmpty()) {
val varName = propertiesToResolve.get(0);
val var = sameDiff.getVariable(varName);
if (var == null) {
throw new ND4JIllegalStateException("No variable found with name " + varName);
} else if (var.getArr() == null) {
throw new ND4JIllegalStateException("Array with variable name " + varName + " unset!");
}
concatDimension = var.getArr().getInt(0);
addIArgument(concatDimension);
}
// don't pass both iArg and last axis down to libnd4j
if (inputArguments().length == args().length) {
val inputArgs = inputArguments();
removeInputArgument(inputArgs[inputArguments().length - 1]);
}
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class ArrowSerde method fromTensor.
/**
* Convert a {@link Tensor}
* to an {@link INDArray}
* @param tensor the input tensor
* @return the equivalent {@link INDArray}
*/
public static INDArray fromTensor(Tensor tensor) {
byte b = tensor.typeType();
int[] shape = new int[tensor.shapeLength()];
int[] stride = new int[tensor.stridesLength()];
for (int i = 0; i < shape.length; i++) {
shape[i] = (int) tensor.shape(i).size();
stride[i] = (int) tensor.strides(i);
}
int length = ArrayUtil.prod(shape);
Buffer buffer = tensor.data();
if (buffer == null) {
throw new ND4JIllegalStateException("Buffer was not serialized properly.");
}
// deduce element size
int elementSize = (int) buffer.length() / length;
// nd4j strides aren't based on element size
for (int i = 0; i < stride.length; i++) {
stride[i] /= elementSize;
}
DataBuffer.Type type = typeFromTensorType(b, elementSize);
DataBuffer dataBuffer = DataBufferStruct.createFromByteBuffer(tensor.getByteBuffer(), (int) tensor.data().offset(), type, length, elementSize);
INDArray arr = Nd4j.create(dataBuffer, shape);
arr.setShapeAndStride(shape, stride);
return arr;
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class CudaZeroHandler method memcpySpecial.
/**
* Special memcpy version, addressing shapeInfoDataBuffer copies
*
* PLEASE NOTE: Blocking H->H, Async H->D
*
* @param dstBuffer
* @param srcPointer
* @param length
* @param dstOffset
*/
@Override
public void memcpySpecial(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) {
// log.info("Memcpy special: {} bytes ", length);
CudaContext context = getCudaContext();
AllocationPoint point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
// context.syncOldStream();
Pointer dP = new CudaPointer((point.getPointers().getHostPointer().address()) + dstOffset);
if (nativeOps.memcpyAsync(dP, srcPointer, length, CudaConstants.cudaMemcpyHostToHost, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
if (point.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(point.getPointers().getDevicePointer().address() + dstOffset);
if (nativeOps.memcpyAsync(rDP, dP, length, CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
context.syncOldStream();
}
context.syncOldStream();
point.tickDeviceWrite();
// point.tickHostRead();
}
use of org.nd4j.linalg.exception.ND4JIllegalStateException in project nd4j by deeplearning4j.
the class CudaZeroHandler method memcpy.
/**
* Synchronous version of memcpy.
*
* @param dstBuffer
* @param srcBuffer
*/
@Override
public void memcpy(DataBuffer dstBuffer, DataBuffer srcBuffer) {
// log.info("Buffer MemCpy called");
// log.info("Memcpy buffer: {} bytes ", dstBuffer.length() * dstBuffer.getElementSize());
CudaContext context = getCudaContext();
AllocationPoint dstPoint = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint();
AllocationPoint srcPoint = ((BaseCudaDataBuffer) srcBuffer).getAllocationPoint();
Pointer dP = new CudaPointer(dstPoint.getPointers().getHostPointer().address());
Pointer sP = null;
if (srcPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
sP = new CudaPointer(srcPoint.getPointers().getDevicePointer().address());
/*
JCuda.cudaMemcpyAsync(
dP,
sP,
srcBuffer.length(),
cudaMemcpyKind.cudaMemcpyHostToDevice,
context.getOldStream()
);*/
if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
} else {
sP = new CudaPointer(srcPoint.getPointers().getHostPointer().address());
/*
JCuda.cudaMemcpyAsync(
dP,
sP,
srcBuffer.length(),
cudaMemcpyKind.cudaMemcpyHostToDevice,
context.getOldStream()
);*/
if (nativeOps.memcpyAsync(dP, sP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
}
if (dstPoint.getAllocationStatus() == AllocationStatus.DEVICE) {
Pointer rDP = new CudaPointer(dstPoint.getPointers().getDevicePointer().address());
/*
JCuda.cudaMemcpyAsync(
rDP,
dP,
srcBuffer.length(),
cudaMemcpyKind.cudaMemcpyHostToDevice,
context.getOldStream()
);*/
if (nativeOps.memcpyAsync(rDP, dP, srcBuffer.length() * srcBuffer.getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getOldStream()) == 0) {
throw new ND4JIllegalStateException("memcpyAsync failed");
}
}
dstPoint.tickDeviceWrite();
// it has to be blocking call
context.syncOldStream();
}
Aggregations