use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class RankingExpressionWithTensorFlowTestCase method assertLargeConstant.
/**
* Verifies that the constant with the given name exists, and - only if an expected size is given -
* that the content of the constant is available and has the expected size.
*/
private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
try {
Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf");
RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
assertEquals(name, rankingConstant.getName());
assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
if (expectedSize.isPresent()) {
Path constantPath = applicationDir.append(constantApplicationPackagePath);
assertTrue("Constant file '" + constantPath + "' has been written", constantPath.toFile().exists());
Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}
use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class RankFeaturesTestCase method requireThatSingleTensorIsBinaryEncoded.
@Test
public void requireThatSingleTensorIsBinaryEncoded() {
TensorType type = new TensorType.Builder().mapped("x").mapped("y").mapped("z").build();
Tensor tensor = Tensor.from(type, "{ {x:a, y:b, z:c}:2.0, {x:a, y:b, z:c2}:3.0 }");
assertTensorEncodingAndDecoding(type, "query(my_tensor)", "my_tensor", tensor);
assertTensorEncodingAndDecoding(type, "$my_tensor", "my_tensor", tensor);
}
use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class RankFeaturesTestCase method requireThatMultipleTensorsAreBinaryEncoded.
@Test
public void requireThatMultipleTensorsAreBinaryEncoded() {
TensorType type = new TensorType.Builder().mapped("x").mapped("y").mapped("z").build();
Tensor tensor1 = Tensor.from(type, "{ {x:a, y:b, z:c}:2.0, {x:a, y:b, z:c2}:3.0 }");
Tensor tensor2 = Tensor.from(type, "{ {x:a, y:b, z:c}:5.0 }");
assertTensorEncodingAndDecoding(type, Arrays.asList(new Entry("query(tensor1)", "tensor1", tensor1), new Entry("$tensor2", "tensor2", tensor2)));
}
use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class Select method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputFunctionsPresent(3)) {
return null;
}
TensorFlowOperation conditionOperation = inputs().get(0);
TensorFunction a = inputs().get(1).function().get();
TensorFunction b = inputs().get(2).function().get();
// Shortcut: if we know during import which tensor to select, do that directly here.
if (conditionOperation.getConstantValue().isPresent()) {
Tensor condition = conditionOperation.getConstantValue().get().asTensor();
if (condition.type().rank() == 0) {
return ((int) condition.asDouble() == 0) ? b : a;
}
if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
}
}
// The task is to select cells from 'x' or 'y' based on 'condition'.
// If 'condition' is 0 (false), select from 'y', if 1 (true) select
// from 'x'. We do this by individually joining 'x' and 'y' with
// 'condition', and then joining the resulting two tensors.
TensorFunction conditionFunction = conditionOperation.function().get();
TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
@Override
public double applyAsDouble(double a, double b) {
return a * (1.0 - b);
}
@Override
public String toString() {
return "f(a,b)(a * (1-b))";
}
});
return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
}
use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.
the class MnistSoftmaxImportTestCase method testMnistSoftmaxImport.
@Test
public void testMnistSoftmaxImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
// Check constants
assertEquals(2, model.get().largeConstants().size());
Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
// Check (provided) macros
assertEquals(0, model.get().macros().size());
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
assertTrue(model.get().requiredMacros().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("Placeholder"));
// Check signatures
assertEquals(1, model.get().signatures().size());
TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
assertEquals(1, signature.inputs().size());
TensorType argument0 = signature.inputArgument("x");
assertNotNull(argument0);
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
// ... signature outputs
assertEquals(1, signature.outputs().size());
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getRoot().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");
model.assertEqualResult("Placeholder", "add");
}
Aggregations