use of com.yahoo.searchlib.rankingexpression.evaluation.TensorValue in project vespa by vespa-engine.
the class ConstantTensorTransformer method transformConstantReference.
private ExpressionNode transformConstantReference(ReferenceNode node, RankProfileTransformContext context) {
Value value = context.constants().get(node.getName());
if (value == null || value.type().rank() == 0) {
return node;
}
TensorValue tensorValue = (TensorValue) value;
String featureName = CONSTANT + "(" + node.getName() + ")";
String tensorType = tensorValue.asTensor().type().toString();
context.rankPropertiesOutput().put(featureName + ".value", tensorValue.toString());
context.rankPropertiesOutput().put(featureName + ".type", tensorType);
// TODO: This allows us to reference constant "a" as "a" instead of "constant(a)", but we shouldn't allow that
return new ReferenceNode(CONSTANT, Arrays.asList(new NameNode(node.getName())), null);
}
use of com.yahoo.searchlib.rankingexpression.evaluation.TensorValue in project vespa by vespa-engine.
the class TestableTensorFlowModel method contextFrom.
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
return context;
}
use of com.yahoo.searchlib.rankingexpression.evaluation.TensorValue in project vespa by vespa-engine.
the class TestableTensorFlowModel method evaluateMacro.
private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
if (!context.names().contains(macroName)) {
RankingExpression e = model.macros().get(macroName);
evaluateMacroDependencies(context, model, e.getRoot());
context.put(macroName, new TensorValue(e.evaluate(context).asTensor()));
}
}
use of com.yahoo.searchlib.rankingexpression.evaluation.TensorValue in project vespa by vespa-engine.
the class Shape method createConstantValue.
private void createConstantValue() {
if (!allInputTypesPresent(1)) {
return;
}
OrderedTensorType inputType = inputs.get(0).type().get();
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
List<TensorType.Dimension> inputDimensions = inputType.dimensions();
for (int i = 0; i < inputDimensions.size(); i++) {
builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
}
this.setConstantValue(new TensorValue(builder.build()));
}
use of com.yahoo.searchlib.rankingexpression.evaluation.TensorValue in project vespa by vespa-engine.
the class TensorFlowImporter method importConstant.
private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) {
String name = operation.vespaName();
if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
return operation.function();
}
Tensor tensor;
if (operation.getConstantValue().isPresent()) {
Value value = operation.getConstantValue().get();
if (!(value instanceof TensorValue)) {
// scalar values are inserted directly into the expression
return operation.function();
}
tensor = value.asTensor();
} else {
// Here we use the type from the operation, which will have correct dimension names after name resolving
tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), operation.type().get());
operation.setConstantValue(new TensorValue(tensor));
}
if (tensor.type().rank() == 0 || tensor.size() <= 1) {
model.smallConstant(name, tensor);
} else {
model.largeConstant(name, tensor);
}
return operation.function();
}
Aggregations