use of org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper in project nd4j by deeplearning4j.
the class OnnxGraphMapperTests method testMapper.
@Test
public void testMapper() throws Exception {
try (val inputs = new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream()) {
OnnxProto3.GraphProto graphProto = OnnxProto3.ModelProto.parseFrom(inputs).getGraph();
OnnxGraphMapper onnxGraphMapper = new OnnxGraphMapper();
assertEquals(graphProto.getNodeList().size(), onnxGraphMapper.getNodeList(graphProto).size());
assertEquals(4, onnxGraphMapper.variablesForGraph(graphProto).size());
val initializer = graphProto.getInput(0).getType().getTensorType();
INDArray arr = onnxGraphMapper.getNDArrayFromTensor(graphProto.getInitializer(0).getName(), initializer, graphProto);
assumeNotNull(arr);
for (val node : graphProto.getNodeList()) {
assertEquals(node.getAttributeList().size(), onnxGraphMapper.getAttrMap(node).size());
}
val sameDiff = onnxGraphMapper.importGraph(graphProto);
assertEquals(1, sameDiff.functions().length);
System.out.println(sameDiff);
}
}
use of org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper in project nd4j by deeplearning4j.
the class Reshape method initFromOnnx.
@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
val shape = new OnnxGraphMapper().getShape(node);
this.shape = shape;
}
Aggregations