Search in sources :

Example 1 with OnnxProto3

use of onnx.OnnxProto3 in project nd4j by deeplearning4j.

the class OnnxGraphMapper method variablesForGraph.

@Override
public Map<String, onnx.OnnxProto3.TypeProto.Tensor> variablesForGraph(OnnxProto3.GraphProto graphProto) {
    /**
     * Need to figure out why
     * gpu_0/conv1_1 isn't present in VGG
     */
    Map<String, onnx.OnnxProto3.TypeProto.Tensor> ret = new HashMap<>();
    for (int i = 0; i < graphProto.getInputCount(); i++) {
        ret.put(graphProto.getInput(i).getName(), graphProto.getInput(i).getType().getTensorType());
    }
    for (int i = 0; i < graphProto.getOutputCount(); i++) {
        ret.put(graphProto.getOutput(i).getName(), graphProto.getOutput(i).getType().getTensorType());
    }
    for (int i = 0; i < graphProto.getNodeCount(); i++) {
        val node = graphProto.getNode(i);
        val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName();
        // add -1 as place holder value representing the shape needs to be filled in
        if (!ret.containsKey(name)) {
            addDummyTensor(name, ret);
        }
        for (int j = 0; j < node.getInputCount(); j++) {
            if (!ret.containsKey(node.getInput(j))) {
                addDummyTensor(node.getInput(j), ret);
            }
        }
        for (int j = 0; j < node.getOutputCount(); j++) {
            if (!ret.containsKey(node.getOutput(j))) {
                addDummyTensor(node.getOutput(j), ret);
            }
        }
    }
    return ret;
}
Also used : lombok.val(lombok.val) OnnxProto3(onnx.OnnxProto3) ByteString(com.github.os72.protobuf351.ByteString)

Example 2 with OnnxProto3

use of onnx.OnnxProto3 in project nd4j by deeplearning4j.

the class OnnxGraphMapper method getNDArrayFromTensor.

@Override
public INDArray getNDArrayFromTensor(String tensorName, OnnxProto3.TypeProto.Tensor tensorProto, OnnxProto3.GraphProto graph) {
    DataBuffer.Type type = dataTypeForTensor(tensorProto);
    if (!tensorProto.isInitialized()) {
        throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
    }
    OnnxProto3.TensorProto tensor = null;
    for (int i = 0; i < graph.getInitializerCount(); i++) {
        val initializer = graph.getInitializer(i);
        if (initializer.getName().equals(tensorName)) {
            tensor = initializer;
            break;
        }
    }
    if (tensor == null)
        return null;
    ByteString bytes = tensor.getRawData();
    ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
    ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
    directAlloc.put(byteBuffer);
    directAlloc.rewind();
    int[] shape = getShapeFromTensor(tensorProto);
    DataBuffer buffer = Nd4j.createBuffer(directAlloc, type, ArrayUtil.prod(shape));
    INDArray arr = Nd4j.create(buffer).reshape(shape);
    return arr;
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ByteString(com.github.os72.protobuf351.ByteString) OnnxProto3(onnx.OnnxProto3) ND4JIllegalStateException(org.nd4j.linalg.exception.ND4JIllegalStateException) ByteBuffer(java.nio.ByteBuffer) DataBuffer(org.nd4j.linalg.api.buffer.DataBuffer)

Example 3 with OnnxProto3

use of onnx.OnnxProto3 in project nd4j by deeplearning4j.

the class OnnxGraphMapperTests method testMapper.

@Test
public void testMapper() throws Exception {
    try (val inputs = new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream()) {
        OnnxProto3.GraphProto graphProto = OnnxProto3.ModelProto.parseFrom(inputs).getGraph();
        OnnxGraphMapper onnxGraphMapper = new OnnxGraphMapper();
        assertEquals(graphProto.getNodeList().size(), onnxGraphMapper.getNodeList(graphProto).size());
        assertEquals(4, onnxGraphMapper.variablesForGraph(graphProto).size());
        val initializer = graphProto.getInput(0).getType().getTensorType();
        INDArray arr = onnxGraphMapper.getNDArrayFromTensor(graphProto.getInitializer(0).getName(), initializer, graphProto);
        assumeNotNull(arr);
        for (val node : graphProto.getNodeList()) {
            assertEquals(node.getAttributeList().size(), onnxGraphMapper.getAttrMap(node).size());
        }
        val sameDiff = onnxGraphMapper.importGraph(graphProto);
        assertEquals(1, sameDiff.functions().length);
        System.out.println(sameDiff);
    }
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) OnnxGraphMapper(org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper) OnnxProto3(onnx.OnnxProto3) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Aggregations

lombok.val (lombok.val)3 OnnxProto3 (onnx.OnnxProto3)3 ByteString (com.github.os72.protobuf351.ByteString)2 INDArray (org.nd4j.linalg.api.ndarray.INDArray)2 ByteBuffer (java.nio.ByteBuffer)1 Test (org.junit.Test)1 OnnxGraphMapper (org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper)1 DataBuffer (org.nd4j.linalg.api.buffer.DataBuffer)1 ND4JIllegalStateException (org.nd4j.linalg.exception.ND4JIllegalStateException)1 ClassPathResource (org.nd4j.linalg.io.ClassPathResource)1