Search in sources :

Example 1 with PredictFn

use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.

the class TensorFlowGraphModelTest method testModelInference.

@Test
public void testModelInference() throws Exception {
    final Path graphFile = createADummyTFGraph();
    final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
    final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
    final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
        try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
            List<Tensor<?>> results = null;
            try {
                results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
                return Prediction.create(vector.input(), results.get(0).doubleValue());
            } finally {
                if (results != null) {
                    results.forEach(Tensor::close);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }).collect(Collectors.toList());
    final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
    final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
    final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
    final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
        return predictions.stream().mapToDouble(Prediction::value).toArray();
    }).toCompletableFuture();
    assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
Also used : Path(java.nio.file.Path) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Arrays(java.util.Arrays) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) Graph(org.tensorflow.Graph) Assert.assertArrayEquals(org.junit.Assert.assertArrayEquals) Output(org.tensorflow.Output) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Path(java.nio.file.Path) Prediction(com.spotify.zoltar.Prediction) Files(java.nio.file.Files) IOException(java.io.IOException) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) Identity(com.spotify.featran.transformers.Identity) JFeatureSpec(com.spotify.featran.java.JFeatureSpec) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.assertEquals(org.junit.Assert.assertEquals) Tensor(org.tensorflow.Tensor) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) IOException(java.io.IOException) List(java.util.List) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 2 with PredictFn

use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.

the class PredictorTest method timeout.

@Test
public void timeout() {
    final Duration wait = Duration.ofSeconds(1);
    final Duration predictionTimeout = Duration.ZERO;
    final ExtractFn<Object, Object> extractFn = inputs -> Collections.emptyList();
    final PredictFn<DummyModel, Object, Object, Object> predictFn = (model, vectors) -> {
        Thread.sleep(wait.toMillis());
        return Collections.emptyList();
    };
    try {
        final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
        DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict(predictionTimeout, new Object()).toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
        fail("should throw TimeoutException");
    } catch (final Exception e) {
        assertTrue(e.getCause() instanceof TimeoutException);
    }
}
Also used : SingleExtractFn(com.spotify.zoltar.FeatureExtractFns.SingleExtractFn) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) Assert.assertTrue(org.junit.Assert.assertTrue) TimeoutException(java.util.concurrent.TimeoutException) CompletableFuture(java.util.concurrent.CompletableFuture) Test(org.junit.Test) AsyncPredictFn(com.spotify.zoltar.PredictFns.AsyncPredictFn) Collectors(java.util.stream.Collectors) Assert.assertThat(org.junit.Assert.assertThat) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Duration(java.time.Duration) Is.is(org.hamcrest.core.Is.is) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.fail(org.junit.Assert.fail) Collections(java.util.Collections) Duration(java.time.Duration) TimeoutException(java.util.concurrent.TimeoutException) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException) Test(org.junit.Test)

Example 3 with PredictFn

use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.

the class PredictorTest method nonEmpty.

@Test
public void nonEmpty() throws InterruptedException, ExecutionException, TimeoutException {
    final Duration wait = Duration.ofSeconds(1);
    final SingleExtractFn<Integer, Float> extractFn = input -> (float) input / 10;
    final PredictFn<DummyModel, Integer, Float, Float> predictFn = (model, vectors) -> {
        return vectors.stream().map(vector -> Prediction.create(vector.input(), vector.value() * 2)).collect(Collectors.toList());
    };
    final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
    final List<Prediction<Integer, Float>> predictions = DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict(1).toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
    assertThat(predictions.size(), is(1));
    assertThat(predictions.get(0), is(Prediction.create(1, 0.2f)));
}
Also used : SingleExtractFn(com.spotify.zoltar.FeatureExtractFns.SingleExtractFn) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) Assert.assertTrue(org.junit.Assert.assertTrue) TimeoutException(java.util.concurrent.TimeoutException) CompletableFuture(java.util.concurrent.CompletableFuture) Test(org.junit.Test) AsyncPredictFn(com.spotify.zoltar.PredictFns.AsyncPredictFn) Collectors(java.util.stream.Collectors) Assert.assertThat(org.junit.Assert.assertThat) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Duration(java.time.Duration) Is.is(org.hamcrest.core.Is.is) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.fail(org.junit.Assert.fail) Collections(java.util.Collections) Duration(java.time.Duration) Test(org.junit.Test)

Example 4 with PredictFn

use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.

the class PredictorTest method empty.

@Test
public void empty() throws InterruptedException, ExecutionException, TimeoutException {
    final Duration wait = Duration.ofSeconds(1);
    final ExtractFn<Object, Object> extractFn = inputs -> Collections.emptyList();
    final AsyncPredictFn<DummyModel, Object, Object, Object> predictFn = (model, vectors) -> CompletableFuture.completedFuture(Collections.emptyList());
    final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
    DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict().toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
}
Also used : SingleExtractFn(com.spotify.zoltar.FeatureExtractFns.SingleExtractFn) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) Assert.assertTrue(org.junit.Assert.assertTrue) TimeoutException(java.util.concurrent.TimeoutException) CompletableFuture(java.util.concurrent.CompletableFuture) Test(org.junit.Test) AsyncPredictFn(com.spotify.zoltar.PredictFns.AsyncPredictFn) Collectors(java.util.stream.Collectors) Assert.assertThat(org.junit.Assert.assertThat) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Duration(java.time.Duration) Is.is(org.hamcrest.core.Is.is) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.fail(org.junit.Assert.fail) Collections(java.util.Collections) Duration(java.time.Duration) Test(org.junit.Test)

Aggregations

ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)4 PredictFn (com.spotify.zoltar.PredictFns.PredictFn)4 List (java.util.List)4 CompletableFuture (java.util.concurrent.CompletableFuture)4 Collectors (java.util.stream.Collectors)4 Test (org.junit.Test)4 SingleExtractFn (com.spotify.zoltar.FeatureExtractFns.SingleExtractFn)3 AsyncPredictFn (com.spotify.zoltar.PredictFns.AsyncPredictFn)3 Duration (java.time.Duration)3 Collections (java.util.Collections)3 ExecutionException (java.util.concurrent.ExecutionException)3 TimeUnit (java.util.concurrent.TimeUnit)3 TimeoutException (java.util.concurrent.TimeoutException)3 Is.is (org.hamcrest.core.Is.is)3 Assert.assertThat (org.junit.Assert.assertThat)3 Assert.assertTrue (org.junit.Assert.assertTrue)3 Assert.fail (org.junit.Assert.fail)3 JFeatureSpec (com.spotify.featran.java.JFeatureSpec)1 Identity (com.spotify.featran.transformers.Identity)1 ModelLoader (com.spotify.zoltar.ModelLoader)1