Search in sources :

Example 1 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method reduceBatchDimensionExpression.

private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
    TensorFunction result = function;
    TensorType type = function.type(context);
    if (type.dimensions().size() > 1) {
        List<String> reduceDimensions = new ArrayList<>();
        for (TensorType.Dimension dimension : type.dimensions()) {
            if (dimension.size().orElse(-1L) == 1) {
                reduceDimensions.add(dimension.name());
            }
        }
        if (reduceDimensions.size() > 0) {
            result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
        }
    }
    return new TensorFunctionNode(result);
}
Also used : TensorFunction(com.yahoo.tensor.functions.TensorFunction) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) ArrayList(java.util.ArrayList) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce)

Example 2 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Reshape method reshape.

public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
    if (!tensorSize(inputType).equals(tensorSize(outputType))) {
        throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
    }
    // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
    // then use the dimension order of the new shape to roll back into a tensor.
    // Here we create a transformation tensor that is multiplied with the from tensor to map into
    // the new shape. We have to introduce temporary dimension names and rename back if dimension names
    // in the new and old tensor type overlap.
    ExpressionNode unrollFrom = unrollTensorExpression(inputType);
    ExpressionNode unrollTo = unrollTensorExpression(outputType);
    ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
    TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
    Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
    TensorFunction outputFunction = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
    return outputFunction;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) ComparisonNode(com.yahoo.searchlib.rankingexpression.rule.ComparisonNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate)

Example 3 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Reshape method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    if (!allInputFunctionsPresent(2)) {
        return null;
    }
    OrderedTensorType inputType = inputs.get(0).type().get();
    TensorFunction inputFunction = inputs.get(0).function().get();
    return reshape(inputFunction, inputType.type(), type.type());
}
Also used : TensorFunction(com.yahoo.tensor.functions.TensorFunction) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 4 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class TensorFlowImporter method importRankingExpression.

private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
    if (operation.function().isPresent()) {
        String name = operation.node().getName();
        if (!model.expressions().containsKey(operation.node().getName())) {
            TensorFunction function = operation.function().get();
            // Make sure output adheres to standard naming convention
            if (isSignatureOutput(model, operation)) {
                OrderedTensorType operationType = operation.type().get();
                OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
                if (!operationType.equals(standardNamingType)) {
                    List<String> renameFrom = operationType.dimensionNames();
                    List<String> renameTo = standardNamingType.dimensionNames();
                    function = new Rename(function, renameFrom, renameTo);
                }
            }
            try {
                // We add all intermediate nodes imported as separate expressions. Only
                // those referenced  in a signature output will be used. We parse the
                // TensorFunction here to convert it to a RankingExpression tree.
                model.expression(name, new RankingExpression(name, function.toString()));
            } catch (ParseException e) {
                throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
            }
        }
    }
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) TensorFunction(com.yahoo.tensor.functions.TensorFunction) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) ParseException(com.yahoo.searchlib.rankingexpression.parser.ParseException) Rename(com.yahoo.tensor.functions.Rename)

Example 5 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Join method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    if (!allInputFunctionsPresent(2)) {
        return null;
    }
    TensorFlowOperation a = largestInput();
    TensorFlowOperation b = smallestInput();
    List<String> aDimensionsToReduce = new ArrayList<>();
    List<String> bDimensionsToReduce = new ArrayList<>();
    int sizeDifference = a.type().get().rank() - b.type().get().rank();
    for (int i = 0; i < b.type().get().rank(); ++i) {
        TensorType.Dimension bDim = b.type().get().dimensions().get(i);
        TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
        long bSize = bDim.size().orElse(-1L);
        long aSize = aDim.size().orElse(-1L);
        if (bSize == 1L && aSize != 1L) {
            bDimensionsToReduce.add(bDim.name());
        }
        if (aSize == 1L && bSize != 1L) {
            aDimensionsToReduce.add(bDim.name());
        }
    }
    TensorFunction aReducedFunction = a.function().get();
    if (aDimensionsToReduce.size() > 0) {
        aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
    }
    TensorFunction bReducedFunction = b.function().get();
    if (bDimensionsToReduce.size() > 0) {
        bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
    }
    return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
}
Also used : ArrayList(java.util.ArrayList) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) TensorFunction(com.yahoo.tensor.functions.TensorFunction)

Aggregations

TensorFunction (com.yahoo.tensor.functions.TensorFunction)12 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)6 Reduce (com.yahoo.tensor.functions.Reduce)6 TensorType (com.yahoo.tensor.TensorType)4 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)3 VariableTensor (com.yahoo.tensor.evaluation.VariableTensor)3 Rename (com.yahoo.tensor.functions.Rename)3 ArrayList (java.util.ArrayList)3 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)2 TensorFunctionNode (com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode)2 MapEvaluationContext (com.yahoo.tensor.evaluation.MapEvaluationContext)2 ConstantTensor (com.yahoo.tensor.functions.ConstantTensor)2 Generate (com.yahoo.tensor.functions.Generate)2 Join (com.yahoo.tensor.functions.Join)2 RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)1 DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)1 ParseException (com.yahoo.searchlib.rankingexpression.parser.ParseException)1 ComparisonNode (com.yahoo.searchlib.rankingexpression.rule.ComparisonNode)1 CompositeNode (com.yahoo.searchlib.rankingexpression.rule.CompositeNode)1 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)1