use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class TestRowExpressionTranslator method testEndToEndFunctionTranslation.
@Test
public void testEndToEndFunctionTranslation() {
String untranslated = "LN(bitwise_and(1, col1))";
TypeProvider typeProvider = TypeProvider.copyOf(ImmutableMap.of("col1", BIGINT));
CallExpression callExpression = (CallExpression) sqlToRowExpressionTranslator.translate(expression(untranslated), typeProvider);
TranslatedExpression translatedExpression = translateWith(callExpression, new TestFunctionTranslator(functionAndTypeManager, buildFunctionTranslator(ImmutableSet.of(TestFunctions.class))), emptyMap());
assertTrue(translatedExpression.getTranslated().isPresent());
assertEquals(translatedExpression.getTranslated().get(), "LNof(1 BITWISE_AND col1)");
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class RowExpressionRewriteRuleSet method rewriteAggregation.
private AggregationNode.Aggregation rewriteAggregation(AggregationNode.Aggregation aggregation, Rule.Context context) {
RowExpression rewrittenCall = rewriter.rewrite(aggregation.getCall(), context);
checkArgument(rewrittenCall instanceof CallExpression, "Aggregation CallExpression must be rewritten to CallExpression");
return new AggregationNode.Aggregation((CallExpression) rewrittenCall, aggregation.getFilter().map(filter -> rewriter.rewrite(filter, context)), aggregation.getOrderBy(), aggregation.isDistinct(), aggregation.getMask());
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class SimplifyCountOverConstant method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
ProjectNode child = captures.get(CHILD);
boolean changed = false;
Map<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations());
for (Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
VariableReferenceExpression variable = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (isCountOverConstant(aggregation, child.getAssignments())) {
changed = true;
aggregations.put(variable, new AggregationNode.Aggregation(new CallExpression(aggregation.getCall().getSourceLocation(), "count", functionResolution.countFunction(), BIGINT, ImmutableList.of()), Optional.empty(), Optional.empty(), false, aggregation.getMask()));
}
}
if (!changed) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(parent.getSourceLocation(), parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashVariable(), parent.getGroupIdVariable()));
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class TranslateExpressions method createRewriter.
private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
return new PlanRowExpressionRewriter() {
@Override
public RowExpression rewrite(RowExpression expression, Rule.Context context) {
// special treatment of the CallExpression in Aggregation
if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getVariableAllocator());
}
return removeOriginalExpression(expression, context);
}
private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanVariableAllocator variableAllocator) {
Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, variableAllocator.getTypes());
return new CallExpression(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types)).collect(toImmutableList()));
}
private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(typeSignature -> (FunctionType) (metadata.getFunctionAndTypeManager().getType(typeSignature))).collect(toImmutableList());
InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
List<Class> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());
for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);
// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<String, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(argument.getName().getValue(), type);
}
// the lambda expression itself
builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(getExpressionTypes(session, metadata, sqlParser, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody(), emptyList(), NOOP));
}
}
for (RowExpression argument : callExpression.getArguments()) {
if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
continue;
}
builder.putAll(analyze(castToExpression(argument), session, typeProvider));
}
return builder.build();
}
private Map<NodeRef<Expression>, Type> analyze(Expression expression, Session session, TypeProvider typeProvider) {
return getExpressionTypes(session, metadata, sqlParser, typeProvider, expression, emptyList(), NOOP);
}
private RowExpression toRowExpression(Expression expression, Session session, Map<NodeRef<Expression>, Type> types) {
return SqlToRowExpressionTranslator.translate(expression, types, ImmutableMap.of(), metadata.getFunctionAndTypeManager(), session);
}
private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context) {
if (isExpression(expression)) {
return toRowExpression(castToExpression(expression), context.getSession(), analyze(castToExpression(expression), context.getSession(), context.getVariableAllocator().getTypes()));
}
return expression;
}
private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types) {
if (isExpression(rowExpression)) {
Expression expression = castToExpression(rowExpression);
return toRowExpression(expression, session, types);
}
return rowExpression;
}
};
}
use of com.facebook.presto.spi.relation.CallExpression in project presto by prestodb.
the class TestCursorProcessorCompiler method testRewriteRowExpressionWithCSE.
@Test
public void testRewriteRowExpressionWithCSE() {
CursorProcessorCompiler cseCursorCompiler = new CursorProcessorCompiler(METADATA, true, emptyMap());
ClassDefinition cursorProcessorClassDefinition = new ClassDefinition(a(PUBLIC, FINAL), makeClassName(CursorProcessor.class.getSimpleName()), type(Object.class), type(CursorProcessor.class));
RowExpression filter = new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2);
List<RowExpression> projections = ImmutableList.of(ADD_X_Y_Z);
List<RowExpression> rowExpressions = ImmutableList.<RowExpression>builder().addAll(projections).add(filter).build();
Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel = collectCSEByLevel(rowExpressions);
Map<VariableReferenceExpression, CommonSubExpressionRewriter.CommonSubExpressionFields> cseFields = declareCommonSubExpressionFields(cursorProcessorClassDefinition, commonSubExpressionsByLevel);
Map<RowExpression, VariableReferenceExpression> commonSubExpressions = commonSubExpressionsByLevel.values().stream().flatMap(m -> m.entrySet().stream()).collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
// X+Y as CSE
assertEquals(1, cseFields.size());
VariableReferenceExpression cseVariable = cseFields.keySet().iterator().next();
RowExpression rewrittenFilter = cseCursorCompiler.rewriteRowExpressionsWithCSE(ImmutableList.of(filter), commonSubExpressions).get(0);
List<RowExpression> rewrittenProjections = cseCursorCompiler.rewriteRowExpressionsWithCSE(projections, commonSubExpressions);
// X+Y+Z contains CSE X+Y
assertTrue(((CallExpression) rewrittenProjections.get(0)).getArguments().contains(cseVariable));
// X+Y > 2 consists CSE X+Y
assertTrue(((CallExpression) ((SpecialFormExpression) rewrittenFilter).getArguments().get(0)).getArguments().contains(cseVariable));
}
Aggregations