use of com.yahoo.searchlib.rankingexpression.rule.ComparisonNode in project vespa by vespa-engine.
the class Reshape method reshape.
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
if (!tensorSize(inputType).equals(tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
// Here we create a transformation tensor that is multiplied with the from tensor to map into
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
TensorFunction outputFunction = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
return outputFunction;
}
Aggregations