use of org.tensorflow.Session in project zoltar by spotify.
the class TensorFlowExtrasTest method testExtract1.
@Test
public void testExtract1() {
final Graph graph = createDummyGraph();
final Session session = new Session(graph);
final Session.Runner runner = session.runner();
runner.feed("input", Tensors.create(10.0));
final Map<String, JTensor> result = TensorFlowExtras.runAndExtract(runner, mul2);
assertEquals(Sets.newHashSet(mul2), result.keySet());
assertScalar(result.get(mul2), 20.0);
session.close();
graph.close();
}
use of org.tensorflow.Session in project zoltar by spotify.
the class TensorFlowGraphModelTest method testDummyLoadOfTensorFlowGraphWithPrefix.
@Test
public void testDummyLoadOfTensorFlowGraphWithPrefix() throws Exception {
final String prefix = "test";
final Path graphFile = createADummyTFGraph();
try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, prefix);
final Session session = model.instance();
final Tensor<Double> double3 = Tensors.create(3.0D)) {
List<Tensor<?>> result = null;
try {
result = session.runner().fetch(prefix + "/" + mulResult).feed(prefix + "/" + inputOpName, double3).run();
assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE);
} finally {
if (result != null) {
result.forEach(Tensor::close);
}
}
}
}
use of org.tensorflow.Session in project tensorflow-example-java by szaza.
the class ObjectDetector method normalizeImage.
/**
* Pre-process input. It resize the image and normalize its pixels
* @param imageBytes Input image
* @return Tensor<Float> with shape [1][416][416][3]
*/
private Tensor<Float> normalizeImage(final byte[] imageBytes) {
try (Graph graph = new Graph()) {
GraphBuilder graphBuilder = new GraphBuilder(graph);
final Output<Float> output = // Divide each pixels with the MEAN
graphBuilder.div(// Resize using bilinear interpolation
graphBuilder.resizeBilinear(// Increase the output tensors dimension
graphBuilder.expandDims(// Cast the output to Float
graphBuilder.cast(graphBuilder.decodeJpeg(graphBuilder.constant("input", imageBytes), 3), Float.class), graphBuilder.constant("make_batch", 0)), graphBuilder.constant("size", new int[] { SIZE, SIZE })), graphBuilder.constant("scale", MEAN));
try (Session session = new Session(graph)) {
return session.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
Aggregations