Search in sources :

Example 86 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TensorFlowImportTest method testIfStatementNodes.

@Test
public void testIfStatementNodes() throws Exception {
    // /home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/examples/simple_cond/frozen_graph.pbtxt
    val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream();
    val mapper = TFGraphMapper.getInstance();
    val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream);
    val nodes = mapper.nodesByName(readGraph);
    /**
     * Work backwards starting fom the condition id (usually a name containing condid/pred_id:
     */
    val firstInput = nodes.get("cond5/Merge");
    val ifNodes = mapper.nodesForIf(firstInput, readGraph);
    assertEquals(5, ifNodes.getFalseNodes().size());
    assertEquals(5, ifNodes.getTrueNodes().size());
    assertEquals(10, ifNodes.getCondNodes().size());
    val secondInput = nodes.get("cond6/Merge");
    val ifNodesTwo = mapper.nodesForIf(secondInput, readGraph);
    assertEquals(5, ifNodesTwo.getFalseNodes().size());
    assertEquals(5, ifNodesTwo.getTrueNodes().size());
    assertEquals(6, ifNodesTwo.getCondNodes().size());
    val parentContext = SameDiff.create();
    val ifStatement = new If();
    ifStatement.initFromTensorFlow(firstInput, parentContext, Collections.emptyMap(), readGraph);
    assertNotNull(ifStatement.getLoopBodyExecution());
    assertNotNull(ifStatement.getFalseBodyExecution());
    assertNotNull(ifStatement.getPredicateExecution());
}
Also used : lombok.val(lombok.val) If(org.nd4j.linalg.api.ops.impl.controlflow.If) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 87 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TensorFlowImportTest method testImportIris.

@Test
@Ignore
public void testImportIris() throws Exception {
    SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream());
    assertNotNull(graph);
}
Also used : SameDiff(org.nd4j.autodiff.samediff.SameDiff) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 88 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TensorFlowImportTest method testCondMapping2.

@Test
public void testCondMapping2() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
    assertNotNull(tg);
    val input = Nd4j.create(2, 2).assign(-1);
    tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
    // tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
    // log.info("{}", tg.asFlatPrint());
    val array = tg.execAndEndResult();
    val exp = Nd4j.create(2, 2).assign(1);
    assertNotNull(array);
    assertEquals(exp, array);
}
Also used : lombok.val(lombok.val) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 89 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class TwoPointApproximationTest method testLinspaceDerivative.

@Test
public void testLinspaceDerivative() throws Exception {
    String basePath = "/two_points_approx_deriv_numpy/";
    INDArray linspace = Nd4j.createNpyFromInputStream(new ClassPathResource(basePath + "x.npy").getInputStream());
    INDArray yLinspace = Nd4j.createNpyFromInputStream(new ClassPathResource(basePath + "y.npy").getInputStream());
    Function<INDArray, INDArray> f = new Function<INDArray, INDArray>() {

        @Override
        public INDArray apply(INDArray indArray) {
            return indArray.add(1);
        }
    };
    INDArray test = TwoPointApproximation.approximateDerivative(f, linspace, null, yLinspace, Nd4j.create(new double[] { Float.MIN_VALUE, Float.MAX_VALUE }));
    INDArray npLoad = Nd4j.createNpyFromInputStream(new ClassPathResource(basePath + "approx_deriv_small.npy").getInputStream());
    assertEquals(npLoad, test);
    System.out.println(test);
}
Also used : Function(org.nd4j.linalg.function.Function) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) Test(org.junit.Test)

Example 90 with ClassPathResource

use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.

the class ExecutionTests method testStoredGraph_1.

@Test
public void testStoredGraph_1() throws Exception {
    Nd4j.create(1);
    val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream());
    val array0 = Nd4j.create(3, 3).assign(2.0);
    val map0 = new LinkedHashMap<String, INDArray>();
    map0.put("alpha", array0);
    val result_0 = tg.yetAnotherExecMethod(map0);
    val exp_0 = Nd4j.create(3, 1).assign(6.0);
    assertEquals(exp_0, result_0);
    val array1 = Nd4j.create(3, 3).assign(3.0);
    val map1 = new LinkedHashMap<String, INDArray>();
    map1.put("alpha", array1);
    val result_1 = tg.yetAnotherExecMethod(map1);
    val exp_1 = Nd4j.create(3, 1).assign(9.0);
    assertEquals(exp_1, result_1);
}
Also used : lombok.val(lombok.val) ClassPathResource(org.nd4j.linalg.io.ClassPathResource) LinkedHashMap(java.util.LinkedHashMap) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

ClassPathResource (org.nd4j.linalg.io.ClassPathResource)112 Test (org.junit.Test)100 lombok.val (lombok.val)31 INDArray (org.nd4j.linalg.api.ndarray.INDArray)26 SequenceRecordReader (org.datavec.api.records.reader.SequenceRecordReader)23 CSVSequenceRecordReader (org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader)23 DataSet (org.nd4j.linalg.dataset.DataSet)23 File (java.io.File)22 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)20 FileSplit (org.datavec.api.split.FileSplit)18 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)14 Ignore (org.junit.Ignore)14 CSVRecordReader (org.datavec.api.records.reader.impl.csv.CSVRecordReader)13 RecordReader (org.datavec.api.records.reader.RecordReader)12 NumberedFileInputSplit (org.datavec.api.split.NumberedFileInputSplit)12 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)12 MultiDataSet (org.nd4j.linalg.dataset.api.MultiDataSet)11 MultiDataSetIterator (org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator)8 RecordMetaData (org.datavec.api.records.metadata.RecordMetaData)7 ImageRecordReader (org.datavec.image.recordreader.ImageRecordReader)7