Search in sources :

Example 6 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Matmul method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    OrderedTensorType aType = inputs.get(0).type().get();
    OrderedTensorType bType = inputs.get(1).type().get();
    if (aType.type().rank() < 2 || bType.type().rank() < 2)
        throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
    if (aType.type().rank() != bType.type().rank())
        throw new IllegalArgumentException("Tensors in matmul must have the same rank");
    Optional<TensorFunction> aFunction = inputs.get(0).function();
    Optional<TensorFunction> bFunction = inputs.get(1).function();
    if (!aFunction.isPresent() || !bFunction.isPresent()) {
        return null;
    }
    return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
}
Also used : TensorFunction(com.yahoo.tensor.functions.TensorFunction) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 7 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Mean method lazyGetFunction.

// todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    TensorFunction inputFunction = inputs.get(0).function().get();
    TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
    if (shouldKeepDimensions()) {
        // multiply with a generated tensor created from the reduced dimensions
        TensorType.Builder typeBuilder = new TensorType.Builder();
        for (String name : reduceDimensions) {
            typeBuilder.indexed(name, 1);
        }
        TensorType generatedType = typeBuilder.build();
        ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
        Generate generatedFunction = new Generate(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
        output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
    }
    return output;
}
Also used : GeneratorLambdaFunctionNode(com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) Reduce(com.yahoo.tensor.functions.Reduce) ConstantNode(com.yahoo.searchlib.rankingexpression.rule.ConstantNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) DoubleValue(com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) Generate(com.yahoo.tensor.functions.Generate)

Example 8 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class TensorFlowFeatureConverter method reduceBatchDimensionsAtInput.

private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, TypeContext<Reference> typeContext) {
    if (node instanceof TensorFunctionNode) {
        TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
        if (tensorFunction instanceof Rename) {
            List<ExpressionNode> children = ((TensorFunctionNode) node).children();
            if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
                ReferenceNode referenceNode = (ReferenceNode) children.get(0);
                if (model.requiredMacros().containsKey(referenceNode.getName())) {
                    return reduceBatchDimensionExpression(tensorFunction, typeContext);
                }
            }
        }
    }
    if (node instanceof ReferenceNode) {
        ReferenceNode referenceNode = (ReferenceNode) node;
        if (model.requiredMacros().containsKey(referenceNode.getName())) {
            return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
        }
    }
    if (node instanceof CompositeNode) {
        List<ExpressionNode> children = ((CompositeNode) node).children();
        List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
        for (ExpressionNode child : children) {
            transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
        }
        return ((CompositeNode) node).setChildren(transformedChildren);
    }
    return node;
}
Also used : CompositeNode(com.yahoo.searchlib.rankingexpression.rule.CompositeNode) TensorFunctionNode(com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode) TensorFunction(com.yahoo.tensor.functions.TensorFunction) ReferenceNode(com.yahoo.searchlib.rankingexpression.rule.ReferenceNode) ExpressionNode(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode) ArrayList(java.util.ArrayList) Rename(com.yahoo.tensor.functions.Rename)

Example 9 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Placeholder method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
    if (!standardNamingType.equals(type)) {
        List<String> renameFrom = standardNamingType.dimensionNames();
        List<String> renameTo = type.dimensionNames();
        output = new Rename(output, renameFrom, renameTo);
    }
    return output;
}
Also used : VariableTensor(com.yahoo.tensor.evaluation.VariableTensor) TensorFunction(com.yahoo.tensor.functions.TensorFunction) Rename(com.yahoo.tensor.functions.Rename)

Example 10 with TensorFunction

use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.

the class Select method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputFunctionsPresent(3)) {
        return null;
    }
    TensorFlowOperation conditionOperation = inputs().get(0);
    TensorFunction a = inputs().get(1).function().get();
    TensorFunction b = inputs().get(2).function().get();
    // Shortcut: if we know during import which tensor to select, do that directly here.
    if (conditionOperation.getConstantValue().isPresent()) {
        Tensor condition = conditionOperation.getConstantValue().get().asTensor();
        if (condition.type().rank() == 0) {
            return ((int) condition.asDouble() == 0) ? b : a;
        }
        if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
            return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
        }
    }
    // The task is to select cells from 'x' or 'y' based on 'condition'.
    // If 'condition' is 0 (false), select from 'y', if 1 (true) select
    // from 'x'. We do this by individually joining 'x' and 'y' with
    // 'condition', and then joining the resulting two tensors.
    TensorFunction conditionFunction = conditionOperation.function().get();
    TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
    TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {

        @Override
        public double applyAsDouble(double a, double b) {
            return a * (1.0 - b);
        }

        @Override
        public String toString() {
            return "f(a,b)(a * (1-b))";
        }
    });
    return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
}
Also used : Tensor(com.yahoo.tensor.Tensor) DoubleBinaryOperator(java.util.function.DoubleBinaryOperator) TensorFunction(com.yahoo.tensor.functions.TensorFunction)

Aggregations

TensorFunction (com.yahoo.tensor.functions.TensorFunction)12 OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)6 Reduce (com.yahoo.tensor.functions.Reduce)6 TensorType (com.yahoo.tensor.TensorType)4 ExpressionNode (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)3 VariableTensor (com.yahoo.tensor.evaluation.VariableTensor)3 Rename (com.yahoo.tensor.functions.Rename)3 ArrayList (java.util.ArrayList)3 GeneratorLambdaFunctionNode (com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode)2 TensorFunctionNode (com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode)2 MapEvaluationContext (com.yahoo.tensor.evaluation.MapEvaluationContext)2 ConstantTensor (com.yahoo.tensor.functions.ConstantTensor)2 Generate (com.yahoo.tensor.functions.Generate)2 Join (com.yahoo.tensor.functions.Join)2 RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)1 DoubleValue (com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue)1 ParseException (com.yahoo.searchlib.rankingexpression.parser.ParseException)1 ComparisonNode (com.yahoo.searchlib.rankingexpression.rule.ComparisonNode)1 CompositeNode (com.yahoo.searchlib.rankingexpression.rule.CompositeNode)1 ConstantNode (com.yahoo.searchlib.rankingexpression.rule.ConstantNode)1