Search in sources :

Example 1 with Graph

use of org.tensorflow.Graph in project tensorflow-java-yolo by szaza.

the class ObjectDetector method normalizeImage.

/**
 * Pre-process input. It resize the image and normalize its pixels
 * @param imageBytes Input image
 * @return Tensor<Float> with shape [1][416][416][3]
 */
private Tensor<Float> normalizeImage(final byte[] imageBytes) {
    try (Graph graph = new Graph()) {
        GraphBuilder graphBuilder = new GraphBuilder(graph);
        final Output<Float> output = // Divide each pixels with the MEAN
        graphBuilder.div(// Resize using bilinear interpolation
        graphBuilder.resizeBilinear(// Increase the output tensors dimension
        graphBuilder.expandDims(// Cast the output to Float
        graphBuilder.cast(graphBuilder.decodeJpeg(graphBuilder.constant("input", imageBytes), 3), Float.class), graphBuilder.constant("make_batch", 0)), graphBuilder.constant("size", new int[] { SIZE, SIZE })), graphBuilder.constant("scale", MEAN));
        try (Session session = new Session(graph)) {
            return session.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
        }
    }
}
Also used : Graph(org.tensorflow.Graph) GraphBuilder(edu.ml.tensorflow.util.GraphBuilder) Session(org.tensorflow.Session)

Example 2 with Graph

use of org.tensorflow.Graph in project zoltar by spotify.

the class TensorFlowExtrasTest method createDummyGraph.

private static Graph createDummyGraph() {
    final Tensor<Double> t2 = Tensors.create(2.0);
    final Tensor<Double> t3 = Tensors.create(3.0);
    final Graph graph = new Graph();
    final Output<Double> input = graph.opBuilder("Placeholder", "input").setAttr("dtype", DataType.DOUBLE).build().output(0);
    final Output<Double> two = graph.opBuilder("Const", "two").setAttr("dtype", t2.dataType()).setAttr("value", t2).build().output(0);
    final Output<Double> three = graph.opBuilder("Const", "three").setAttr("dtype", t3.dataType()).setAttr("value", t3).build().output(0);
    graph.opBuilder("Mul", mul2).addInput(input).addInput(two).build();
    graph.opBuilder("Mul", mul3).addInput(input).addInput(three).build();
    return graph;
}
Also used : Graph(org.tensorflow.Graph)

Example 3 with Graph

use of org.tensorflow.Graph in project zoltar by spotify.

the class TensorFlowGraphModelTest method createADummyTFGraph.

/**
 * Creates a simple TensorFlow graph that multiplies Double on input by 2.0, result is available
 * via multiply operation.
 */
private Path createADummyTFGraph() throws IOException {
    final Path graphFile;
    try (final Graph graph = new Graph();
        final Tensor<Double> t = Tensors.create(2.0D)) {
        final Output<Double> input = graph.opBuilder("Placeholder", inputOpName).setAttr("dtype", t.dataType()).build().output(0);
        final Output<Double> two = graph.opBuilder("Const", "two").setAttr("dtype", t.dataType()).setAttr("value", t).build().output(0);
        graph.opBuilder("Mul", mulResult).addInput(two).addInput(input).build();
        graphFile = Files.createTempFile("tf-graph", ".bin");
        Files.write(graphFile, graph.toGraphDef());
    }
    return graphFile;
}
Also used : Path(java.nio.file.Path) Graph(org.tensorflow.Graph)

Example 4 with Graph

use of org.tensorflow.Graph in project zoltar by spotify.

the class TensorFlowGraphModel method create.

/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.</p>
 *
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config ConfigProto config for TensorFlow {@link Session}.
 * @param prefix a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException {
    final Graph graph = new Graph();
    final Session session = new Session(graph, config != null ? config.toByteArray() : null);
    final long loadStart = System.currentTimeMillis();
    if (prefix == null) {
        logger.debug("Loading graph definition without prefix");
        graph.importGraphDef(graphDef);
    } else {
        logger.debug("Loading graph definition with prefix: %s", prefix);
        graph.importGraphDef(graphDef, prefix);
    }
    logger.info("TensorFlow graph loaded in %d ms", System.currentTimeMillis() - loadStart);
    return new AutoValue_TensorFlowGraphModel(graph, session);
}
Also used : Graph(org.tensorflow.Graph) Session(org.tensorflow.Session)

Example 5 with Graph

use of org.tensorflow.Graph in project tensorflow-example-java by szaza.

the class ObjectDetector method executeYOLOGraph.

/**
 * Executes graph on the given preprocessed image
 * @param image preprocessed image
 * @return output tensor returned by tensorFlow
 */
private float[] executeYOLOGraph(final Tensor<Float> image) {
    try (Graph graph = new Graph()) {
        graph.importGraphDef(GRAPH_DEF);
        try (Session s = new Session(graph);
            Tensor<Float> result = s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
            float[] outputTensor = new float[YOLOClassifier.getInstance().getOutputSizeByShape(result)];
            FloatBuffer floatBuffer = FloatBuffer.wrap(outputTensor);
            result.writeTo(floatBuffer);
            return outputTensor;
        }
    }
}
Also used : Graph(org.tensorflow.Graph) FloatBuffer(java.nio.FloatBuffer) Session(org.tensorflow.Session)

Aggregations

Graph (org.tensorflow.Graph)11 Session (org.tensorflow.Session)9 Test (org.junit.Test)4 GraphBuilder (edu.ml.tensorflow.util.GraphBuilder)2 FloatBuffer (java.nio.FloatBuffer)2 Path (java.nio.file.Path)1 Tensor (org.tensorflow.Tensor)1