Search in sources :

Example 6 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class TensorFlowImporter method importRankingExpression.

private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
    if (operation.function().isPresent()) {
        String name = operation.node().getName();
        if (!model.expressions().containsKey(operation.node().getName())) {
            TensorFunction function = operation.function().get();
            // Make sure output adheres to standard naming convention
            if (isSignatureOutput(model, operation)) {
                OrderedTensorType operationType = operation.type().get();
                OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
                if (!operationType.equals(standardNamingType)) {
                    List<String> renameFrom = operationType.dimensionNames();
                    List<String> renameTo = standardNamingType.dimensionNames();
                    function = new Rename(function, renameFrom, renameTo);
                }
            }
            try {
                // We add all intermediate nodes imported as separate expressions. Only
                // those referenced  in a signature output will be used. We parse the
                // TensorFunction here to convert it to a RankingExpression tree.
                model.expression(name, new RankingExpression(name, function.toString()));
            } catch (ParseException e) {
                throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
            }
        }
    }
}
Also used : RankingExpression(com.yahoo.searchlib.rankingexpression.RankingExpression) TensorFunction(com.yahoo.tensor.functions.TensorFunction) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) ParseException(com.yahoo.searchlib.rankingexpression.parser.ParseException) Rename(com.yahoo.tensor.functions.Rename)

Example 7 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Join method smallestInput.

private TensorFlowOperation smallestInput() {
    OrderedTensorType a = inputs.get(0).type().get();
    OrderedTensorType b = inputs.get(1).type().get();
    return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
}
Also used : OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 8 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Join method largestInput.

private TensorFlowOperation largestInput() {
    OrderedTensorType a = inputs.get(0).type().get();
    OrderedTensorType b = inputs.get(1).type().get();
    return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
}
Also used : OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 9 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Matmul method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    OrderedTensorType aType = inputs.get(0).type().get();
    OrderedTensorType bType = inputs.get(1).type().get();
    if (aType.type().rank() < 2 || bType.type().rank() < 2)
        throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
    if (aType.type().rank() != bType.type().rank())
        throw new IllegalArgumentException("Tensors in matmul must have the same rank");
    Optional<TensorFunction> aFunction = inputs.get(0).function();
    Optional<TensorFunction> bFunction = inputs.get(1).function();
    if (!aFunction.isPresent() || !bFunction.isPresent()) {
        return null;
    }
    return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
}
Also used : TensorFunction(com.yahoo.tensor.functions.TensorFunction) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 10 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Mean method lazyGetType.

@Override
protected OrderedTensorType lazyGetType() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    TensorFlowOperation reductionIndices = inputs.get(1);
    if (!reductionIndices.getConstantValue().isPresent()) {
        throw new IllegalArgumentException("Mean in " + node.getName() + ": " + "reduction indices must be a constant.");
    }
    Tensor indices = reductionIndices.getConstantValue().get().asTensor();
    reduceDimensions = new ArrayList<>();
    OrderedTensorType inputType = inputs.get(0).type().get();
    for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext(); ) {
        Tensor.Cell cell = cellIterator.next();
        int dimensionIndex = cell.getValue().intValue();
        if (dimensionIndex < 0) {
            dimensionIndex = inputType.dimensions().size() - dimensionIndex;
        }
        reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
    }
    return reducedType(inputType, shouldKeepDimensions());
}
Also used : Tensor(com.yahoo.tensor.Tensor) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Aggregations

OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)15 TensorFunction (com.yahoo.tensor.functions.TensorFunction)4 Tensor (com.yahoo.tensor.Tensor)3 TensorType (com.yahoo.tensor.TensorType)3 DimensionRenamer (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer)2 RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)1 TensorValue (com.yahoo.searchlib.rankingexpression.evaluation.TensorValue)1 TensorConverter (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter)1 ParseException (com.yahoo.searchlib.rankingexpression.parser.ParseException)1 IndexedTensor (com.yahoo.tensor.IndexedTensor)1 Reduce (com.yahoo.tensor.functions.Reduce)1 Rename (com.yahoo.tensor.functions.Rename)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 Test (org.junit.Test)1 AttrValue (org.tensorflow.framework.AttrValue)1 NodeDef (org.tensorflow.framework.NodeDef)1