use of org.tensorflow.Tensor in project zoltar by spotify.
the class TensorFlowGraphModelTest method testModelInference.
@Test
public void testModelInference() throws Exception {
final Path graphFile = createADummyTFGraph();
final JFeatureSpec<Double> featureSpec = JFeatureSpec.<Double>create().required(d -> d, Identity.apply("feature"));
final URI settingsUri = getClass().getResource("/settings_dummy.json").toURI();
final String settings = new String(Files.readAllBytes(Paths.get(settingsUri)), StandardCharsets.UTF_8);
final ModelLoader<TensorFlowGraphModel> tfModel = TensorFlowGraphLoader.create(graphFile.toString(), null, null);
final PredictFn<TensorFlowGraphModel, Double, double[], Double> predictFn = (model, vectors) -> vectors.stream().map(vector -> {
try (Tensor<Double> input = Tensors.create(vector.value()[0])) {
List<Tensor<?>> results = null;
try {
results = model.instance().runner().fetch(mulResult).feed(inputOpName, input).run();
return Prediction.create(vector.input(), results.get(0).doubleValue());
} finally {
if (results != null) {
results.forEach(Tensor::close);
}
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}).collect(Collectors.toList());
final ExtractFn<Double, double[]> extractFn = FeatranExtractFns.doubles(featureSpec, settings);
final Double[] input = new Double[] { 0.0D, 1.0D, 7.0D };
final double[] expected = Arrays.stream(input).mapToDouble(d -> d * 2.0D).toArray();
final CompletableFuture<double[]> result = PredictorsTest.newBuilder(tfModel, extractFn, predictFn).predictor().predict(input).thenApply(predictions -> {
return predictions.stream().mapToDouble(Prediction::value).toArray();
}).toCompletableFuture();
assertArrayEquals(result.get(), expected, Double.MIN_VALUE);
}
use of org.tensorflow.Tensor in project zoltar by spotify.
the class TensorFlowGraphModelTest method testDummyLoadOfTensorFlowGraph.
@Test
public void testDummyLoadOfTensorFlowGraph() throws Exception {
final Path graphFile = createADummyTFGraph();
try (final TensorFlowGraphModel model = TensorFlowGraphModel.create(graphFile.toUri(), null, null);
final Session session = model.instance();
final Tensor<Double> double3 = Tensors.create(3.0D)) {
List<Tensor<?>> result = null;
try {
result = session.runner().fetch(mulResult).feed(inputOpName, double3).run();
assertEquals(result.get(0).doubleValue(), 6.0D, Double.MIN_VALUE);
} finally {
if (result != null) {
result.forEach(Tensor::close);
}
}
}
}
use of org.tensorflow.Tensor in project zoltar by spotify.
the class TensorFlowModelTest method predict.
private static long predict(final TensorFlowModel model, final Example example) {
// rank 1 cause we need to account for batch
final byte[][] b = new byte[1][];
b[0] = example.toByteArray();
try (final Tensor<String> t = Tensors.create(b)) {
final Session.Runner runner = model.instance().session().runner().feed("input_example_tensor", t).fetch("linear/head/predictions/class_ids");
final List<Tensor<?>> output = runner.run();
final LongBuffer incomingClassId = LongBuffer.allocate(1);
try {
output.get(0).writeTo(incomingClassId);
} finally {
output.forEach(Tensor::close);
}
return incomingClassId.get(0);
}
}
use of org.tensorflow.Tensor in project sponge by softelnet.
the class TensorflowTest method testTf.
@Test
public void testTf() throws UnsupportedEncodingException {
try (Graph g = new Graph()) {
final String value = "Hello from " + TensorFlow.version();
// named "MyConst" with a value "value".
try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
}
// Execute the "MyConst" operation in a Session.
try (Session s = new Session(g);
Tensor output = s.runner().fetch("MyConst").run().get(0)) {
logger.info(new String(output.bytesValue(), "UTF-8"));
logger.info(output.toString());
}
}
}
use of org.tensorflow.Tensor in project sponge by softelnet.
the class TensorflowTest method testLoadModel.
// TODO @Test
public void testLoadModel() throws Exception {
String modelDir = "examples/tensorflow/estimator/model";
SavedModelBundle bundle = SavedModelBundle.load(modelDir + "/" + SpongeUtils.getLastSubdirectory(modelDir), "serve");
try (Session s = bundle.session()) /* ; Tensor output = s.runner().fetch("MyConst").run().get(0) */
{
Tensor x = Tensor.create(new float[] { 2, 5, 8, 1 });
Tensor y = s.runner().feed("x", x).fetch("y").run().get(0);
logger.info("y = {}", y.floatValue());
}
}
Aggregations