Search in sources :

Example 1 with Tensor

use of org.tensorflow.Tensor in project zoltar by spotify.

the class TensorFlowGraphModelTest method testModelInference.

@Test
public void testModelInference() throws Exception {
    final Path graphFile = createADummyTFGraph();
    final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
    final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
    final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
        try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
            List<Tensor<?>> results = null;
            try {
                results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
                return Prediction.create(vector.input(), results.get(0).doubleValue());
            } finally {
                if (results != null) {
                    results.forEach(Tensor::close);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }).collect(Collectors.toList());
    final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
    final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
    final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
    final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
        return predictions.stream().mapToDouble(Prediction::value).toArray();
    }).toCompletableFuture();
    assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
Also used : Path(java.nio.file.Path) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Arrays(java.util.Arrays) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) Graph(org.tensorflow.Graph) Assert.assertArrayEquals(org.junit.Assert.assertArrayEquals) Output(org.tensorflow.Output) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Path(java.nio.file.Path) Prediction(com.spotify.zoltar.Prediction) Files(java.nio.file.Files) IOException(java.io.IOException) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) Identity(com.spotify.featran.transformers.Identity) JFeatureSpec(com.spotify.featran.java.JFeatureSpec) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.assertEquals(org.junit.Assert.assertEquals) Tensor(org.tensorflow.Tensor) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) IOException(java.io.IOException) List(java.util.List) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 2 with Tensor

use of org.tensorflow.Tensor in project zoltar by spotify.

the class TensorFlowGraphModelTest method testDummyLoadOfTensorFlowGraph.

@Test
public void testDummyLoadOfTensorFlowGraph() throws Exception {
    final Path graphFile = createADummyTFGraph();
    try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, null);
        final Session session = model.instance();
        final Tensor<Double> double3 = Tensors.create(3.0D)) {
        List<Tensor<?>> result = null;
        try {
            result = session.runner().fetch(mulResult).feed(inputOpName, double3).run();
            assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE);
        } finally {
            if (result != null) {
                result.forEach(Tensor::close);
            }
        }
    }
}
Also used : Path(java.nio.file.Path) Tensor(org.tensorflow.Tensor) Session(org.tensorflow.Session) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 3 with Tensor

use of org.tensorflow.Tensor in project zoltar by spotify.

the class TensorFlowModelTest method predict.

private static long predict(final TensorFlowModel model, final Example example) {
    // rank 1 cause we need to account for batch
    final byte[][] b = new byte[1][];
    b[0] = example.toByteArray();
    try (final Tensor<String> t = Tensors.create(b)) {
        final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t).fetch("linear/head/predictions/class_ids");
        final List<Tensor<?>> output = runner.run();
        final LongBuffer incomingClassId = LongBuffer.allocate(1);
        try {
            output.get(0).writeTo(incomingClassId);
        } finally {
            output.forEach(Tensor::close);
        }
        return incomingClassId.get(0);
    }
}
Also used : LongBuffer(java.nio.LongBuffer) Tensor(org.tensorflow.Tensor) Session(org.tensorflow.Session)

Example 4 with Tensor

use of org.tensorflow.Tensor in project sponge by softelnet.

the class TensorflowTest method testTf.

@Test
public void testTf() throws UnsupportedEncodingException {
    try (Graph g = new Graph()) {
        final String value = "Hello from " + TensorFlow.version();
        // named "MyConst" with a value "value".
        try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
            // The Java API doesn't yet include convenience functions for adding operations.
            g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
        }
        // Execute the "MyConst" operation in a Session.
        try (Session s = new Session(g);
            Tensor output = s.runner().fetch("MyConst").run().get(0)) {
            logger.info(new String(output.bytesValue(), "UTF-8"));
            logger.info(output.toString());
        }
    }
}
Also used : Graph(org.tensorflow.Graph) Tensor(org.tensorflow.Tensor) Session(org.tensorflow.Session) Test(org.junit.Test)

Example 5 with Tensor

use of org.tensorflow.Tensor in project sponge by softelnet.

the class TensorflowTest method testLoadModel.

// TODO @Test
public void testLoadModel() throws Exception {
    String modelDir = "examples/tensorflow/estimator/model";
    SavedModelBundle bundle = SavedModelBundle.load(modelDir + "/" + SpongeUtils.getLastSubdirectory(modelDir), "serve");
    try (Session s = bundle.session()) /* ; Tensor output = s.runner().fetch("MyConst").run().get(0) */
    {
        Tensor x = Tensor.create(new float[] { 2, 5, 8, 1 });
        Tensor y = s.runner().feed("x", x).fetch("y").run().get(0);
        logger.info("y = {}", y.floatValue());
    }
}
Also used : Tensor(org.tensorflow.Tensor) SavedModelBundle(org.tensorflow.SavedModelBundle) Session(org.tensorflow.Session)

Aggregations

Session (org.tensorflow.Session)6 Tensor (org.tensorflow.Tensor)6 Test (org.junit.Test)4 PredictorsTest (com.spotify.zoltar.PredictorsTest)3 Path (java.nio.file.Path)3 Graph (org.tensorflow.Graph)2 JFeatureSpec (com.spotify.featran.java.JFeatureSpec)1 Identity (com.spotify.featran.transformers.Identity)1 ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)1 ModelLoader (com.spotify.zoltar.ModelLoader)1 PredictFn (com.spotify.zoltar.PredictFns.PredictFn)1 Prediction (com.spotify.zoltar.Prediction)1 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)1 IOException (java.io.IOException)1 URI (java.net.URI)1 LongBuffer (java.nio.LongBuffer)1 StandardCharsets (java.nio.charset.StandardCharsets)1 Files (java.nio.file.Files)1 Paths (java.nio.file.Paths)1 Arrays (java.util.Arrays)1