use of com.github.os72.protobuf351.ByteString in project nd4j by deeplearning4j.
the class OnnxGraphMapper method mapTensorProto.
public INDArray mapTensorProto(OnnxProto3.TensorProto tensor) {
if (tensor == null)
return null;
DataBuffer.Type type = nd4jTypeFromOnnxType(tensor.getDataType());
ByteString bytes = tensor.getRawData();
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
directAlloc.put(byteBuffer);
directAlloc.rewind();
int[] shape = getShapeFromTensor(tensor);
DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
INDArray arr = Nd4j.create(buffer).reshape(shape);
return arr;
}
use of com.github.os72.protobuf351.ByteString in project nd4j by deeplearning4j.
the class OnnxGraphMapper method getNDArrayFromTensor.
@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
DataBuffer.Type type = dataTypeForTensor(tensorProto);
if (!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
}
OnnxProto3.TensorProto tensor = null;
for (int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i);
if (initializer.getName().equals(tensorName)) {
tensor = initializer;
break;
}
}
if (tensor == null)
return null;
ByteString bytes = tensor.getRawData();
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
directAlloc.put(byteBuffer);
directAlloc.rewind();
int[] shape = getShapeFromTensor(tensorProto);
DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
INDArray arr = Nd4j.create(buffer).reshape(shape);
return arr;
}
Aggregations