use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method importGraph3.
@Test
@Ignore
public void importGraph3() throws Exception {
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/max_log_reg.pb.txt").getInputStream());
assertNotNull(graph);
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testWhileMapping2.
@Test
public void testWhileMapping2() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/simplewhile_0/frozen_model.pb").getInputStream());
assertNotNull(tg);
val input = Nd4j.trueScalar(4.0);
tg.associateArrayWithVariable(input, tg.getVariable("input_1"));
// tg.asFlatFile(new File("../../../libnd4j/tests_cpu/resources/simplewhile_0.fb"));
// log.info("{}", tg.asFlatPrint());
val array = tg.execAndEndResult();
val exp = Nd4j.create(2, 2).assign(2);
assertNotNull(array);
assertEquals(exp, array);
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testInferShape.
@Test
public void testInferShape() throws IOException {
/**
* node {
* name: "input"
* op: "Placeholder"
* attr {
* key: "dtype"
* value {
* type: DT_FLOAT
* }
* }
* attr {
* key: "shape"
* value {
* shape {
* dim {
* size: -1
* }
* dim {
* size: 4
* }
* }
* }
* }
* }
* node {
* name: "bias"
* op: "Const"
* attr {
* key: "dtype"
* value {
* type: DT_FLOAT
* }
* }
* attr {
* key: "value"
* value {
* tensor {
* dtype: DT_FLOAT
* tensor_shape {
* dim {
* size: 4
* }
* }
* tensor_content: "\000\000\200?\000\000\000@\000\000@@\000\000\200@"
* }
* }
* }
* }
* node {
* name: "bias/read"
* op: "Identity"
* input: "bias"
* attr {
* key: "_class"
* value {
* list {
* s: "loc:@bias"
* }
* }
* }
* attr {
* key: "T"
* value {
* type: DT_FLOAT
* }
* }
* }
* node {
* name: "output"
* op: "BiasAdd"
* input: "input"
* input: "bias/read"
* attr {
* key: "data_format"
* value {
* s: "NHWC"
* }
* }
* attr {
* key: "T"
* value {
* type: DT_FLOAT
* }
* }
* }
* library {
* }
*/
SameDiff graph = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/examples/bias_add/frozen_model.pb").getInputStream());
assertNotNull(graph);
INDArray input = Nd4j.linspace(1, 40, 40).reshape(10, 4);
INDArray expectedOutput = Nd4j.linspace(1, 40, 40).reshape(10, 4).addRowVector(Nd4j.linspace(1, 4, 4));
INDArray actual = graph.execWithPlaceHolderAndEndResult(Collections.singletonMap("input", input));
assertEquals(input, graph.getVariable("input").getArr());
assertArrayEquals(input.shape(), graph.getShapeForVarName(graph.getVariable("input").getVarName()));
assertEquals(expectedOutput, actual);
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testIntermediateLoop1.
@Test
public void testIntermediateLoop1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/simple_while.pb.txt").getInputStream());
assertNotNull(tg);
val graph = FlatGraph.getRootAsFlatGraph(tg.asFlatBuffers());
assertEquals(6, graph.variablesLength());
// assertEquals("alpha/Assign", graph.nodes(0).name());
}
use of org.nd4j.linalg.io.ClassPathResource in project nd4j by deeplearning4j.
the class TensorFlowImportTest method testIntermediate1.
@Test
public void testIntermediate1() throws Exception {
Nd4j.create(1);
val tg = TFGraphMapper.getInstance().importGraph(new ClassPathResource("tf_graphs/tensorflow_inception_graph.pb").getInputStream());
assertTrue(tg.getVariable("input") != null);
// assertTrue(tg.getVariableSpace().getVariable("input").isPlaceholder());
val ipod = Nd4j.read(new DataInputStream(new ClassPathResource("tf_graphs/ipod.nd4").getInputStream()));
tg.updateVariable("input", ipod);
val buffer = tg.asFlatBuffers();
assertNotNull(buffer);
}
Aggregations