Search in sources :

Example 1 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Reshape method lazyGetType.

@Override
protected OrderedTensorType lazyGetType() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    TensorFlowOperation newShape = inputs.get(1);
    if (!newShape.getConstantValue().isPresent()) {
        throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + "shape input must be a constant.");
    }
    Tensor shape = newShape.getConstantValue().get().asTensor();
    OrderedTensorType inputType = inputs.get(0).type().get();
    OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
    int dimensionIndex = 0;
    for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext(); ) {
        Tensor.Cell cell = cellIterator.next();
        int size = cell.getValue().intValue();
        if (size < 0) {
            size = -1 * (int) shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType.type()).intValue();
        }
        outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), dimensionIndex), size));
        dimensionIndex++;
    }
    return outputTypeBuilder.build();
}
Also used : Tensor(com.yahoo.tensor.Tensor) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 2 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Reshape method lazyGetFunction.

@Override
protected TensorFunction lazyGetFunction() {
    if (!allInputTypesPresent(2)) {
        return null;
    }
    if (!allInputFunctionsPresent(2)) {
        return null;
    }
    OrderedTensorType inputType = inputs.get(0).type().get();
    TensorFunction inputFunction = inputs.get(0).function().get();
    return reshape(inputFunction, inputType.type(), type.type());
}
Also used : TensorFunction(com.yahoo.tensor.functions.TensorFunction) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 3 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Shape method createConstantValue.

private void createConstantValue() {
    if (!allInputTypesPresent(1)) {
        return;
    }
    OrderedTensorType inputType = inputs.get(0).type().get();
    IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
    List<TensorType.Dimension> inputDimensions = inputType.dimensions();
    for (int i = 0; i < inputDimensions.size(); i++) {
        builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
    }
    this.setConstantValue(new TensorValue(builder.build()));
}
Also used : TensorValue(com.yahoo.searchlib.rankingexpression.evaluation.TensorValue) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) IndexedTensor(com.yahoo.tensor.IndexedTensor)

Example 4 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Select method lazyGetType.

@Override
protected OrderedTensorType lazyGetType() {
    if (!allInputTypesPresent(3)) {
        return null;
    }
    OrderedTensorType a = inputs.get(1).type().get();
    OrderedTensorType b = inputs.get(2).type().get();
    if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) {
        throw new IllegalArgumentException("'Select': input tensors must have the same shape");
    }
    return a;
}
Also used : OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)

Example 5 with OrderedTensorType

use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.

the class Squeeze method lazyGetType.

@Override
protected OrderedTensorType lazyGetType() {
    if (!allInputTypesPresent(1)) {
        return null;
    }
    OrderedTensorType inputType = inputs.get(0).type().get();
    squeezeDimensions = new ArrayList<>();
    AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
    if (squeezeDimsAttr == null) {
        squeezeDimensions = inputType.type().dimensions().stream().filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
    } else {
        squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).map(i -> inputType.type().dimensions().get(i.intValue())).filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
    }
    return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
Also used : TensorConverter(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter) List(java.util.List) NodeDef(org.tensorflow.framework.NodeDef) AttrValue(org.tensorflow.framework.AttrValue) Optional(java.util.Optional) Reduce(com.yahoo.tensor.functions.Reduce) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) TensorType(com.yahoo.tensor.TensorType) DimensionRenamer(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer) TensorFunction(com.yahoo.tensor.functions.TensorFunction) Collectors(java.util.stream.Collectors) ArrayList(java.util.ArrayList) OrderedTensorType(com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType) AttrValue(org.tensorflow.framework.AttrValue)

Aggregations

OrderedTensorType (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType)15 TensorFunction (com.yahoo.tensor.functions.TensorFunction)4 Tensor (com.yahoo.tensor.Tensor)3 TensorType (com.yahoo.tensor.TensorType)3 DimensionRenamer (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer)2 RankingExpression (com.yahoo.searchlib.rankingexpression.RankingExpression)1 TensorValue (com.yahoo.searchlib.rankingexpression.evaluation.TensorValue)1 TensorConverter (com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter)1 ParseException (com.yahoo.searchlib.rankingexpression.parser.ParseException)1 IndexedTensor (com.yahoo.tensor.IndexedTensor)1 Reduce (com.yahoo.tensor.functions.Reduce)1 Rename (com.yahoo.tensor.functions.Rename)1 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Optional (java.util.Optional)1 Collectors (java.util.stream.Collectors)1 Test (org.junit.Test)1 AttrValue (org.tensorflow.framework.AttrValue)1 NodeDef (org.tensorflow.framework.NodeDef)1