Search in sources :

Example 1 with Predictor

use of com.spotify.zoltar.Predictor in project zoltar by spotify.

the class TensorFlowModelTest method getTFIrisPredictor.

public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
    final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
    final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : Example(org.tensorflow.example.Example) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) Tensors(org.tensorflow.Tensors) Session(org.tensorflow.Session) CompletableFuture(java.util.concurrent.CompletableFuture) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) Tensor(org.tensorflow.Tensor) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) ModelLoader(com.spotify.zoltar.ModelLoader) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) LongBuffer(java.nio.LongBuffer) Assert(org.junit.Assert) Prediction(com.spotify.zoltar.Prediction) URI(java.net.URI) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) Example(org.tensorflow.example.Example) List(java.util.List)

Example 2 with Predictor

use of com.spotify.zoltar.Predictor in project zoltar by spotify.

the class XGBoostModelTest method getXGBoostIrisPredictor.

public static Predictor<Iris, Long> getXGBoostIrisPredictor() throws Exception {
    final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI();
    final URI settingsUri = XGBoostModelTest.class.getResource("/settings.json").toURI();
    final XGBoostPredictFn<Iris, Long> predictFn = (model, vectors) -> {
        final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
            return CompletableFuture.supplyAsync(() -> {
                try {
                    final Iterator<LabeledPoint> iterator = Collections.singletonList(vector.value()).iterator();
                    final DMatrix dMatrix = new DMatrix(iterator, null);
                    final float[] score = model.instance().predict(dMatrix)[0];
                    int idx = IntStream.range(0, score.length).reduce((i, j) -> score[i] >= score[j] ? i : j).getAsInt();
                    return Prediction.create(vector.input(), (long) idx);
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
        }).collect(Collectors.toList());
        return CompletableFutures.allAsList(predictions);
    };
    final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
    final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString());
    final ExtractFn<Iris, LabeledPoint> extractFn = FeatranExtractFns.labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings);
    return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Also used : IntStream(java.util.stream.IntStream) FeatranExtractFns(com.spotify.zoltar.featran.FeatranExtractFns) CompletableFuture(java.util.concurrent.CompletableFuture) DMatrix(ml.dmlc.xgboost4j.java.DMatrix) IrisFeaturesSpec(com.spotify.zoltar.IrisFeaturesSpec) Duration(java.time.Duration) Map(java.util.Map) IrisHelper(com.spotify.zoltar.IrisHelper) PredictorsTest(com.spotify.zoltar.PredictorsTest) URI(java.net.URI) CompletableFutures(com.spotify.futures.CompletableFutures) Prediction(com.spotify.zoltar.Prediction) Iterator(java.util.Iterator) ImmutableMap(com.google.common.collect.ImmutableMap) Files(java.nio.file.Files) Assert.assertTrue(org.junit.Assert.assertTrue) Test(org.junit.Test) Collectors(java.util.stream.Collectors) StandardCharsets(java.nio.charset.StandardCharsets) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) Paths(java.nio.file.Paths) Predictor(com.spotify.zoltar.Predictor) ExtractFn(com.spotify.zoltar.FeatureExtractFns.ExtractFn) Collections(java.util.Collections) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Prediction(com.spotify.zoltar.Prediction) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) URI(java.net.URI) LabeledPoint(ml.dmlc.xgboost4j.LabeledPoint) Iris(com.spotify.zoltar.IrisFeaturesSpec.Iris) List(java.util.List) DMatrix(ml.dmlc.xgboost4j.java.DMatrix)

Aggregations

ImmutableMap (com.google.common.collect.ImmutableMap)2 CompletableFutures (com.spotify.futures.CompletableFutures)2 ExtractFn (com.spotify.zoltar.FeatureExtractFns.ExtractFn)2 IrisFeaturesSpec (com.spotify.zoltar.IrisFeaturesSpec)2 Iris (com.spotify.zoltar.IrisFeaturesSpec.Iris)2 IrisHelper (com.spotify.zoltar.IrisHelper)2 Prediction (com.spotify.zoltar.Prediction)2 Predictor (com.spotify.zoltar.Predictor)2 PredictorsTest (com.spotify.zoltar.PredictorsTest)2 FeatranExtractFns (com.spotify.zoltar.featran.FeatranExtractFns)2 URI (java.net.URI)2 StandardCharsets (java.nio.charset.StandardCharsets)2 Files (java.nio.file.Files)2 Paths (java.nio.file.Paths)2 Duration (java.time.Duration)2 List (java.util.List)2 Map (java.util.Map)2 CompletableFuture (java.util.concurrent.CompletableFuture)2 Collectors (java.util.stream.Collectors)2 Test (org.junit.Test)2