use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.
the class AbstractOperatorBenchmark method createHashProjectOperator.
protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNodeId planNodeId, List<Type> types) {
PlanSymbolAllocator planSymbolAllocator = new PlanSymbolAllocator();
ImmutableMap.Builder<Symbol, Integer> symbolToInputMapping = ImmutableMap.builder();
ImmutableList.Builder<PageProjection> projections = ImmutableList.builder();
for (int channel = 0; channel < types.size(); channel++) {
Symbol symbol = planSymbolAllocator.newSymbol("h" + channel, types.get(channel));
symbolToInputMapping.put(symbol, channel);
projections.add(new InputPageProjection(channel, types.get(channel)));
}
Map<Symbol, Type> symbolTypes = planSymbolAllocator.getTypes().allTypes();
Optional<Expression> hashExpression = HashGenerationOptimizer.getHashExpression(localQueryRunner.getMetadata(), planSymbolAllocator, ImmutableList.copyOf(symbolTypes.keySet()));
verify(hashExpression.isPresent());
Map<NodeRef<Expression>, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata()).getTypes(session, TypeProvider.copyOf(symbolTypes), hashExpression.get());
RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionAndTypeManager(), session, false);
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0);
projections.add(functionCompiler.compileProjection(translated, Optional.empty()).get());
return new FilterAndProjectOperator.FilterAndProjectOperatorFactory(operatorId, planNodeId, () -> new PageProcessor(Optional.empty(), projections.build()), ImmutableList.copyOf(Iterables.concat(types, ImmutableList.of(BIGINT))), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session));
}
use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.
the class TranslateExpressions method createRewriter.
private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
return new PlanRowExpressionRewriter() {
@Override
public RowExpression rewrite(RowExpression expression, Rule.Context context) {
// special treatment of the CallExpression in Aggregation
if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getSymbolAllocator(), context);
}
return removeOriginalExpression(expression, context, new HashMap<>());
}
private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Rule.Context context) {
Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, planSymbolAllocator.getTypes());
return new CallExpression(callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types, context)).collect(toImmutableList()), Optional.empty());
}
private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
if (!lambdaExpressions.isEmpty()) {
List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(metadata::getType).map(FunctionType.class::cast).collect(toImmutableList());
InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
List<Class<?>> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
verify(lambdaExpressions.size() == functionTypes.size());
verify(lambdaExpressions.size() == lambdaInterfaces.size());
for (int i = 0; i < lambdaExpressions.size(); i++) {
LambdaExpression lambdaExpression = lambdaExpressions.get(i);
FunctionType functionType = functionTypes.get(i);
// To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
// which requires the types of all sub-expressions.
//
// In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
// is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
//
// This does not work here since the function call representation in final aggregation node
// is currently a hack: it takes intermediate type as input, and may not be a valid
// function call in Presto.
//
// TODO: Once the final aggregation function call representation is fixed,
// the same mechanism in project and filter expression should be used here.
verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>();
for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
Type type = functionType.getArgumentTypes().get(j);
lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
lambdaArgumentSymbolTypes.put(new Symbol(argument.getName().getValue()), type);
}
// the lambda expression itself
builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody()));
}
}
for (RowExpression argument : callExpression.getArguments()) {
if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
continue;
}
builder.putAll(typeAnalyzer.getTypes(session, typeProvider, castToExpression(argument)));
}
return builder.build();
}
private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Session session) {
RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, FunctionKind.SCALAR, types, layout, metadata.getFunctionAndTypeManager(), session, false);
return new RowExpressionOptimizer(metadata).optimize(rowExpression, RowExpressionInterpreter.Level.SERIALIZABLE, session.toConnectorSession());
}
private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context, Map<Symbol, Integer> layout) {
if (isExpression(expression)) {
TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
return toRowExpression(castToExpression(expression), typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), castToExpression(expression)), layout, context.getSession());
}
return expression;
}
private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types, Rule.Context context) {
if (isExpression(rowExpression)) {
Expression expression = castToExpression(rowExpression);
return toRowExpression(expression, types, new HashMap<>(), session);
}
return rowExpression;
}
};
}
use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.
the class TestSqlToRowExpressionTranslator method testPossibleExponentialOptimizationTime.
@Test(timeOut = 10_000)
public void testPossibleExponentialOptimizationTime() {
Expression expression = new LongLiteral("1");
ImmutableMap.Builder<NodeRef<Expression>, Type> types = ImmutableMap.builder();
types.put(NodeRef.of(expression), BIGINT);
for (int i = 0; i < 100; i++) {
expression = new CoalesceExpression(expression, new LongLiteral("2"));
types.put(NodeRef.of(expression), BIGINT);
}
translateAndOptimize(expression, types.build());
}
use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.
the class SimplifyExpressions method rewrite.
public static Expression rewrite(Expression inputExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, TypeAnalyzer typeAnalyzer) {
Expression expression = inputExpression;
requireNonNull(metadata, "metadata is null");
requireNonNull(typeAnalyzer, "typeAnalyzer is null");
if (expression instanceof SymbolReference) {
return expression;
}
expression = pushDownNegations(expression);
expression = extractCommonPredicates(expression);
Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, planSymbolAllocator.getTypes(), expression);
ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes);
Object optimized = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
return literalEncoder.toExpression(optimized, expressionTypes.get(NodeRef.of(expression)));
}
use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.
the class GroupingOperationRewriter method rewriteGroupingOperation.
public static Expression rewriteGroupingOperation(GroupingOperation expression, List<Set<Integer>> groupingSets, Map<NodeRef<Expression>, FieldId> columnReferenceFields, Optional<Symbol> groupIdSymbol) {
requireNonNull(groupIdSymbol, "groupIdSymbol is null");
// See SQL:2011:4.16.2 and SQL:2011:6.9.10.
if (groupingSets.size() == 1) {
return new LongLiteral("0");
} else {
checkState(groupIdSymbol.isPresent(), "groupId symbol is missing");
RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId();
List<Integer> columns = expression.getGroupingColumns().stream().map(NodeRef::of).peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")).map(columnReferenceFields::get).map(fieldId -> translateFieldToInteger(fieldId, relationId)).collect(toImmutableList());
List<Expression> groupingResults = groupingSets.stream().map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))).map(LongLiteral::new).collect(toImmutableList());
// It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1
return new SubscriptExpression(new ArrayConstructor(groupingResults), new ArithmeticBinaryExpression(ADD, toSymbolReference(groupIdSymbol.get()), new GenericLiteral("BIGINT", "1")));
}
}
Aggregations