use of com.spotify.zoltar.Prediction in project zoltar by spotify.
the class TensorFlowGraphModelTest method testModelInference.
@Test
public void testModelInference() throws Exception {
final Path graphFile = createADummyTFGraph();
final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
List<Tensor<?>> results = null;
try {
results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
return Prediction.create(vector.input(), results.get(0).doubleValue());
} finally {
if (results != null) {
results.forEach(Tensor::close);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());
final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
return predictions.stream().mapToDouble(Prediction::value).toArray();
}).toCompletableFuture();
assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
use of com.spotify.zoltar.Prediction in project zoltar by spotify.
the class TensorFlowModelTest method getTFIrisPredictor.
public static Predictor<Iris, Long> getTFIrisPredictor() throws Exception {
final TensorFlowPredictFn<Iris, Long> predictFn = (model, vectors) -> {
final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
return CompletableFuture.supplyAsync(() -> predict(model, vector.value())).thenApply(value -> Prediction.create(vector.input(), value));
}).collect(Collectors.toList());
return CompletableFutures.allAsList(predictions);
};
final URI trainedModelUri = TensorFlowModelTest.class.getResource("/trained_model").toURI();
final URI settingsUri = TensorFlowModelTest.class.getResource("/settings.json").toURI();
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final ModelLoader<TensorFlowModel> model = TensorFlowLoader.create(trainedModelUri.toString());
final ExtractFn<Iris, Example> extractFn = FeatranExtractFns.example(IrisFeaturesSpec.irisFeaturesSpec(), settings);
return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
use of com.spotify.zoltar.Prediction in project zoltar by spotify.
the class XGBoostModelTest method getXGBoostIrisPredictor.
public static Predictor<Iris, Long> getXGBoostIrisPredictor() throws Exception {
final URI trainedModelUri = XGBoostModelTest.class.getResource("/iris.model").toURI();
final URI settingsUri = XGBoostModelTest.class.getResource("/settings.json").toURI();
final XGBoostPredictFn<Iris, Long> predictFn = (model, vectors) -> {
final List<CompletableFuture<Prediction<Iris, Long>>> predictions = vectors.stream().map(vector -> {
return CompletableFuture.supplyAsync(() -> {
try {
final Iterator<LabeledPoint> iterator = Collections.singletonList(vector.value()).iterator();
final DMatrix dMatrix = new DMatrix(iterator, null);
final float[] score = model.instance().predict(dMatrix)[0];
int idx = IntStream.range(0, score.length).reduce((i, j) -> score[i] >= score[j] ? i : j).getAsInt();
return Prediction.create(vector.input(), (long) idx);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}).collect(Collectors.toList());
return CompletableFutures.allAsList(predictions);
};
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final XGBoostLoader model = XGBoostLoader.create(trainedModelUri.toString());
final ExtractFn<Iris, LabeledPoint> extractFn = FeatranExtractFns.labeledPoints(IrisFeaturesSpec.irisFeaturesSpec(), settings);
return PredictorsTest.newBuilder(model, extractFn, predictFn).predictor();
}
use of com.spotify.zoltar.Prediction in project zoltar by spotify.
the class XGBoostModelTest method testModelPrediction.
@Test
public void testModelPrediction() throws Exception {
final Iris[] irisStream = IrisHelper.getIrisTestData();
final Map<Integer, String> classToId = ImmutableMap.of(0, "Iris-setosa", 1, "Iris-versicolor", 2, "Iris-virginica");
final CompletableFuture<Integer> sum = getXGBoostIrisPredictor().predict(Duration.ofSeconds(10), irisStream).thenApply(predictions -> {
return predictions.stream().mapToInt(prediction -> {
String className = prediction.input().className().get();
int score = prediction.value().intValue();
return classToId.get(score).equals(className) ? 1 : 0;
}).sum();
}).toCompletableFuture();
assertTrue("Should be more the 0.8", sum.get() / (float) irisStream.length > .8);
}
use of com.spotify.zoltar.Prediction in project zoltar by spotify.
the class IrisPrediction method predict.
/**
* Prediction endpoint. Takes a request in a from of a String containing iris features `-`
* separated, and returns a response in a form of a predicted iris class.
*/
public static Response<String> predict(final String requestFeatures) {
if (requestFeatures == null) {
return Response.forStatus(Status.BAD_REQUEST);
}
final String[] features = requestFeatures.split("-");
if (features.length != 4) {
return Response.forStatus(Status.BAD_REQUEST);
}
final Iris featureData = new Iris(Option.apply(Double.parseDouble(features[0])), Option.apply(Double.parseDouble(features[1])), Option.apply(Double.parseDouble(features[2])), Option.apply(Double.parseDouble(features[3])), Option.empty());
int[] predictions = new int[0];
try {
predictions = predictor.predict(featureData).thenApply(p -> p.stream().mapToInt(prediction -> prediction.value().intValue()).toArray()).toCompletableFuture().get();
} catch (final Exception e) {
e.printStackTrace();
// TODO: what to return in case of failure here?
}
final String predictedClass = idToClass.get(predictions[0]);
return Response.forPayload(predictedClass);
}
Aggregations