use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class TranslateExpressions method createRewriter.
private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
return new PlanRowExpressionRewriter() {
@Override
public RowExpression rewrite(RowExpression expression, Rule.Context context) {
// special treatment of the CallExpression in Aggregation
if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getSymbolAllocator(), context);
}
return removeOriginalExpression(expression, context, new HashMap<>());
}
private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Rule.Context context) {
Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, planSymbolAllocator.getTypes());
return new CallExpression(callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types, context)).collect(toImmutableList()), Optional.empty());
}
private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(metadata::getType).map(FunctionType.class::cast).collect(toImmutableList());
InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
List<Class<?>> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());
for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);
// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(new Symbol(argument.getName().getValue()), type);
}
// the lambda expression itself
builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody()));
}
}
for (RowExpression argument : callExpression.getArguments()) {
if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
continue;
}
builder.putAll(typeAnalyzer.getTypes(session, typeProvider, castToExpression(argument)));
}
return builder.build();
}
private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Session session) {
RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, FunctionKind.SCALAR, types, layout, metadata.getFunctionAndTypeManager(), session, false);
return new RowExpressionOptimizer(metadata).optimize(rowExpression, RowExpressionInterpreter.Level.SERIALIZABLE, session.toConnectorSession());
}
private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context, Map<Symbol, Integer> layout) {
if (isExpression(expression)) {
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
return toRowExpression(castToExpression(expression), typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), castToExpression(expression)), layout, context.getSession());
}
return expression;
}
private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types, Rule.Context context) {
if (isExpression(rowExpression)) {
Expression expression = castToExpression(rowExpression);
return toRowExpression(expression, types, new HashMap<>(), session);
}
return rowExpression;
}
};
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class SimplifyCountOverConstant method apply.
@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
ProjectNode child = captures.get(CHILD);
boolean changed = false;
Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations());
for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
Symbol symbol = entry.getKey();
AggregationNode.Aggregation aggregation = entry.getValue();
if (isCountOverConstant(aggregation, child.getAssignments())) {
changed = true;
aggregations.put(symbol, new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, ImmutableList.of(), Optional.empty()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
}
}
if (!changed) {
return Result.empty();
}
return Result.ofPlanNode(new AggregationNode(parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol(), parent.getAggregationType(), parent.getFinalizeSymbol()));
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class PushPartialAggregationThroughExchange method split.
private PlanNode split(AggregationNode node, Context context) {
// otherwise, add a partial and final with an exchange in between
Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
AggregationNode.Aggregation originalAggregation = entry.getValue();
String functionName = metadata.getFunctionAndTypeManager().getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
InternalAggregationFunction function = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(functionHandle);
Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(functionName, function.getIntermediateType());
checkState(!originalAggregation.getOrderingScheme().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments(), Optional.empty()), originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
// rewrite final aggregation in terms of intermediate function
finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), Optional.empty()), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
}
PlanNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
ImmutableList.of(), FINAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class RewriteSpatialPartitioningAggregation method apply.
@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", INTEGER);
ImmutableMap.Builder<Symbol, RowExpression> envelopeAssignments = ImmutableMap.builder();
for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) {
Aggregation aggregation = entry.getValue();
QualifiedObjectName name = metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName();
Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE);
if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
RowExpression geometry = getOnlyElement(aggregation.getArguments().stream().collect(toImmutableList()));
Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", metadata.getType(GEOMETRY_TYPE_SIGNATURE));
if (isFunctionNameMatch(geometry, "ST_Envelope")) {
envelopeAssignments.put(envelopeSymbol, geometry);
} else {
envelopeAssignments.put(envelopeSymbol, castToRowExpression(new FunctionCallBuilder(metadata).setName(QualifiedName.of("ST_Envelope")).addArgument(GEOMETRY_TYPE_SIGNATURE, castToExpression(geometry)).build()));
}
aggregations.put(entry.getKey(), new Aggregation(new CallExpression(name.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(NAME.getObjectName(), fromTypes(geometryType, INTEGER)), context.getSymbolAllocator().getTypes().get(entry.getKey()), ImmutableList.of(castToRowExpression(toSymbolReference(envelopeSymbol)), castToRowExpression(toSymbolReference(partitionCountSymbol))), Optional.empty()), ImmutableList.of(castToRowExpression(toSymbolReference(envelopeSymbol)), castToRowExpression(toSymbolReference(partitionCountSymbol))), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
} else {
aggregations.put(entry);
}
}
return Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(node.getSource().getOutputSymbols())).put(partitionCountSymbol, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))).putAll(envelopeAssignments.build()).build()), aggregations.build(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol()));
}
use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.
the class RowExpressionRewriteRuleSet method rewriteAggregation.
private AggregationNode.Aggregation rewriteAggregation(AggregationNode.Aggregation aggregation, Type returnType, Rule.Context context, FunctionAndTypeManager functionAndTypeManager) {
CallExpression callExpression = new CallExpression(aggregation.getFunctionCall().getDisplayName(), aggregation.getFunctionHandle(), returnType, aggregation.getArguments(), Optional.empty());
RowExpression expression = rewriter.rewrite(callExpression, context);
return new AggregationNode.Aggregation(aggregation.getFunctionCall(), ((CallExpression) expression).getArguments(), aggregation.isDistinct(), aggregation.getFilter(), aggregation.getOrderingScheme(), aggregation.getMask());
}
Aggregations