Search in sources :

Example 1 with Session

use of org.tensorflow.Session in project tensorflow-java-yolo by szaza.

the class ObjectDetector method normalizeImage.

/**
 * Pre-process input. It resize the image and normalize its pixels
 * @param imageBytes Input image
 * @return Tensor<Float> with shape [1][416][416][3]
 */
private Tensor<Float> normalizeImage(final byte[] imageBytes) {
    try (Graph graph = new Graph()) {
        GraphBuilder graphBuilder = new GraphBuilder(graph);
        final Output<Float> output = // Divide each pixels with the MEAN
        graphBuilder.div(// Resize using bilinear interpolation
        graphBuilder.resizeBilinear(// Increase the output tensors dimension
        graphBuilder.expandDims(// Cast the output to Float
        graphBuilder.cast(graphBuilder.decodeJpeg(graphBuilder.constant("input", imageBytes), 3), Float.class), graphBuilder.constant("make_batch", 0)), graphBuilder.constant("size", new int[] { SIZE, SIZE })), graphBuilder.constant("scale", MEAN));
        try (Session session = new Session(graph)) {
            return session.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
        }
    }
}
Also used : Graph(org.tensorflow.Graph) GraphBuilder(edu.ml.tensorflow.util.GraphBuilder) Session(org.tensorflow.Session)

Example 2 with Session

use of org.tensorflow.Session in project zoltar by spotify.

the class TensorFlowGraphModelTest method testDummyLoadOfTensorFlowGraph.

@Test
public void testDummyLoadOfTensorFlowGraph() throws Exception {
    final Path graphFile = createADummyTFGraph();
    try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, null);
        final Session session = model.instance();
        final Tensor<Double> double3 = Tensors.create(3.0D)) {
        List<Tensor<?>> result = null;
        try {
            result = session.runner().fetch(mulResult).feed(inputOpName, double3).run();
            assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE);
        } finally {
            if (result != null) {
                result.forEach(Tensor::close);
            }
        }
    }
}
Also used : Path(java.nio.file.Path) Tensor(org.tensorflow.Tensor) Session(org.tensorflow.Session) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 3 with Session

use of org.tensorflow.Session in project zoltar by spotify.

the class TensorFlowModelTest method predict.

private static long predict(final TensorFlowModel model, final Example example) {
    // rank 1 cause we need to account for batch
    final byte[][] b = new byte[1][];
    b[0] = example.toByteArray();
    try (final Tensor<String> t = Tensors.create(b)) {
        final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t).fetch("linear/head/predictions/class_ids");
        final List<Tensor<?>> output = runner.run();
        final LongBuffer incomingClassId = LongBuffer.allocate(1);
        try {
            output.get(0).writeTo(incomingClassId);
        } finally {
            output.forEach(Tensor::close);
        }
        return incomingClassId.get(0);
    }
}
Also used : LongBuffer(java.nio.LongBuffer) Tensor(org.tensorflow.Tensor) Session(org.tensorflow.Session)

Example 4 with Session

use of org.tensorflow.Session in project zoltar by spotify.

the class TensorFlowGraphModel method create.

/**
 * Note: Please use Models from zoltar-models module.
 *
 * <p>Creates a TensorFlow model based on a frozen, serialized TensorFlow {@link Graph}.</p>
 *
 * @param graphDef byte array representing the TensorFlow {@link Graph} definition.
 * @param config ConfigProto config for TensorFlow {@link Session}.
 * @param prefix a prefix that will be prepended to names in graphDef.
 */
public static TensorFlowGraphModel create(final byte[] graphDef, @Nullable final ConfigProto config, @Nullable final String prefix) throws IOException {
    final Graph graph = new Graph();
    final Session session = new Session(graph, config != null ? config.toByteArray() : null);
    final long loadStart = System.currentTimeMillis();
    if (prefix == null) {
        logger.debug("Loading graph definition without prefix");
        graph.importGraphDef(graphDef);
    } else {
        logger.debug("Loading graph definition with prefix: %s", prefix);
        graph.importGraphDef(graphDef, prefix);
    }
    logger.info("TensorFlow graph loaded in %d ms", System.currentTimeMillis() - loadStart);
    return new AutoValue_TensorFlowGraphModel(graph, session);
}
Also used : Graph(org.tensorflow.Graph) Session(org.tensorflow.Session)

Example 5 with Session

use of org.tensorflow.Session in project tensorflow-example-java by szaza.

the class ObjectDetector method executeYOLOGraph.

/**
 * Executes graph on the given preprocessed image
 * @param image preprocessed image
 * @return output tensor returned by tensorFlow
 */
private float[] executeYOLOGraph(final Tensor<Float> image) {
    try (Graph graph = new Graph()) {
        graph.importGraphDef(GRAPH_DEF);
        try (Session s = new Session(graph);
            Tensor<Float> result = s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
            float[] outputTensor = new float[YOLOClassifier.getInstance().getOutputSizeByShape(result)];
            FloatBuffer floatBuffer = FloatBuffer.wrap(outputTensor);
            result.writeTo(floatBuffer);
            return outputTensor;
        }
    }
}
Also used : Graph(org.tensorflow.Graph) FloatBuffer(java.nio.FloatBuffer) Session(org.tensorflow.Session)

Aggregations

Session (org.tensorflow.Session)13 Graph (org.tensorflow.Graph)9 Test (org.junit.Test)6 Tensor (org.tensorflow.Tensor)5 PredictorsTest (com.spotify.zoltar.PredictorsTest)2 GraphBuilder (edu.ml.tensorflow.util.GraphBuilder)2 FloatBuffer (java.nio.FloatBuffer)2 Path (java.nio.file.Path)2 LongBuffer (java.nio.LongBuffer)1 SavedModelBundle (org.tensorflow.SavedModelBundle)1