Search in sources :

Example 1 with ExtractFn

use of com.spotify.zoltar.FeatureExtractFns.ExtractFn in project zoltar by spotify.

the class TensorFlowGraphModelTest method testModelInference.

@Test
public void testModelInference() throws Exception {
    final Path graphFile = createADummyTFGraph();
    final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
    final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
    final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
        try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
            List<Tensor<?>> results = null;
            try {
                results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
                return Prediction.create(vector.input(), results.get(0).doubleValue());
            } finally {
                if (results != null) {
                    results.forEach(Tensor::close);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }).collect(Collectors.toList());
    final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
    final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
    final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
    final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
        return predictions.stream().mapToDouble(Prediction::value).toArray();
    }).toCompletableFuture();
    assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
Also used : Path(java.nio.file.Path) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Arrays(java.util.Arrays) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) Graph(org.tensorflow.Graph) Assert.assertArrayEquals(org.junit.Assert.assertArrayEquals) Output(org.tensorflow.Output) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Path(java.nio.file.Path) Prediction(com.spotify.zoltar.Prediction) Files(java.nio.file.Files) IOException(java.io.IOException) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) Identity(com.spotify.featran.transformers.Identity) JFeatureSpec(com.spotify.featran.java.JFeatureSpec) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.assertEquals(org.junit.Assert.assertEquals) Tensor(org.tensorflow.Tensor) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) IOException(java.io.IOException) List(java.util.List) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 2 with ExtractFn

use of com.spotify.zoltar.FeatureExtractFns.ExtractFn in project zoltar by spotify.

the class TensorFlowModelTest method getTFIrisPredictor.

public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
    final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
    final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : Example(org.tensorflow.example.Example) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) LongBuffer(java.nio.LongBuffer) Assert(org.junit.Assert) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) Example(org.tensorflow.example.Example) List(java.util.List)

Example 3 with ExtractFn

use of com.spotify.zoltar.FeatureExtractFns.ExtractFn in project zoltar by spotify.

the class PredictorTest method timeout.

@Test
public void timeout() {
    final Duration wait = Duration.ofSeconds(1);
    final Duration predictionTimeout = Duration.ZERO;
    final ExtractFn<Object, Object> extractFn = inputs -> Collections.emptyList();
    final PredictFn<DummyModel, Object, Object, Object> predictFn = (model, vectors) -> {
        Thread.sleep(wait.toMillis());
        return Collections.emptyList();
    };
    try {
        final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
        DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict(predictionTimeout, new Object()).toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
        fail("should throw TimeoutException");
    } catch (final Exception e) {
        assertTrue(e.getCause() instanceof TimeoutException);
    }
}
Also used : SingleExtractFn(com.spotify.zoltar.FeatureExtractFns.SingleExtractFn) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) Assert.assertTrue(org.junit.Assert.assertTrue) TimeoutException(java.util.concurrent.TimeoutException) CompletableFuture(java.util.concurrent.CompletableFuture) Test(org.junit.Test) AsyncPredictFn(com.spotify.zoltar.PredictFns.AsyncPredictFn) Collectors(java.util.stream.Collectors) Assert.assertThat(org.junit.Assert.assertThat) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Duration(java.time.Duration) Is.is(org.hamcrest.core.Is.is) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.fail(org.junit.Assert.fail) Collections(java.util.Collections) Duration(java.time.Duration) TimeoutException(java.util.concurrent.TimeoutException) ExecutionException(java.util.concurrent.ExecutionException) TimeoutException(java.util.concurrent.TimeoutException) Test(org.junit.Test)

Example 4 with ExtractFn

use of com.spotify.zoltar.FeatureExtractFns.ExtractFn in project zoltar by spotify.

the class PredictorTest method nonEmpty.

@Test
public void nonEmpty() throws InterruptedException, ExecutionException, TimeoutException {
    final Duration wait = Duration.ofSeconds(1);
    final SingleExtractFn<Integer, Float> extractFn = input -> (float) input / 10;
    final PredictFn<DummyModel, Integer, Float, Float> predictFn = (model, vectors) -> {
        return vectors.stream().map(vector -> Prediction.create(vector.input(), vector.value() * 2)).collect(Collectors.toList());
    };
    final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
    final List<Prediction<Integer, Float>> predictions = DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict(1).toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
    assertThat(predictions.size(), is(1));
    assertThat(predictions.get(0), is(Prediction.create(1, 0.2f)));
}
Also used : SingleExtractFn(com.spotify.zoltar.FeatureExtractFns.SingleExtractFn) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) Assert.assertTrue(org.junit.Assert.assertTrue) TimeoutException(java.util.concurrent.TimeoutException) CompletableFuture(java.util.concurrent.CompletableFuture) Test(org.junit.Test) AsyncPredictFn(com.spotify.zoltar.PredictFns.AsyncPredictFn) Collectors(java.util.stream.Collectors) Assert.assertThat(org.junit.Assert.assertThat) ExecutionException(java.util.concurrent.ExecutionException) TimeUnit(java.util.concurrent.TimeUnit) List(java.util.List) Duration(java.time.Duration) Is.is(org.hamcrest.core.Is.is) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.fail(org.junit.Assert.fail) Collections(java.util.Collections) Duration(java.time.Duration) Test(org.junit.Test)

Example 5 with ExtractFn

use of com.spotify.zoltar.FeatureExtractFns.ExtractFn in project zoltar by spotify.

the class XGBoostModelTest method getXGBoostIrisPredictor.

public static Predictor<Iris, Long> getXGBoostIrisPredictor() throws Exception {
    final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI();
    final URI settingsUri = XGBoostModelTest.class.getResource("/settings.json").toURI();
    final XGBoostPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    final Iterator<LabeledPoint> iterator = Collections.singletonList(vector.value()).iterator();
                    final DMatrix dMatrix = new DMatrix(iterator, null);
                    final float[] score = model.instance().predict(dMatrix)[0];
                    int idx = IntStream.range(0, score.length).reduce((i, j) -> score[i] >= score[j] ? i : j).getAsInt();
                    return Prediction.create(vector.input(), (long) idx);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, LabeledPoint> extractFn = FeatranExtractFns.labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : IntStream(java.util.stream.IntStream) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) CompletableFuture(java.util.concurrent.CompletableFuture) DMatrix(ml.dmlc.xgboost4j.java.DMatrix) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) Iterator(java.util.Iterator) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Collections(java.util.Collections) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Prediction(com.spotify.zoltar.Prediction) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) URI(java.net.URI) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) DMatrix(ml.dmlc.xgboost4j.java.DMatrix)

Aggregations

ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)6 List (java.util.List)6 CompletableFuture (java.util.concurrent.CompletableFuture)6 Collectors (java.util.stream.Collectors)6 Test (org.junit.Test)6 Duration (java.time.Duration)5 PredictFn (com.spotify.zoltar.PredictFns.PredictFn)4 Collections (java.util.Collections)4 Assert.assertTrue (org.junit.Assert.assertTrue)4 SingleExtractFn (com.spotify.zoltar.FeatureExtractFns.SingleExtractFn)3 AsyncPredictFn (com.spotify.zoltar.PredictFns.AsyncPredictFn)3 Prediction (com.spotify.zoltar.Prediction)3 PredictorsTest (com.spotify.zoltar.PredictorsTest)3 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)3 URI (java.net.URI)3 StandardCharsets (java.nio.charset.StandardCharsets)3 Files (java.nio.file.Files)3 Paths (java.nio.file.Paths)3 ExecutionException (java.util.concurrent.ExecutionException)3 TimeUnit (java.util.concurrent.TimeUnit)3