Search in sources :

Example 1 with Iris

use of com.spotify.zoltar.IrisFeaturesSpec.Iris in project zoltar by spotify.

the class TensorFlowModelTest method getTFIrisPredictor.

public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
    final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
    final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : Example(org.tensorflow.example.Example) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) LongBuffer(java.nio.LongBuffer) Assert(org.junit.Assert) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) Example(org.tensorflow.example.Example) List(java.util.List)

Example 2 with Iris

use of com.spotify.zoltar.IrisFeaturesSpec.Iris in project zoltar by spotify.

the class XGBoostModelTest method getXGBoostIrisPredictor.

public static Predictor<Iris, Long> getXGBoostIrisPredictor() throws Exception {
    final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI();
    final URI settingsUri = XGBoostModelTest.class.getResource("/settings.json").toURI();
    final XGBoostPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    final Iterator<LabeledPoint> iterator = Collections.singletonList(vector.value()).iterator();
                    final DMatrix dMatrix = new DMatrix(iterator, null);
                    final float[] score = model.instance().predict(dMatrix)[0];
                    int idx = IntStream.range(0, score.length).reduce((i, j) -> score[i] >= score[j] ? i : j).getAsInt();
                    return Prediction.create(vector.input(), (long) idx);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, LabeledPoint> extractFn = FeatranExtractFns.labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : IntStream(java.util.stream.IntStream) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) CompletableFuture(java.util.concurrent.CompletableFuture) DMatrix(ml.dmlc.xgboost4j.java.DMatrix) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) Iterator(java.util.Iterator) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Collections(java.util.Collections) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Prediction(com.spotify.zoltar.Prediction) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) URI(java.net.URI) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) DMatrix(ml.dmlc.xgboost4j.java.DMatrix)

Example 3 with Iris

use of com.spotify.zoltar.IrisFeaturesSpec.Iris in project zoltar by spotify.

the class XGBoostModelTest method testModelPrediction.

@Test
public void testModelPrediction() throws Exception {
    final Iris[] irisStream = IrisHelper.getIrisTestData();
    final Map<Integer, String> classToId = ImmutableMap.of(0, "Iris-setosa", 1, "Iris-versicolor", 2, "Iris-virginica");
    final CompletableFuture<Integer> sum = getXGBoostIrisPredictor().predict(Duration.ofSeconds(10), irisStream).thenApply(predictions -> {
        return predictions.stream().mapToInt(prediction -> {
            String className = prediction.input().className().get();
            int score = prediction.value().intValue();
            return classToId.get(score).equals(className) ? 1 : 0;
        }).sum();
    }).toCompletableFuture();
    assertTrue("Should be more the 0.8", sum.get() / (float) irisStream.length > .8);
}
Also used : IntStream(java.util.stream.IntStream) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) CompletableFuture(java.util.concurrent.CompletableFuture) DMatrix(ml.dmlc.xgboost4j.java.DMatrix) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) Iterator(java.util.Iterator) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Collections(java.util.Collections) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) PredictorsTest(com.spotify.zoltar.PredictorsTest) Test(org.junit.Test)

Example 4 with Iris

use of com.spotify.zoltar.IrisFeaturesSpec.Iris in project zoltar by spotify.

the class IrisPrediction method predict.

/**
 * Prediction endpoint. Takes a request in a from of a String containing iris features `-`
 * separated, and returns a response in a form of a predicted iris class.
 */
public static Response<String> predict(final String requestFeatures) {
    if (requestFeatures == null) {
        return Response.forStatus(Status.BAD_REQUEST);
    }
    final String[] features = requestFeatures.split("-");
    if (features.length != 4) {
        return Response.forStatus(Status.BAD_REQUEST);
    }
    final Iris featureData = new Iris(Option.apply(Double.parseDouble(features[0])), Option.apply(Double.parseDouble(features[1])), Option.apply(Double.parseDouble(features[2])), Option.apply(Double.parseDouble(features[3])), Option.empty());
    int[] predictions = new int[0];
    try {
        predictions = predictor.predict(featureData).thenApply(p -> p.stream().mapToInt(prediction -> prediction.value().intValue()).toArray()).toCompletableFuture().get();
    } catch (final Exception e) {
        e.printStackTrace();
    // TODO: what to return in case of failure here?
    }
    final String predictedClass = idToClass.get(predictions[0]);
    return Response.forPayload(predictedClass);
}
Also used : Example(org.tensorflow.example.Example) Response(com.spotify.apollo.Response) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Predictors(com.spotify.zoltar.Predictors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) TensorFlowModel(com.spotify.zoltar.tf.TensorFlowModel) TensorFlowPredictFn(com.spotify.zoltar.tf.TensorFlowPredictFn) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Map(java.util.Map) URI(java.net.URI) Tensor(org.tensorflow.Tensor) Status(com.spotify.apollo.Status) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) FeatureSpec(com.spotify.featran.FeatureSpec) TensorFlowExtras(com.spotify.zoltar.tf.TensorFlowExtras) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) JTensor(com.spotify.zoltar.tf.JTensor) IOException(java.io.IOException) Option(scala.Option) Collectors(java.util.stream.Collectors) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Models(com.spotify.zoltar.Models) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) IOException(java.io.IOException)

Aggregations

ImmutableMap (com.google.common.collect.ImmutableMap)4 CompletableFutures (com.spotify.futures.CompletableFutures)4 ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)4 IrisFeaturesSpec (com.spotify.zoltar.IrisFeaturesSpec)4 Iris (com.spotify.zoltar.IrisFeaturesSpec.Iris)4 Prediction (com.spotify.zoltar.Prediction)4 Predictor (com.spotify.zoltar.Predictor)4 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)4 URI (java.net.URI)4 Files (java.nio.file.Files)4 Paths (java.nio.file.Paths)4 List (java.util.List)4 Map (java.util.Map)4 CompletableFuture (java.util.concurrent.CompletableFuture)4 Collectors (java.util.stream.Collectors)4 IrisHelper (com.spotify.zoltar.IrisHelper)3 PredictorsTest (com.spotify.zoltar.PredictorsTest)3 StandardCharsets (java.nio.charset.StandardCharsets)3 Duration (java.time.Duration)3 Test (org.junit.Test)3