Search in sources :

Example 11 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Reshape method reshape.

public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
    if (!tensorSize(inputType).equals(tensorSize(outputType))) {
        throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
    }
    // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
    // then use the dimension order of the new shape to roll back into a tensor.
    // Here we create a transformation tensor that is multiplied with the from tensor to map into
    // the new shape. We have to introduce temporary dimension names and rename back if dimension names
    // in the new and old tensor type overlap.
    ExpressionNode unrollFrom = unrollTensorExpression(inputType);
    ExpressionNode unrollTo = unrollTensorExpression(outputType);
    ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
    TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
    Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
    TensorFunction outputFunction = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
    return outputFunction;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) ComparisonNode(com.yahoo.searchlib.rankingexpression.rule.ComparisonNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate)

Example 12 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class JsonWriterTestCase method registerTensorDocumentType.

private void registerTensorDocumentType() {
    DocumentType x = new DocumentType("testtensor");
    TensorType tensorType = new TensorType.Builder().mapped("x").mapped("y").build();
    x.addField(new Field("tensorfield", new TensorDataType(tensorType)));
    types.registerDocumentType(x);
}
Also used : Field(com.yahoo.document.Field) TensorDataType(com.yahoo.document.TensorDataType) DocumentType(com.yahoo.document.DocumentType) TensorType(com.yahoo.tensor.TensorType)

Example 13 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Mean method lazyGetFunction.

// todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    TensorFunction inputFunction = inputs.get(0).function().get();
    TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
    if (shouldKeepDimensions()) {
        // multiply with a generated tensor created from the reduced dimensions
        TensorType.Builder typeBuilder = new TensorType.Builder();
        for (String name : reduceDimensions) {
            typeBuilder.indexed(name, 1);
        }
        TensorType generatedType = typeBuilder.build();
        ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
        Generate generatedFunction = new Generate(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
        output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
    }
    return output;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate)

Example 14 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Join method evaluate.

@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
    Tensor a = argumentA.evaluate(context);
    Tensor b = argumentB.evaluate(context);
    TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
    // Choose join algorithm
    if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
        return indexedVectorJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
    else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
        return singleSpaceJoin(a, b, joinedType);
    else if (a.type().dimensions().containsAll(b.type().dimensions()))
        return subspaceJoin(b, a, joinedType, true);
    else if (b.type().dimensions().containsAll(a.type().dimensions()))
        return subspaceJoin(a, b, joinedType, false);
    else
        return generalJoin(a, b, joinedType);
}
Also used : IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) TensorType(com.yahoo.tensor.TensorType) IndexedTensor(com.yahoo.tensor.IndexedTensor)

Example 15 with TensorType

use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.

the class Concat method evaluate.

@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
    Tensor a = argumentA.evaluate(context);
    Tensor b = argumentB.evaluate(context);
    a = ensureIndexedDimension(dimension, a);
    b = ensureIndexedDimension(dimension, b);
    // If you get an exception here you have implemented a mixed tensor
    IndexedTensor aIndexed = (IndexedTensor) a;
    IndexedTensor bIndexed = (IndexedTensor) b;
    TensorType concatType = type(a.type(), b.type());
    DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
    Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
    long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
    int[] aToIndexes = mapIndexes(a.type(), concatType);
    int[] bToIndexes = mapIndexes(b.type(), concatType);
    concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
    concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
    return builder.build();
}
Also used : Arrays(java.util.Arrays) Iterator(java.util.Iterator) Set(java.util.Set) TensorType(com.yahoo.tensor.TensorType) Collectors(java.util.stream.Collectors) Objects(java.util.Objects) TensorAddress(com.yahoo.tensor.TensorAddress) List(java.util.List) ImmutableList(com.google.common.collect.ImmutableList) EvaluationContext(com.yahoo.tensor.evaluation.EvaluationContext) DimensionSizes(com.yahoo.tensor.DimensionSizes) TypeContext(com.yahoo.tensor.evaluation.TypeContext) Optional(java.util.Optional) IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) IndexedTensor(com.yahoo.tensor.IndexedTensor) Tensor(com.yahoo.tensor.Tensor) DimensionSizes(com.yahoo.tensor.DimensionSizes) TensorType(com.yahoo.tensor.TensorType) IndexedTensor(com.yahoo.tensor.IndexedTensor)

Aggregations

TensorType (com.yahoo.tensor.TensorType)38 Tensor (com.yahoo.tensor.Tensor)18 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)8 IndexedTensor (com.yahoo.tensor.IndexedTensor)7 MixedTensor (com.yahoo.tensor.MixedTensor)6 TensorAddress (com.yahoo.tensor.TensorAddress)6 Test (org.junit.Test)6 DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)4 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)4 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)4 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)4 Generate (com.yahoo.tensor.functions.Generate)4 List (java.util.List)4 ImmutableList (com.google.common.collect.ImmutableList)3 RankProfile (com.yahoo.searchdefinition.RankProfile)3 RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)3 DimensionSizes (com.yahoo.tensor.DimensionSizes)3 Reduce (com.yahoo.tensor.functions.Reduce)3 TensorFunction (com.yahoo.tensor.functions.TensorFunction)3 ArrayList (java.util.ArrayList)3