Search in sources :

Example 91 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class OnnxGraphMapperTests method test1dCnn.

@Test
public void test1dCnn() throws Exception {
    val loadedFile = new ClassPathResource("onnx_graphs/sm_cnn.onnx").getInputStream();
    val mapped = OnnxGraphMapper.getInstance().importGraph(loadedFile);
    System.out.println(mapped.variables());
}
Also used : lombok.val(lombok.val) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 92 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class OnnxGraphMapperTests method testLoadResnet.

@Test
public void testLoadResnet() throws Exception {
    val loadedFile = new ClassPathResource("onnx_graphs/resnet50/model.pb").getInputStream();
    val mapped = OnnxGraphMapper.getInstance().importGraph(loadedFile);
}
Also used : lombok.val(lombok.val) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 93 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class OnnxGraphMapperTests method testMapper.

@Test
public void testMapper() throws Exception {
    try (val inputs = new ClassPathResource("onnx_graphs/embedding_only.onnx").getInputStream()) {
        OnnxProto3.GraphProto graphProto = OnnxProto3.ModelProto.parseFrom(inputs).getGraph();
        OnnxGraphMapper onnxGraphMapper = new OnnxGraphMapper();
        assertEquals(graphProto.getNodeList().size(), onnxGraphMapper.getNodeList(graphProto).size());
        assertEquals(4, onnxGraphMapper.variablesForGraph(graphProto).size());
        val initializer = graphProto.getInput(0).getType().getTensorType();
        INDArray arr = onnxGraphMapper.getNDArrayFromTensor(graphProto.getInitializer(0).getName(), initializer, graphProto);
        assumeNotNull(arr);
        for (val node : graphProto.getNodeList()) {
            assertEquals(node.getAttributeList().size(), onnxGraphMapper.getAttrMap(node).size());
        }
        val sameDiff = onnxGraphMapper.importGraph(graphProto);
        assertEquals(1, sameDiff.functions().length);
        System.out.println(sameDiff);
    }
}
Also used : lombok.val(lombok.val) INDArray(org.nd4j.linalg.api.ndarray.INDArray) OnnxGraphMapper(org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper) OnnxProto3(onnx.OnnxProto3) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 94 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TFGraphTestAllHelper method readVars.

protected static Map<String, INDArray> readVars(String modelName, String base_dir, String pattern) throws IOException {
    Map<String, INDArray> varMap = new HashMap<>();
    String modelDir = base_dir + "/" + modelName;
    ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(modelDir).getClassLoader());
    Resource[] resources = resolver.getResources("classpath*:" + modelDir + "/" + pattern + ".shape");
    val dtype = Nd4j.dataType();
    for (int i = 0; i < resources.length; i++) {
        String fileName = resources[i].getFilename();
        String varPath = modelDir + "/" + fileName;
        String[] varNameArr = fileName.split("\\.");
        String varName = String.join(".", Arrays.copyOfRange(varNameArr, 0, varNameArr.length - 2));
        int[] varShape = Nd4j.readNumpy(new ClassPathResource(varPath).getInputStream(), ",").data().asInt();
        try {
            float[] varContents = Nd4j.readNumpy(new ClassPathResource(varPath.replace(".shape", ".csv")).getInputStream(), ",").data().asFloat();
            INDArray varValue;
            if (varShape.length == 1) {
                if (varShape[0] == 0) {
                    varValue = Nd4j.trueScalar(varContents[0]);
                } else {
                    varValue = Nd4j.trueVector(varContents);
                }
            } else {
                varValue = Nd4j.create(varContents, varShape);
            }
            // varValue = Nd4j.readNumpy(new ClassPathResource(varPath.replace(".shape", ".csv")).getInputStream(), ",").reshape(varShape);
            if (varName.contains("____")) {
                // these are intermediate node outputs
                varMap.put(varName.replaceAll("____", "/"), varValue);
            } else {
                varMap.put(varName, varValue);
            }
        } catch (NumberFormatException e) {
            // FIXME: we can't parse boolean arrays right now :(
            continue;
        }
    }
    return varMap;
}
Also used : lombok.val(lombok.val) PathMatchingResourcePatternResolver(org.springframework.core.io.support.PathMatchingResourcePatternResolver) ResourcePatternResolver(org.springframework.core.io.support.ResourcePatternResolver) Resource(org.springframework.core.io.Resource) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) INDArray(org.nd4j.linalg.api.ndarray.INDArray) PathMatchingResourcePatternResolver(org.springframework.core.io.support.PathMatchingResourcePatternResolver)

Example 95 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TFGraphTestAllHelper method modelDirNames.

private static String[] modelDirNames(String base_dir, ExecuteWith executeWith) throws IOException {
    ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver(new ClassPathResource(base_dir).getClassLoader());
    Resource[] resources = resolver.getResources("classpath*:" + base_dir + "/**/frozen_model.pb");
    String[] exampleNames = new String[resources.length];
    for (int i = 0; i < resources.length; i++) {
        String nestedName = resources[i].getURL().toString().split(base_dir + "/")[1];
        exampleNames[i] = nestedName.replaceAll(Pattern.quote(executeWith.getDefaultBaseDir()), "").replaceAll("/frozen_model.pb", "");
    }
    return exampleNames;
}
Also used : PathMatchingResourcePatternResolver(org.springframework.core.io.support.PathMatchingResourcePatternResolver) ResourcePatternResolver(org.springframework.core.io.support.ResourcePatternResolver) Resource(org.springframework.core.io.Resource) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) PathMatchingResourcePatternResolver(org.springframework.core.io.support.PathMatchingResourcePatternResolver) ClassPathResource(org.nd4j.linalg.io.ClassPathResource)

Aggregations

ClassPathResource (org.nd4j.linalg.io.ClassPathResource)112 Test (org.junit.Test)100 lombok.val (lombok.val)31 INDArray (org.nd4j.linalg.api.ndarray.INDArray)26 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)23 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 DataSet (org.nd4j.linalg.dataset.DataSet)23 File (java.io.File)22 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)20 FileSplit (org.datavec.api.split.FileSplit)18 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)14 Ignore (org.junit.Ignore)14 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)13 RecordReader (org.datavec.api.records.reader.RecordReader)12 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)12 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)12 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)11 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)8 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)7