use of org.tensorflow.Graph in project tensorflow-java-yolo by szaza.
the class ObjectDetector method normalizeImage.
/**
* Pre-process input. It resize the image and normalize its pixels
* @param imageBytes Input image
* @return Tensor<Float> with shape [1][416][416][3]
*/
private Tensor<Float> normalizeImage(final byte[] imageBytes) {
try (Graph graph = new Graph()) {
GraphBuilder graphBuilder = new GraphBuilder(graph);
final Output<Float> output = // Divide each pixels with the MEAN
graphBuilder.div(// Resize using bilinear interpolation
graphBuilder.resizeBilinear(// Increase the output tensors dimension
graphBuilder.expandDims(// Cast the output to Float
graphBuilder.cast(graphBuilder.decodeJpeg(graphBuilder.constant("input", imageBytes), 3), Float.class), graphBuilder.constant("make_batch", 0)), graphBuilder.constant("size", new int[] { SIZE, SIZE })), graphBuilder.constant("scale", MEAN));
try (Session session = new Session(graph)) {
return session.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
}
}
}
use of org.tensorflow.Graph in project zoltar by spotify.
the class TensorFlowExtrasTest method createDummyGraph.
private static Graph createDummyGraph() {
final Tensor<Double> t2 = Tensors.create(2.0);
final Tensor<Double> t3 = Tensors.create(3.0);
final Graph graph = new Graph();
final Output<Double> input = graph.opBuilder("Placeholder", "input").setAttr("dtype", DataType.DOUBLE).build().output(0);
final Output<Double> two = graph.opBuilder("Const", "two").setAttr("dtype", t2.dataType()).setAttr("value", t2).build().output(0);
final Output<Double> three = graph.opBuilder("Const", "three").setAttr("dtype", t3.dataType()).setAttr("value", t3).build().output(0);
graph.opBuilder("Mul", mul2).addInput(input).addInput(two).build();
graph.opBuilder("Mul", mul3).addInput(input).addInput(three).build();
return graph;
}
use of org.tensorflow.Graph in project zoltar by spotify.
the class TensorFlowGraphModelTest method createADummyTFGraph.
/**
* Creates a simple TensorFlow graph that multiplies Double on input by 2.0, result is available
* via multiply operation.
*/
private Path createADummyTFGraph() throws IOException {
final Path graphFile;
try (final Graph graph = new Graph();
final Tensor<Double> t = Tensors.create(2.0D)) {
final Output<Double> input = graph.opBuilder("Placeholder", inputOpName).setAttr("dtype", t.dataType()).build().output(0);
final Output<Double> two = graph.opBuilder("Const", "two").setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);
graph.opBuilder("Mul", mulResult).addInput(two).addInput(input).build();
graphFile = Files.createTempFile("tf-graph", ".bin");
Files.write(graphFile, graph.toGraphDef());
}
return graphFile;
}
use of org.tensorflow.Graph in project zoltar by spotify.
the class TensorFlowGraphModel method create.
/**
* Note: Please use Models from zoltar-models module.
*
* <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.</p>
*
* @param graphDef byte array representing the TensorFlow {@link Graph} definition.
* @param config ConfigProto config for TensorFlow {@link Session}.
* @param prefix a prefix that will be prepended to names in graphDef.
*/
public static TensorFlowGraphModel create(final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException {
final Graph graph = new Graph();
final Session session = new Session(graph, config != null ? config.toByteArray() : null);
final long loadStart = System.currentTimeMillis();
if (prefix == null) {
logger.debug("Loading graph definition without prefix");
graph.importGraphDef(graphDef);
} else {
logger.debug("Loading graph definition with prefix: %s", prefix);
graph.importGraphDef(graphDef, prefix);
}
logger.info("TensorFlow graph loaded in %d ms", System.currentTimeMillis() - loadStart);
return new AutoValue_TensorFlowGraphModel(graph, session);
}
use of org.tensorflow.Graph in project tensorflow-example-java by szaza.
the class ObjectDetector method executeYOLOGraph.
/**
* Executes graph on the given preprocessed image
* @param image preprocessed image
* @return output tensor returned by tensorFlow
*/
private float[] executeYOLOGraph(final Tensor<Float> image) {
try (Graph graph = new Graph()) {
graph.importGraphDef(GRAPH_DEF);
try (Session s = new Session(graph);
Tensor<Float> result = s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
float[] outputTensor = new float[YOLOClassifier.getInstance().getOutputSizeByShape(result)];
FloatBuffer floatBuffer = FloatBuffer.wrap(outputTensor);
result.writeTo(floatBuffer);
return outputTensor;
}
}
}
Aggregations