use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Reshape method lazyGetType.
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(2)) {
return null;
}
TensorFlowOperation newShape = inputs.get(1);
if (!newShape.getConstantValue().isPresent()) {
throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + "shape input must be a constant.");
}
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell cell = cellIterator.next();
int size = cell.getValue().intValue();
if (size < 0) {
size = -1 * (int) shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType.type()).intValue();
}
outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), dimensionIndex), size));
dimensionIndex++;
}
return outputTypeBuilder.build();
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Reshape method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputTypesPresent(2)) {
return null;
}
if (!allInputFunctionsPresent(2)) {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
return reshape(inputFunction, inputType.type(), type.type());
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Shape method createConstantValue.
private void createConstantValue() {
if (!allInputTypesPresent(1)) {
return;
}
OrderedTensorType inputType = inputs.get(0).type().get();
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
List<TensorType.Dimension> inputDimensions = inputType.dimensions();
for (int i = 0; i < inputDimensions.size(); i++) {
builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
}
this.setConstantValue(new TensorValue(builder.build()));
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Select method lazyGetType.
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(3)) {
return null;
}
OrderedTensorType a = inputs.get(1).type().get();
OrderedTensorType b = inputs.get(2).type().get();
if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) {
throw new IllegalArgumentException("'Select': input tensors must have the same shape");
}
return a;
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Squeeze method lazyGetType.
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(1)) {
return null;
}
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
if (squeezeDimsAttr == null) {
squeezeDimensions = inputType.type().dimensions().stream().filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
} else {
squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).map(i -> inputType.type().dimensions().get(i.intValue())).filter(dim -> TensorConverter.dimensionSize(dim) == 1).map(TensorType.Dimension::name).collect(Collectors.toList());
}
return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
}
Aggregations