use of org.tensorflow.Session in project tensorflow-java-yolo by szaza.
the class ObjectDetector method normalizeImage.
/**
* Pre-process input. It resize the image and normalize its pixels
* @param imageBytes Input image
* @return Tensor<Float> with shape [1][416][416][3]
*/
private Tensor<Float> normalizeImage(final byte[] imageBytes) {
try (Graph graph = new Graph()) {
GraphBuilder graphBuilder = new GraphBuilder(graph);
final Output<Float> output = // Divide each pixels with the MEAN
graphBuilder.div(// Resize using bilinear interpolation
graphBuilder.resizeBilinear(// Increase the output tensors dimension
graphBuilder.expandDims(// Cast the output to Float
graphBuilder.cast(graphBuilder.decodeJpeg(graphBuilder.constant("input", imageBytes), 3), Float.class), graphBuilder.constant("make_batch", 0)), graphBuilder.constant("size", new int[] { SIZE, SIZE })), graphBuilder.constant("scale", MEAN));
try (Session session = new Session(graph)) {
return session.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
use of org.tensorflow.Session in project zoltar by spotify.
the class TensorFlowGraphModelTest method testDummyLoadOfTensorFlowGraph.
@Test
public void testDummyLoadOfTensorFlowGraph() throws Exception {
final Path graphFile = createADummyTFGraph();
try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, null);
final Session session = model.instance();
final Tensor<Double> double3 = Tensors.create(3.0D)) {
List<Tensor<?>> result = null;
try {
result = session.runner().fetch(mulResult).feed(inputOpName, double3).run();
assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE);
} finally {
if (result != null) {
result.forEach(Tensor::close);
}
}
}
}
use of org.tensorflow.Session in project zoltar by spotify.
the class TensorFlowModelTest method predict.
private static long predict(final TensorFlowModel model, final Example example) {
// rank 1 cause we need to account for batch
final byte[][] b = new byte[1][];
b[0] = example.toByteArray();
try (final Tensor<String> t = Tensors.create(b)) {
final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t).fetch("linear/head/predictions/class_ids");
final List<Tensor<?>> output = runner.run();
final LongBuffer incomingClassId = LongBuffer.allocate(1);
try {
output.get(0).writeTo(incomingClassId);
} finally {
output.forEach(Tensor::close);
}
return incomingClassId.get(0);
}
}
use of org.tensorflow.Session in project zoltar by spotify.
the class TensorFlowGraphModel method create.
/**
* Note: Please use Models from zoltar-models module.
*
* <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.</p>
*
* @param graphDef byte array representing the TensorFlow {@link Graph} definition.
* @param config ConfigProto config for TensorFlow {@link Session}.
* @param prefix a prefix that will be prepended to names in graphDef.
*/
public static TensorFlowGraphModel create(final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException {
final Graph graph = new Graph();
final Session session = new Session(graph, config != null ? config.toByteArray() : null);
final long loadStart = System.currentTimeMillis();
if (prefix == null) {
logger.debug("Loading graph definition without prefix");
graph.importGraphDef(graphDef);
} else {
logger.debug("Loading graph definition with prefix: %s", prefix);
graph.importGraphDef(graphDef, prefix);
}
logger.info("TensorFlow graph loaded in %d ms", System.currentTimeMillis() - loadStart);
return new AutoValue_TensorFlowGraphModel(graph, session);
}
use of org.tensorflow.Session in project tensorflow-example-java by szaza.
the class ObjectDetector method executeYOLOGraph.
/**
* Executes graph on the given preprocessed image
* @param image preprocessed image
* @return output tensor returned by tensorFlow
*/
private float[] executeYOLOGraph(final Tensor<Float> image) {
try (Graph graph = new Graph()) {
graph.importGraphDef(GRAPH_DEF);
try (Session s = new Session(graph);
Tensor<Float> result = s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
float[] outputTensor = new float[YOLOClassifier.getInstance().getOutputSizeByShape(result)];
FloatBuffer floatBuffer = FloatBuffer.wrap(outputTensor);
result.writeTo(floatBuffer);
return outputTensor;
}
}
}
Aggregations