use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.
the class TensorFlowGraphModelTest method testModelInference.
@Test
public void testModelInference() throws Exception {
final Path graphFile = createADummyTFGraph();
final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
List<Tensor<?>> results = null;
try {
results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
return Prediction.create(vector.input(), results.get(0).doubleValue());
} finally {
if (results != null) {
results.forEach(Tensor::close);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());
final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
return predictions.stream().mapToDouble(Prediction::value).toArray();
}).toCompletableFuture();
assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.
the class PredictorTest method timeout.
@Test
public void timeout() {
final Duration wait = Duration.ofSeconds(1);
final Duration predictionTimeout = Duration.ZERO;
final ExtractFn<Object, Object> extractFn = inputs -> Collections.emptyList();
final PredictFn<DummyModel, Object, Object, Object> predictFn = (model, vectors) -> {
Thread.sleep(wait.toMillis());
return Collections.emptyList();
};
try {
final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict(predictionTimeout, new Object()).toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
fail("should throw TimeoutException");
} catch (final Exception e) {
assertTrue(e.getCause() instanceof TimeoutException);
}
}
use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.
the class PredictorTest method nonEmpty.
@Test
public void nonEmpty() throws InterruptedException, ExecutionException, TimeoutException {
final Duration wait = Duration.ofSeconds(1);
final SingleExtractFn<Integer, Float> extractFn = input -> (float) input / 10;
final PredictFn<DummyModel, Integer, Float, Float> predictFn = (model, vectors) -> {
return vectors.stream().map(vector -> Prediction.create(vector.input(), vector.value() * 2)).collect(Collectors.toList());
};
final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
final List<Prediction<Integer, Float>> predictions = DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict(1).toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
assertThat(predictions.size(), is(1));
assertThat(predictions.get(0), is(Prediction.create(1, 0.2f)));
}
use of com.spotify.zoltar.PredictFns.PredictFn in project zoltar by spotify.
the class PredictorTest method empty.
@Test
public void empty() throws InterruptedException, ExecutionException, TimeoutException {
final Duration wait = Duration.ofSeconds(1);
final ExtractFn<Object, Object> extractFn = inputs -> Collections.emptyList();
final AsyncPredictFn<DummyModel, Object, Object, Object> predictFn = (model, vectors) -> CompletableFuture.completedFuture(Collections.emptyList());
final ModelLoader<DummyModel> loader = ModelLoader.lift(DummyModel::new);
DefaultPredictorBuilder.create(loader, extractFn, predictFn).predictor().predict().toCompletableFuture().get(wait.toMillis(), TimeUnit.MILLISECONDS);
}
Aggregations