Search in sources :

Example 1 with OnnxGraphMapper

use of org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper in project nd4j by deeplearning4j.

the class OnnxGraphMapperTests method testMapper.

@Test
public void testMapper() throws Exception {
    try (val inputs = new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream()) {
        OnnxProto3.GraphProto graphProto = OnnxProto3.ModelProto.parseFrom(inputs).getGraph();
        OnnxGraphMapper onnxGraphMapper = new OnnxGraphMapper();
        assertEquals(graphProto.getNodeList().size(), onnxGraphMapper.getNodeList(graphProto).size());
        assertEquals(4, onnxGraphMapper.variablesForGraph(graphProto).size());
        val initializer = graphProto.getInput(0).getType().getTensorType();
        INDArray arr = onnxGraphMapper.getNDArrayFromTensor(graphProto.getInitializer(0).getName(), initializer, graphProto);
        assumeNotNull(arr);
        for (val node : graphProto.getNodeList()) {
            assertEquals(node.getAttributeList().size(), onnxGraphMapper.getAttrMap(node).size());
        }
        val sameDiff = onnxGraphMapper.importGraph(graphProto);
        assertEquals(1, sameDiff.functions().length);
        System.out.println(sameDiff);
    }
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) OnnxGraphMapper(org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper) OnnxProto3(onnx.OnnxProto3) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 2 with OnnxGraphMapper

use of org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper in project nd4j by deeplearning4j.

the class Reshape method initFromOnnx.

@Override
public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    val shape = new OnnxGraphMapper().getShape(node);
    this.shape = shape;
}
Also used : lombok.val(lombok.val) OnnxGraphMapper(org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper)

Aggregations

lombok.val (lombok.val)2 OnnxGraphMapper (org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper)2 OnnxProto3 (onnx.OnnxProto3)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)1