use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.
the class Matmul method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputTypesPresent(2)) {
return null;
}
OrderedTensorType aType = inputs.get(0).type().get();
OrderedTensorType bType = inputs.get(1).type().get();
if (aType.type().rank() < 2 || bType.type().rank() < 2)
throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
if (aType.type().rank() != bType.type().rank())
throw new IllegalArgumentException("Tensors in matmul must have the same rank");
Optional<TensorFunction> aFunction = inputs.get(0).function();
Optional<TensorFunction> bFunction = inputs.get(1).function();
if (!aFunction.isPresent() || !bFunction.isPresent()) {
return null;
}
return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
}
use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.
the class Mean method lazyGetFunction.
// todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputTypesPresent(2)) {
return null;
}
TensorFunction inputFunction = inputs.get(0).function().get();
TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
TensorType.Builder typeBuilder = new TensorType.Builder();
for (String name : reduceDimensions) {
typeBuilder.indexed(name, 1);
}
TensorType generatedType = typeBuilder.build();
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
Generate generatedFunction = new Generate(generatedType, new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
}
return output;
}
use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method reduceBatchDimensionsAtInput.
private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
List<ExpressionNode> children = ((TensorFunctionNode) node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
if (model.requiredMacros().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
}
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
if (model.requiredMacros().containsKey(referenceNode.getName())) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
if (node instanceof CompositeNode) {
List<ExpressionNode> children = ((CompositeNode) node).children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children) {
transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
}
return ((CompositeNode) node).setChildren(transformedChildren);
}
return node;
}
use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.
the class Placeholder method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
if (!standardNamingType.equals(type)) {
List<String> renameFrom = standardNamingType.dimensionNames();
List<String> renameTo = type.dimensionNames();
output = new Rename(output, renameFrom, renameTo);
}
return output;
}
use of com.yahoo.tensor.functions.TensorFunction in project vespa by vespa-engine.
the class Select method lazyGetFunction.
@Override
protected TensorFunction lazyGetFunction() {
if (!allInputFunctionsPresent(3)) {
return null;
}
TensorFlowOperation conditionOperation = inputs().get(0);
TensorFunction a = inputs().get(1).function().get();
TensorFunction b = inputs().get(2).function().get();
// Shortcut: if we know during import which tensor to select, do that directly here.
if (conditionOperation.getConstantValue().isPresent()) {
Tensor condition = conditionOperation.getConstantValue().get().asTensor();
if (condition.type().rank() == 0) {
return ((int) condition.asDouble() == 0) ? b : a;
}
if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
}
}
// The task is to select cells from 'x' or 'y' based on 'condition'.
// If 'condition' is 0 (false), select from 'y', if 1 (true) select
// from 'x'. We do this by individually joining 'x' and 'y' with
// 'condition', and then joining the resulting two tensors.
TensorFunction conditionFunction = conditionOperation.function().get();
TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
@Override
public double applyAsDouble(double a, double b) {
return a * (1.0 - b);
}
@Override
public String toString() {
return "f(a,b)(a * (1-b))";
}
});
return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
}
Aggregations