use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class TensorFlowImporter method importRankingExpression.
private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
if (operation.function().isPresent()) {
String name = operation.node().getName();
if (!model.expressions().containsKey(operation.node().getName())) {
TensorFunction function = operation.function().get();
// Make sure output adheres to standard naming convention
if (isSignatureOutput(model, operation)) {
OrderedTensorType operationType = operation.type().get();
OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
if (!operationType.equals(standardNamingType)) {
List<String> renameFrom = operationType.dimensionNames();
List<String> renameTo = standardNamingType.dimensionNames();
function = new Rename(function, renameFrom, renameTo);
}
}
try {
// We add all intermediate nodes imported as separate expressions. Only
// those referenced in a signature output will be used. We parse the
// TensorFunction here to convert it to a RankingExpression tree.
model.expression(name, new RankingExpression(name, function.toString()));
} catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
}
}
}
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Join method smallestInput.
private TensorFlowOperation smallestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Join method largestInput.
private TensorFlowOperation largestInput() {
OrderedTensorType a = inputs.get(0).type().get();
OrderedTensorType b = inputs.get(1).type().get();
return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Matmul method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputTypesPresent(2)) {
return null;
}
OrderedTensorType aType = inputs.get(0).type().get();
OrderedTensorType bType = inputs.get(1).type().get();
if (aType.type().rank() < 2 || bType.type().rank() < 2)
throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
if (aType.type().rank() != bType.type().rank())
throw new IllegalArgumentException("Tensors in matmul must have the same rank");
Optional<TensorFunction> aFunction = inputs.get(0).function();
Optional<TensorFunction> bFunction = inputs.get(1).function();
if (!aFunction.isPresent() || !bFunction.isPresent()) {
return null;
}
return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
}
use of com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType in project vespa by vespa-engine.
the class Mean method lazyGetType.
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(2)) {
return null;
}
TensorFlowOperation reductionIndices = inputs.get(1);
if (!reductionIndices.getConstantValue().isPresent()) {
throw new IllegalArgumentException("Mean in " + node.getName() + ": " + "reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
reduceDimensions = new ArrayList<>();
OrderedTensorType inputType = inputs.get(0).type().get();
for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell cell = cellIterator.next();
int dimensionIndex = cell.getValue().intValue();
if (dimensionIndex < 0) {
dimensionIndex = inputType.dimensions().size() - dimensionIndex;
}
reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
}
return reducedType(inputType, shouldKeepDimensions());
}
Aggregations