Search in sources :

Example 1 with Prediction

use of com.spotify.zoltar.Prediction in project zoltar by spotify.

the class TensorFlowGraphModelTest method testModelInference.

@Test
public void testModelInference() throws Exception {
    final Path graphFile = createADummyTFGraph();
    final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
    final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
    final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
        try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
            List<Tensor<?>> results = null;
            try {
                results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
                return Prediction.create(vector.input(), results.get(0).doubleValue());
            } finally {
                if (results != null) {
                    results.forEach(Tensor::close);
                }
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }).collect(Collectors.toList());
    final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
    final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
    final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
    final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
        return predictions.stream().mapToDouble(Prediction::value).toArray();
    }).toCompletableFuture();
    assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
Also used : Path(java.nio.file.Path) PredictFn(com.spotify.zoltar.PredictFns.PredictFn) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Arrays(java.util.Arrays) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) Graph(org.tensorflow.Graph) Assert.assertArrayEquals(org.junit.Assert.assertArrayEquals) Output(org.tensorflow.Output) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Path(java.nio.file.Path) Prediction(com.spotify.zoltar.Prediction) Files(java.nio.file.Files) IOException(java.io.IOException) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) List(java.util.List) Identity(com.spotify.featran.transformers.Identity) JFeatureSpec(com.spotify.featran.java.JFeatureSpec) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Assert.assertEquals(org.junit.Assert.assertEquals) Tensor(org.tensorflow.Tensor) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) IOException(java.io.IOException) List(java.util.List) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 2 with Prediction

use of com.spotify.zoltar.Prediction in project zoltar by spotify.

the class TensorFlowModelTest method getTFIrisPredictor.

public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
    final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
    final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : Example(org.tensorflow.example.Example) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) LongBuffer(java.nio.LongBuffer) Assert(org.junit.Assert) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) Example(org.tensorflow.example.Example) List(java.util.List)

Example 3 with Prediction

use of com.spotify.zoltar.Prediction in project zoltar by spotify.

the class XGBoostModelTest method getXGBoostIrisPredictor.

public static Predictor<Iris, Long> getXGBoostIrisPredictor() throws Exception {
    final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI();
    final URI settingsUri = XGBoostModelTest.class.getResource("/settings.json").toURI();
    final XGBoostPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    final Iterator<LabeledPoint> iterator = Collections.singletonList(vector.value()).iterator();
                    final DMatrix dMatrix = new DMatrix(iterator, null);
                    final float[] score = model.instance().predict(dMatrix)[0];
                    int idx = IntStream.range(0, score.length).reduce((i, j) -> score[i] >= score[j] ? i : j).getAsInt();
                    return Prediction.create(vector.input(), (long) idx);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, LabeledPoint> extractFn = FeatranExtractFns.labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : IntStream(java.util.stream.IntStream) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) CompletableFuture(java.util.concurrent.CompletableFuture) DMatrix(ml.dmlc.xgboost4j.java.DMatrix) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) Iterator(java.util.Iterator) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Collections(java.util.Collections) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Prediction(com.spotify.zoltar.Prediction) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) URI(java.net.URI) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) DMatrix(ml.dmlc.xgboost4j.java.DMatrix)

Example 4 with Prediction

use of com.spotify.zoltar.Prediction in project zoltar by spotify.

the class XGBoostModelTest method testModelPrediction.

@Test
public void testModelPrediction() throws Exception {
    final Iris[] irisStream = IrisHelper.getIrisTestData();
    final Map<Integer, String> classToId = ImmutableMap.of(0, "Iris-setosa", 1, "Iris-versicolor", 2, "Iris-virginica");
    final CompletableFuture<Integer> sum = getXGBoostIrisPredictor().predict(Duration.ofSeconds(10), irisStream).thenApply(predictions -> {
        return predictions.stream().mapToInt(prediction -> {
            String className = prediction.input().className().get();
            int score = prediction.value().intValue();
            return classToId.get(score).equals(className) ? 1 : 0;
        }).sum();
    }).toCompletableFuture();
    assertTrue("Should be more the 0.8", sum.get() / (float) irisStream.length > .8);
}
Also used : IntStream(java.util.stream.IntStream) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) CompletableFuture(java.util.concurrent.CompletableFuture) DMatrix(ml.dmlc.xgboost4j.java.DMatrix) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) Iterator(java.util.Iterator) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Collections(java.util.Collections) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 5 with Prediction

use of com.spotify.zoltar.Prediction in project zoltar by spotify.

the class IrisPrediction method predict.

/**
 * Prediction endpoint. Takes a request in a from of a String containing iris features `-`
 * separated, and returns a response in a form of a predicted iris class.
 */
public static Response<String> predict(final String requestFeatures) {
    if (requestFeatures == null) {
        return Response.forStatus(Status.BAD_REQUEST);
    }
    final String[] features = requestFeatures.split("-");
    if (features.length != 4) {
        return Response.forStatus(Status.BAD_REQUEST);
    }
    final Iris featureData = new Iris(Option.apply(Double.parseDouble(features[0])), Option.apply(Double.parseDouble(features[1])), Option.apply(Double.parseDouble(features[2])), Option.apply(Double.parseDouble(features[3])), Option.empty());
    int[] predictions = new int[0];
    try {
        predictions = predictor.predict(featureData).thenApply(p -> p.stream().mapToInt(prediction -> prediction.value().intValue()).toArray()).toCompletableFuture().get();
    } catch (final Exception e) {
        e.printStackTrace();
    // TODO: what to return in case of failure here?
    }
    final String predictedClass = idToClass.get(predictions[0]);
    return Response.forPayload(predictedClass);
}
Also used : Example(org.tensorflow.example.Example) Response(com.spotify.apollo.Response) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Predictors(com.spotify.zoltar.Predictors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) TensorFlowModel(com.spotify.zoltar.tf.TensorFlowModel) TensorFlowPredictFn(com.spotify.zoltar.tf.TensorFlowPredictFn) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Map(java.util.Map) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Status(com.spotify.apollo.Status) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) FeatureSpec(com.spotify.featran.FeatureSpec) TensorFlowExtras(com.spotify.zoltar.tf.TensorFlowExtras) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) JTensor(com.spotify.zoltar.tf.JTensor) IOException(java.io.IOException) Option(scala.Option) Collectors(java.util.stream.Collectors) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Models(com.spotify.zoltar.Models) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) IOException(java.io.IOException)

Aggregations

ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)5 Prediction (com.spotify.zoltar.Prediction)5 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)5 URI (java.net.URI)5 Files (java.nio.file.Files)5 Paths (java.nio.file.Paths)5 List (java.util.List)5 CompletableFuture (java.util.concurrent.CompletableFuture)5 Collectors (java.util.stream.Collectors)5 ImmutableMap (com.google.common.collect.ImmutableMap)4 CompletableFutures (com.spotify.futures.CompletableFutures)4 IrisFeaturesSpec (com.spotify.zoltar.IrisFeaturesSpec)4 Iris (com.spotify.zoltar.IrisFeaturesSpec.Iris)4 Predictor (com.spotify.zoltar.Predictor)4 PredictorsTest (com.spotify.zoltar.PredictorsTest)4 StandardCharsets (java.nio.charset.StandardCharsets)4 Map (java.util.Map)4 Test (org.junit.Test)4 IrisHelper (com.spotify.zoltar.IrisHelper)3 ModelLoader (com.spotify.zoltar.ModelLoader)3