Search in sources :

Example 1 with ModelLoader

use of com.spotify.zoltar.ModelLoader in project zoltar by spotify.

the class TensorFlowGraphModelTest method testModelInference.

@Test
public void testModelInference() throws Exception {
    final Path graphFile = createADummyTFGraph();
    final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
    final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
    final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
        try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
            List<Tensor<?>> results = null;
            try {
                results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
                return Prediction.create(vector.input(), results.get(0).doubleValue());
            } finally {
                if (results != null) {
                    results.forEach(Tensor::close);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }).collect(Collectors.toList());
    final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
    final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
    final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
    final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
        return predictions.stream().mapToDouble(Prediction::value).toArray();
    }).toCompletableFuture();
    assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
Also used : Path(java.nio.file.Path) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Arrays(java.util.Arrays) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) Graph(org.tensorflow.Graph) Assert.assertArrayEquals(org.junit.Assert.assertArrayEquals) Output(org.tensorflow.Output) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Path(java.nio.file.Path) Prediction(com.spotify.zoltar.Prediction) Files(java.nio.file.Files) IOException(java.io.IOException) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) Identity(com.spotify.featran.transformers.Identity) JFeatureSpec(com.spotify.featran.java.JFeatureSpec) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.assertEquals(org.junit.Assert.assertEquals) Tensor(org.tensorflow.Tensor) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) IOException(java.io.IOException) List(java.util.List) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 2 with ModelLoader

use of com.spotify.zoltar.ModelLoader in project zoltar by spotify.

the class TensorFlowModelTest method getTFIrisPredictor.

public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
    final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
    final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : Example(org.tensorflow.example.Example) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) LongBuffer(java.nio.LongBuffer) Assert(org.junit.Assert) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) Example(org.tensorflow.example.Example) List(java.util.List)

Aggregations

ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)2 ModelLoader (com.spotify.zoltar.ModelLoader)2 Prediction (com.spotify.zoltar.Prediction)2 PredictorsTest (com.spotify.zoltar.PredictorsTest)2 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)2 URI (java.net.URI)2 StandardCharsets (java.nio.charset.StandardCharsets)2 Files (java.nio.file.Files)2 Paths (java.nio.file.Paths)2 List (java.util.List)2 CompletableFuture (java.util.concurrent.CompletableFuture)2 Collectors (java.util.stream.Collectors)2 Test (org.junit.Test)2 Session (org.tensorflow.Session)2 Tensor (org.tensorflow.Tensor)2 Tensors (org.tensorflow.Tensors)2 ImmutableMap (com.google.common.collect.ImmutableMap)1 JFeatureSpec (com.spotify.featran.java.JFeatureSpec)1 Identity (com.spotify.featran.transformers.Identity)1 CompletableFutures (com.spotify.futures.CompletableFutures)1