use of com.spotify.zoltar.Predictor in project zoltar by spotify.
the class TensorFlowModelTest method getTFIrisPredictor.
public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
}).collect(Collectors.toList());
return CompletableFutures.allAsList(predictions);
};
final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
use of com.spotify.zoltar.Predictor in project zoltar by spotify.
the class XGBoostModelTest method getXGBoostIrisPredictor.
public static Predictor<Iris, Long> getXGBoostIrisPredictor() throws Exception {
final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI();
final URI settingsUri = XGBoostModelTest.class.getResource("/settings.json").toURI();
final XGBoostPredictFn<Iris, Long> predictFn = (model, vectors) -> {
final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
return CompletableFuture.supplyAsync(() -> {
try {
final Iterator<LabeledPoint> iterator = Collections.singletonList(vector.value()).iterator();
final DMatrix dMatrix = new DMatrix(iterator, null);
final float[] score = model.instance().predict(dMatrix)[0];
int idx = IntStream.range(0, score.length).reduce((i, j) -> score[i] >= score[j] ? i : j).getAsInt();
return Prediction.create(vector.input(), (long) idx);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}).collect(Collectors.toList());
return CompletableFutures.allAsList(predictions);
};
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString());
final ExtractFn<Iris, LabeledPoint> extractFn = FeatranExtractFns.labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings);
return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
Aggregations