use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class RankProfile method typeContext.
/**
* Creates a context containing the type information of all constants, attributes and query profiles
* referable from this rank profile.
*/
public TypeContext<Reference> typeContext(QueryProfileRegistry queryProfiles) {
MapEvaluationTypeContext context = new MapEvaluationTypeContext(getMacros().values().stream().map(Macro::asExpressionFunction).collect(Collectors.toList()));
// Add small and large constants, respectively
getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type()));
getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType()));
// Add attributes
getSearch().allFields().forEach(field -> addAttributeFeatureTypes(field, context));
getSearch().allImportedFields().forEach(field -> addAttributeFeatureTypes(field, context));
// Add query features from rank profile types reached from the "default" profile
for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) {
for (FieldDescription field : queryProfileType.declaredFields().values()) {
TensorType type = field.getType().asTensorType();
Optional<Reference> feature = Reference.simple(field.getName());
if (!feature.isPresent() || !feature.get().name().equals("query"))
continue;
TensorType existingType = context.getType(feature.get());
if (!Objects.equals(existingType, context.defaultTypeOf(feature.get())))
type = existingType.dimensionwiseGeneralizationWith(type).orElseThrow(() -> new IllegalArgumentException(queryProfileType + " contains query feature " + feature.get() + " with type " + field.getType().asTensorType() + ", but this is already defined " + "in another query profile with type " + context.getType(feature.get())));
context.setType(feature.get(), type);
}
}
return context;
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method reduceBatchDimensions.
/**
* Check if batch dimensions of inputs can be reduced out. If the input
* macro specifies that a single exemplar should be evaluated, we can
* reduce the batch dimension out.
*/
private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, RankProfile profile, QueryProfileRegistry queryProfiles) {
TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
// Check generated macros for inputs to reduce
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
if (!model.macros().containsKey(macroName)) {
continue;
}
RankProfile.Macro macro = profile.getMacros().get(macroName);
if (macro == null) {
throw new IllegalArgumentException("Model refers to generated macro '" + macroName + "but this macro is not present in " + profile);
}
RankingExpression macroExpression = macro.getRankingExpression();
macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
}
// Check expression for inputs to reduce
ExpressionNode root = expression.getRoot();
root = reduceBatchDimensionsAtInput(root, model, typeContext);
TensorType typeAfterReducing = root.type(typeContext);
root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
expression.setRoot(root);
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method transformLargeConstant.
private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByMacros, String constantName, Tensor constantValue) {
RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
if (macroOverridingConstant != null) {
TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
if (!macroType.equals(constantValue.type()))
throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + "The required type of this is " + constantValue.type() + ", but the macro returns " + macroType);
// will replace constant(constantName) by constantName later
constantsReplacedByMacros.add(constantName);
} else {
Path constantPath = store.writeLargeConstant(constantName, constantValue);
if (!profile.getSearch().getRankingConstants().containsKey(constantName)) {
profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantPath.toString()));
}
}
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method expandBatchDimensionsAtOutput.
/**
* If batch dimensions have been reduced away above, bring them back here
* for any following computation of the tensor.
* Todo: determine when this is not necessary!
*/
private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
if (after.equals(before)) {
return node;
}
TensorType.Builder typeBuilder = new TensorType.Builder();
for (TensorType.Dimension dimension : before.dimensions()) {
if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
typeBuilder.indexed(dimension.name(), 1);
}
}
TensorType expandDimensionsType = typeBuilder.build();
if (expandDimensionsType.dimensions().size() > 0) {
ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
Generate generatedFunction = new Generate(expandDimensionsType, new GeneratorLambdaFunctionNode(expandDimensionsType, generatedExpression).asLongListToDoubleOperator());
Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
return new TensorFunctionNode(expand);
}
return node;
}
use of com.yahoo.tensor.TensorType in project vespa by vespa-engine.
the class TensorFlowFeatureConverter method reduceBatchDimensionExpression.
private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
TensorFunction result = function;
TensorType type = function.type(context);
if (type.dimensions().size() > 1) {
List<String> reduceDimensions = new ArrayList<>();
for (TensorType.Dimension dimension : type.dimensions()) {
if (dimension.size().orElse(-1L) == 1) {
reduceDimensions.add(dimension.name());
}
}
if (reduceDimensions.size() > 0) {
result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
}
}
return new TensorFunctionNode(result);
}
Aggregations