use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testIfStatementNodes.
@Test
public void testIfStatementNodes() throws Exception {
// /home/agibsonccc/code/dl4j-test-resources/src/main/resources/tf_graphs/examples/simple_cond/frozen_graph.pbtxt
val resourceInputStream = new ClassPathResource("/tf_graphs/examples/simple_cond/frozen_model.pb").getInputStream();
val mapper = TFGraphMapper.getInstance();
val readGraph = TFGraphMapper.getInstance().parseGraphFrom(resourceInputStream);
val nodes = mapper.nodesByName(readGraph);
/**
* Work backwards starting fom the condition id (usually a name containing condid/pred_id:
*/
val firstInput = nodes.get("cond5/Merge");
val ifNodes = mapper.nodesForIf(firstInput, readGraph);
assertEquals(5, ifNodes.getFalseNodes().size());
assertEquals(5, ifNodes.getTrueNodes().size());
assertEquals(10, ifNodes.getCondNodes().size());
val secondInput = nodes.get("cond6/Merge");
val ifNodesTwo = mapper.nodesForIf(secondInput, readGraph);
assertEquals(5, ifNodesTwo.getFalseNodes().size());
assertEquals(5, ifNodesTwo.getTrueNodes().size());
assertEquals(6, ifNodesTwo.getCondNodes().size());
val parentContext = SameDiff.create();
val ifStatement = new If();
ifStatement.initFromTensorFlow(firstInput, parentContext, Collections.emptyMap(), readGraph);
assertNotNull(ifStatement.getLoopBodyExecution());
assertNotNull(ifStatement.getFalseBodyExecution());
assertNotNull(ifStatement.getPredicateExecution());
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testImportIris.
@Test
@Ignore
public void testImportIris() throws Exception {
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/train_iris.pb").getInputStream());
assertNotNull(graph);
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testCondMapping2.
@Test
public void testCondMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simpleif_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.create(2, 2).assign(-1);
tg.associateArrayWithVariable(input, tg.getVariable("input_0"));
// tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simpleif_0.fb"));
// log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult();
val exp = Nd4j.create(2, 2).assign(1);
assertNotNull(array);
assertEquals(exp, array);
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TwoPointApproximationTest method testLinspaceDerivative.
@Test
public void testLinspaceDerivative() throws Exception {
String basePath = "/two_points_approx_deriv_numpy/";
INDArray linspace = Nd4j.createNpyFromInputStream(new ClassPathResource(basePath + "x.npy").getInputStream());
INDArray yLinspace = Nd4j.createNpyFromInputStream(new ClassPathResource(basePath + "y.npy").getInputStream());
Function<INDArray, INDArray> f = new Function<INDArray, INDArray>() {
@Override
public INDArray apply(INDArray indArray) {
return indArray.add(1);
}
};
INDArray test = TwoPointApproximation.approximateDerivative(f, linspace, null, yLinspace, Nd4j.create(new double[] { Float.MIN_VALUE, Float.MAX_VALUE }));
INDArray npLoad = Nd4j.createNpyFromInputStream(new ClassPathResource(basePath + "approx_deriv_small.npy").getInputStream());
assertEquals(npLoad, test);
System.out.println(test);
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class ExecutionTests method testStoredGraph_1.
@Test
public void testStoredGraph_1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/reduce_dim.pb.txt").getInputStream());
val array0 = Nd4j.create(3, 3).assign(2.0);
val map0 = new LinkedHashMap<String, INDArray>();
map0.put("alpha", array0);
val result_0 = tg.yetAnotherExecMethod(map0);
val exp_0 = Nd4j.create(3, 1).assign(6.0);
assertEquals(exp_0, result_0);
val array1 = Nd4j.create(3, 3).assign(3.0);
val map1 = new LinkedHashMap<String, INDArray>();
map1.put("alpha", array1);
val result_1 = tg.yetAnotherExecMethod(map1);
val exp_1 = Nd4j.create(3, 1).assign(9.0);
assertEquals(exp_1, result_1);
}
Aggregations