Search in sources :

Example 1 with JFeatureSpec

use of com.spotify.featran.java.JFeatureSpec in project zoltar by spotify.

the class TensorFlowGraphModelTest method testModelInference.

@Test
public void testModelInference() throws Exception {
    final Path graphFile = createADummyTFGraph();
    final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
    final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
    final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
        try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
            List<Tensor<?>> results = null;
            try {
                results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
                return Prediction.create(vector.input(), results.get(0).doubleValue());
            } finally {
                if (results != null) {
                    results.forEach(Tensor::close);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }).collect(Collectors.toList());
    final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
    final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
    final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
    final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
        return predictions.stream().mapToDouble(Prediction::value).toArray();
    }).toCompletableFuture();
    assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
Also used : Path(java.nio.file.Path) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Arrays(java.util.Arrays) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) Graph(org.tensorflow.Graph) Assert.assertArrayEquals(org.junit.Assert.assertArrayEquals) Output(org.tensorflow.Output) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Path(java.nio.file.Path) Prediction(com.spotify.zoltar.Prediction) Files(java.nio.file.Files) IOException(java.io.IOException) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) Identity(com.spotify.featran.transformers.Identity) JFeatureSpec(com.spotify.featran.java.JFeatureSpec) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.assertEquals(org.junit.Assert.assertEquals) Tensor(org.tensorflow.Tensor) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) IOException(java.io.IOException) List(java.util.List) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Aggregations

JFeatureSpec (com.spotify.featran.java.JFeatureSpec)1 Identity (com.spotify.featran.transformers.Identity)1 ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)1 ModelLoader (com.spotify.zoltar.ModelLoader)1 PredictFn (com.spotify.zoltar.PredictFns.PredictFn)1 Prediction (com.spotify.zoltar.Prediction)1 PredictorsTest (com.spotify.zoltar.PredictorsTest)1 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)1 IOException (java.io.IOException)1 URI (java.net.URI)1 StandardCharsets (java.nio.charset.StandardCharsets)1 Files (java.nio.file.Files)1 Path (java.nio.file.Path)1 Paths (java.nio.file.Paths)1 Arrays (java.util.Arrays)1 List (java.util.List)1 CompletableFuture (java.util.concurrent.CompletableFuture)1 Collectors (java.util.stream.Collectors)1 Assert.assertArrayEquals (org.junit.Assert.assertArrayEquals)1 Assert.assertEquals (org.junit.Assert.assertEquals)1