Search in sources :

Example 26 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class RankingExpressionWithTensorFlowTestCase method assertLargeConstant.

/**
 * Verifies that the constant with the given name exists, and - only if an expected size is given -
 * that the content of the constant is available and has the expected size.
 */
private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
    try {
        Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf");
        RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
        assertEquals(name, rankingConstant.getName());
        assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
        if (expectedSize.isPresent()) {
            Path constantPath = applicationDir.append(constantApplicationPackagePath);
            assertTrue("Constant file '" + constantPath + "' has been written", constantPath.toFile().exists());
            Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
            assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
        }
    } catch (IOException e) {
        throw new UncheckedIOException(e);
    }
}
Also used : Path(com.yahoo.path.Path) Tensor(com.yahoo.tensor.Tensor) UncheckedIOException(java.io.UncheckedIOException) IOException(java.io.IOException) UncheckedIOException(java.io.UncheckedIOException) RankingConstant(com.yahoo.searchdefinition.RankingConstant)

Example 27 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class RankFeaturesTestCase method requireThatSingleTensorIsBinaryEncoded.

@Test
public void requireThatSingleTensorIsBinaryEncoded() {
    TensorType type = new TensorType.Builder().mapped("x").mapped("y").mapped("z").build();
    Tensor tensor = Tensor.from(type, "{ {x:a, y:b, z:c}:2.0, {x:a, y:b, z:c2}:3.0 }");
    assertTensorEncodingAndDecoding(type, "query(my_tensor)", "my_tensor", tensor);
    assertTensorEncodingAndDecoding(type, "$my_tensor", "my_tensor", tensor);
}
Also used : Tensor(com.yahoo.tensor.Tensor) TensorType(com.yahoo.tensor.TensorType) Test(org.junit.Test)

Example 28 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class RankFeaturesTestCase method requireThatMultipleTensorsAreBinaryEncoded.

@Test
public void requireThatMultipleTensorsAreBinaryEncoded() {
    TensorType type = new TensorType.Builder().mapped("x").mapped("y").mapped("z").build();
    Tensor tensor1 = Tensor.from(type, "{ {x:a, y:b, z:c}:2.0, {x:a, y:b, z:c2}:3.0 }");
    Tensor tensor2 = Tensor.from(type, "{ {x:a, y:b, z:c}:5.0 }");
    assertTensorEncodingAndDecoding(type, Arrays.asList(new Entry("query(tensor1)", "tensor1", tensor1), new Entry("$tensor2", "tensor2", tensor2)));
}
Also used : Tensor(com.yahoo.tensor.Tensor) TensorType(com.yahoo.tensor.TensorType) Test(org.junit.Test)

Example 29 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class Select method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputFunctionsPresent(3)) {
        return null;
    }
    TensorFlowOperation conditionOperation = inputs().get(0);
    TensorFunction a = inputs().get(1).function().get();
    TensorFunction b = inputs().get(2).function().get();
    // Shortcut: if we know during import which tensor to select, do that directly here.
    if (conditionOperation.getConstantValue().isPresent()) {
        Tensor condition = conditionOperation.getConstantValue().get().asTensor();
        if (condition.type().rank() == 0) {
            return ((int) condition.asDouble() == 0) ? b : a;
        }
        if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
            return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
        }
    }
    // The task is to select cells from 'x' or 'y' based on 'condition'.
    // If 'condition' is 0 (false), select from 'y', if 1 (true) select
    // from 'x'. We do this by individually joining 'x' and 'y' with
    // 'condition', and then joining the resulting two tensors.
    TensorFunction conditionFunction = conditionOperation.function().get();
    TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
    TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {

        @Override
        public double applyAsDouble(double a, double b) {
            return a * (1.0 - b);
        }

        @Override
        public String toString() {
            return "f(a,b)(a * (1-b))";
        }
    });
    return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
}
Also used : Tensor(com.yahoo.tensor.Tensor) DoubleBinaryOperator(java.util.function.DoubleBinaryOperator) TensorFunction(com.yahoo.tensor.functions.TensorFunction)

Example 30 with Tensor

use of com.yahoo.tensor.Tensor in project vespa by vespa-engine.

the class MnistSoftmaxImportTestCase method testMnistSoftmaxImport.

@Test
public void testMnistSoftmaxImport() {
    TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
    // Check constants
    assertEquals(2, model.get().largeConstants().size());
    Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
    assertNotNull(constant0);
    assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type());
    assertEquals(7840, constant0.size());
    Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
    assertNotNull(constant1);
    assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
    assertEquals(10, constant1.size());
    // Check (provided) macros
    assertEquals(0, model.get().macros().size());
    // Check required macros
    assertEquals(1, model.get().requiredMacros().size());
    assertTrue(model.get().requiredMacros().containsKey("Placeholder"));
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("Placeholder"));
    // Check signatures
    assertEquals(1, model.get().signatures().size());
    TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
    assertNotNull(signature);
    // ... signature inputs
    assertEquals(1, signature.inputs().size());
    TensorType argument0 = signature.inputArgument("x");
    assertNotNull(argument0);
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
    // ... signature outputs
    assertEquals(1, signature.outputs().size());
    RankingExpression output = signature.outputExpression("y");
    assertNotNull(output);
    assertEquals("add", output.getName());
    assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getRoot().toString());
    // Test execution
    model.assertEqualResult("Placeholder", "MatMul");
    model.assertEqualResult("Placeholder", "add");
}
Also used : Tensor(com.yahoo.tensor.Tensor) RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) TensorType(com.yahoo.tensor.TensorType) Test(org.junit.Test)

Aggregations

Tensor (com.yahoo.tensor.Tensor)58 Test (org.junit.Test)26 TensorType (com.yahoo.tensor.TensorType)17 IndexedTensor (com.yahoo.tensor.IndexedTensor)10 TensorAddress (com.yahoo.tensor.TensorAddress)7 MixedTensor (com.yahoo.tensor.MixedTensor)5 HashMap (java.util.HashMap)5 Map (java.util.Map)4 MapContext (com.yahoo.searchlib.rankingexpression.evaluation.MapContext)3 TensorValue (com.yahoo.searchlib.rankingexpression.evaluation.TensorValue)3 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)3 DimensionSizes (com.yahoo.tensor.DimensionSizes)3 JsonNode (com.fasterxml.jackson.databind.JsonNode)2 ObjectMapper (com.fasterxml.jackson.databind.ObjectMapper)2 ImmutableList (com.google.common.collect.ImmutableList)2 MappedTensor (com.yahoo.tensor.MappedTensor)2 IOException (java.io.IOException)2 List (java.util.List)2 GrowableByteBuffer (com.yahoo.io.GrowableByteBuffer)1 Path (com.yahoo.path.Path)1