use of onnx.OnnxProto3 in project nd4j by deeplearning4j.
the class OnnxGraphMapper method variablesForGraph.
@Override
public Map<String, onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
/**
* Need to figure out why
* gpu_0/conv1_1 isn't present in VGG
*/
Map<String, onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
for (int i = 0; i < graphProto.getInputCount(); i++) {
ret.put(graphProto.getInput(i).getName(), graphProto.getInput(i).getType().getTensorType());
}
for (int i = 0; i < graphProto.getOutputCount(); i++) {
ret.put(graphProto.getOutput(i).getName(), graphProto.getOutput(i).getType().getTensorType());
}
for (int i = 0; i < graphProto.getNodeCount(); i++) {
val node = graphProto.getNode(i);
val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName();
// add -1 as place holder value representing the shape needs to be filled in
if (!ret.containsKey(name)) {
addDummyTensor(name, ret);
}
for (int j = 0; j < node.getInputCount(); j++) {
if (!ret.containsKey(node.getInput(j))) {
addDummyTensor(node.getInput(j), ret);
}
}
for (int j = 0; j < node.getOutputCount(); j++) {
if (!ret.containsKey(node.getOutput(j))) {
addDummyTensor(node.getOutput(j), ret);
}
}
}
return ret;
}
use of onnx.OnnxProto3 in project nd4j by deeplearning4j.
the class OnnxGraphMapper method getNDArrayFromTensor.
@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
DataBuffer.Type type = dataTypeForTensor(tensorProto);
if (!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
}
OnnxProto3.TensorProto tensor = null;
for (int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i);
if (initializer.getName().equals(tensorName)) {
tensor = initializer;
break;
}
}
if (tensor == null)
return null;
ByteString bytes = tensor.getRawData();
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
directAlloc.put(byteBuffer);
directAlloc.rewind();
int[] shape = getShapeFromTensor(tensorProto);
DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
INDArray arr = Nd4j.create(buffer).reshape(shape);
return arr;
}
use of onnx.OnnxProto3 in project nd4j by deeplearning4j.
the class OnnxGraphMapperTests method testMapper.
@Test
public void testMapper() throws Exception {
try (val inputs = new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream()) {
OnnxProto3.GraphProto graphProto = OnnxProto3.ModelProto.parseFrom(inputs).getGraph();
OnnxGraphMapper onnxGraphMapper = new OnnxGraphMapper();
assertEquals(graphProto.getNodeList().size(), onnxGraphMapper.getNodeList(graphProto).size());
assertEquals(4, onnxGraphMapper.variablesForGraph(graphProto).size());
val initializer = graphProto.getInput(0).getType().getTensorType();
INDArray arr = onnxGraphMapper.getNDArrayFromTensor(graphProto.getInitializer(0).getName(), initializer, graphProto);
assumeNotNull(arr);
for (val node : graphProto.getNodeList()) {
assertEquals(node.getAttributeList().size(), onnxGraphMapper.getAttrMap(node).size());
}
val sameDiff = onnxGraphMapper.importGraph(graphProto);
assertEquals(1, sameDiff.functions().length);
System.out.println(sameDiff);
}
}
Aggregations