Search in sources :

Example 26 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class MnistSoftmaxImportTestCase method testMnistSoftmaxImport.

@Test
public void testMnistSoftmaxImport() {
    TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/mnist_softmax/saved");
    // Check constants
    assertEquals(2, model.get().largeConstants().size());
    Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
    assertNotNull(constant0);
    assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type());
    assertEquals(7840, constant0.size());
    Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
    assertNotNull(constant1);
    assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
    assertEquals(10, constant1.size());
    // Check (provided) macros
    assertEquals(0, model.get().macros().size());
    // Check required macros
    assertEquals(1, model.get().requiredMacros().size());
    assertTrue(model.get().requiredMacros().containsKey("Placeholder"));
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), model.get().requiredMacros().get("Placeholder"));
    // Check signatures
    assertEquals(1, model.get().signatures().size());
    TensorFlowModel.Signature signature = model.get().signatures().get("serving_default");
    assertNotNull(signature);
    // ... signature inputs
    assertEquals(1, signature.inputs().size());
    TensorType argument0 = signature.inputArgument("x");
    assertNotNull(argument0);
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
    // ... signature outputs
    assertEquals(1, signature.outputs().size());
    RankingExpression output = signature.outputExpression("y");
    assertNotNull(output);
    assertEquals("add", output.getName());
    assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getRoot().toString());
    // Test execution
    model.assertEqualResult("Placeholder", "MatMul");
    model.assertEqualResult("Placeholder", "add");
}
Also used : Tensor(com.yahoo.tensor.Tensor) RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) TensorType(com.yahoo.tensor.TensorType) Test(org.junit.Test)

Example 27 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Join method mappedHashJoin.

private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) {
    TensorType commonDimensionType = commonDimensions(a, b);
    if (commonDimensionType.dimensions().isEmpty()) {
        // fallback
        return mappedGeneralJoin(a, b, joinedType);
    }
    boolean swapTensors = a.size() > b.size();
    if (swapTensors) {
        Tensor temp = a;
        a = b;
        b = temp;
    }
    // Map dimension indexes to common and joined type
    int[] aIndexesInCommon = mapIndexes(commonDimensionType, a.type());
    int[] bIndexesInCommon = mapIndexes(commonDimensionType, b.type());
    int[] aIndexesInJoined = mapIndexes(a.type(), joinedType);
    int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);
    // Iterate once through the smaller tensor and construct a hash map for common dimensions
    Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>();
    for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
        Tensor.Cell aCell = cellIterator.next();
        TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
        aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>());
        aCellsByCommonAddress.get(partialCommonAddress).add(aCell);
    }
    // Iterate once through the larger tensor and use the hash map to find joinable cells
    Tensor.Builder builder = Tensor.Builder.of(joinedType);
    for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
        Tensor.Cell bCell = cellIterator.next();
        TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon);
        for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, Collections.emptyList())) {
            TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType);
            // not combinable
            if (combinedAddress == null)
                continue;
            double combinedValue = swapTensors ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) : combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
            builder.cell(combinedAddress, combinedValue);
        }
    }
    return builder.build();
}
Also used : TensorAddress(com.yahoo.tensor.TensorAddress) IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) HashMap(java.util.HashMap) TensorType(com.yahoo.tensor.TensorType) ArrayList(java.util.ArrayList) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList)

Example 28 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Join method commonDimensions.

/**
 * Returns common dimension of a and b as a new tensor type
 */
private TensorType commonDimensions(Tensor a, Tensor b) {
    TensorType.Builder typeBuilder = new TensorType.Builder();
    TensorType aType = a.type();
    TensorType bType = b.type();
    for (int i = 0; i < aType.dimensions().size(); ++i) {
        TensorType.Dimension aDim = aType.dimensions().get(i);
        for (int j = 0; j < bType.dimensions().size(); ++j) {
            TensorType.Dimension bDim = bType.dimensions().get(j);
            if (aDim.equals(bDim)) {
                typeBuilder.set(bDim);
            }
        }
    }
    return typeBuilder.build();
}
Also used : TensorType(com.yahoo.tensor.TensorType)

Example 29 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Reduce method evaluate.

@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
    Tensor argument = this.argument.evaluate(context);
    if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions))
        throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + dimensions + ": Not all those dimensions are present in this tensor");
    // Special case: Reduce all
    if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
        if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
            return reduceIndexedVector((IndexedTensor) argument);
        else
            return reduceAllGeneral(argument);
    TensorType reducedType = type(argument.type());
    // Reduce cells
    Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
    for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
        Map.Entry<TensorAddress, Double> cell = i.next();
        TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType);
        aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
        aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
    }
    Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
    for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet()) reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
    return reducedBuilder.build();
}
Also used : TensorAddress(com.yahoo.tensor.TensorAddress) IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) HashMap(java.util.HashMap) TensorType(com.yahoo.tensor.TensorType) HashMap(java.util.HashMap) Map(java.util.Map) IndexedTensor(com.yahoo.tensor.IndexedTensor)

Example 30 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class DenseBinaryFormat method decode.

@Override
public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
    TensorType type;
    DimensionSizes sizes;
    if (optionalType.isPresent()) {
        type = optionalType.get();
        TensorType serializedType = decodeType(buffer);
        if (!serializedType.isAssignableTo(type))
            throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + " cannot be assigned to type " + type);
        sizes = sizesFromType(serializedType);
    } else {
        type = decodeType(buffer);
        sizes = sizesFromType(type);
    }
    Tensor.Builder builder = Tensor.Builder.of(type, sizes);
    decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder) builder);
    return builder.build();
}
Also used : IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) DimensionSizes(com.yahoo.tensor.DimensionSizes) TensorType(com.yahoo.tensor.TensorType) IndexedTensor(com.yahoo.tensor.IndexedTensor)

Aggregations

TensorType (com.yahoo.tensor.TensorType)38 Tensor (com.yahoo.tensor.Tensor)18 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)8 IndexedTensor (com.yahoo.tensor.IndexedTensor)7 MixedTensor (com.yahoo.tensor.MixedTensor)6 TensorAddress (com.yahoo.tensor.TensorAddress)6 Test (org.junit.Test)6 DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)4 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)4 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)4 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)4 Generate (com.yahoo.tensor.functions.Generate)4 List (java.util.List)4 ImmutableList (com.google.common.collect.ImmutableList)3 RankProfile (com.yahoo.searchdefinition.RankProfile)3 RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)3 DimensionSizes (com.yahoo.tensor.DimensionSizes)3 Reduce (com.yahoo.tensor.functions.Reduce)3 TensorFunction (com.yahoo.tensor.functions.TensorFunction)3 ArrayList (java.util.ArrayList)3