Search in sources :

Example 6 with NodeRef

use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.

the class AbstractOperatorBenchmark method createHashProjectOperator.

protected final OperatorFactory createHashProjectOperator(int operatorId, PlanNodeId planNodeId, List<Type> types) {
    PlanSymbolAllocator planSymbolAllocator = new PlanSymbolAllocator();
    ImmutableMap.Builder<Symbol, Integer> symbolToInputMapping = ImmutableMap.builder();
    ImmutableList.Builder<PageProjection> projections = ImmutableList.builder();
    for (int channel = 0; channel < types.size(); channel++) {
        Symbol symbol = planSymbolAllocator.newSymbol("h" + channel, types.get(channel));
        symbolToInputMapping.put(symbol, channel);
        projections.add(new InputPageProjection(channel, types.get(channel)));
    }
    Map<Symbol, Type> symbolTypes = planSymbolAllocator.getTypes().allTypes();
    Optional<Expression> hashExpression = HashGenerationOptimizer.getHashExpression(localQueryRunner.getMetadata(), planSymbolAllocator, ImmutableList.copyOf(symbolTypes.keySet()));
    verify(hashExpression.isPresent());
    Map<NodeRef<Expression>, Type> expressionTypes = new TypeAnalyzer(localQueryRunner.getSqlParser(), localQueryRunner.getMetadata()).getTypes(session, TypeProvider.copyOf(symbolTypes), hashExpression.get());
    RowExpression translated = translate(hashExpression.get(), SCALAR, expressionTypes, symbolToInputMapping.build(), localQueryRunner.getMetadata().getFunctionAndTypeManager(), session, false);
    PageFunctionCompiler functionCompiler = new PageFunctionCompiler(localQueryRunner.getMetadata(), 0);
    projections.add(functionCompiler.compileProjection(translated, Optional.empty()).get());
    return new FilterAndProjectOperator.FilterAndProjectOperatorFactory(operatorId, planNodeId, () -> new PageProcessor(Optional.empty(), projections.build()), ImmutableList.copyOf(Iterables.concat(types, ImmutableList.of(BIGINT))), getFilterAndProjectMinOutputPageSize(session), getFilterAndProjectMinOutputPageRowCount(session));
}
Also used : PageFunctionCompiler(io.prestosql.sql.gen.PageFunctionCompiler) Symbol(io.prestosql.spi.plan.Symbol) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) ImmutableList(com.google.common.collect.ImmutableList) RowExpression(io.prestosql.spi.relation.RowExpression) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) ImmutableMap(com.google.common.collect.ImmutableMap) TypeAnalyzer(io.prestosql.sql.planner.TypeAnalyzer) PageProjection(io.prestosql.operator.project.PageProjection) InputPageProjection(io.prestosql.operator.project.InputPageProjection) NodeRef(io.prestosql.sql.tree.NodeRef) Type(io.prestosql.spi.type.Type) InputPageProjection(io.prestosql.operator.project.InputPageProjection) PageProcessor(io.prestosql.operator.project.PageProcessor) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression)

Example 7 with NodeRef

use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.

the class TranslateExpressions method createRewriter.

private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
    return new PlanRowExpressionRewriter() {

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context) {
            // special treatment of the CallExpression in Aggregation
            if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
                return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getSymbolAllocator(), context);
            }
            return removeOriginalExpression(expression, context, new HashMap<>());
        }

        private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Rule.Context context) {
            Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, planSymbolAllocator.getTypes());
            return new CallExpression(callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types, context)).collect(toImmutableList()), Optional.empty());
        }

        private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
            List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
            ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
            TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
            if (!lambdaExpressions.isEmpty()) {
                List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(metadata::getType).map(FunctionType.class::cast).collect(toImmutableList());
                InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
                List<Class<?>> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
                verify(lambdaExpressions.size() == functionTypes.size());
                verify(lambdaExpressions.size() == lambdaInterfaces.size());
                for (int i = 0; i < lambdaExpressions.size(); i++) {
                    LambdaExpression lambdaExpression = lambdaExpressions.get(i);
                    FunctionType functionType = functionTypes.get(i);
                    // To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
                    // which requires the types of all sub-expressions.
                    // 
                    // In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
                    // is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
                    // 
                    // This does not work here since the function call representation in final aggregation node
                    // is currently a hack: it takes intermediate type as input, and may not be a valid
                    // function call in Presto.
                    // 
                    // TODO: Once the final aggregation function call representation is fixed,
                    // the same mechanism in project and filter expression should be used here.
                    verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
                    Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
                    Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>();
                    for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
                        LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
                        Type type = functionType.getArgumentTypes().get(j);
                        lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
                        lambdaArgumentSymbolTypes.put(new Symbol(argument.getName().getValue()), type);
                    }
                    // the lambda expression itself
                    builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody()));
                }
            }
            for (RowExpression argument : callExpression.getArguments()) {
                if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
                    continue;
                }
                builder.putAll(typeAnalyzer.getTypes(session, typeProvider, castToExpression(argument)));
            }
            return builder.build();
        }

        private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Session session) {
            RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, FunctionKind.SCALAR, types, layout, metadata.getFunctionAndTypeManager(), session, false);
            return new RowExpressionOptimizer(metadata).optimize(rowExpression, RowExpressionInterpreter.Level.SERIALIZABLE, session.toConnectorSession());
        }

        private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context, Map<Symbol, Integer> layout) {
            if (isExpression(expression)) {
                TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
                return toRowExpression(castToExpression(expression), typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), castToExpression(expression)), layout, context.getSession());
            }
            return expression;
        }

        private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types, Rule.Context context) {
            if (isExpression(rowExpression)) {
                Expression expression = castToExpression(rowExpression);
                return toRowExpression(expression, types, new HashMap<>(), session);
            }
            return rowExpression;
        }
    };
}
Also used : FunctionKind(io.prestosql.spi.function.FunctionKind) TypeProvider(io.prestosql.sql.planner.TypeProvider) HashMap(java.util.HashMap) SqlParser(io.prestosql.sql.parser.SqlParser) TypeAnalyzer(io.prestosql.sql.planner.TypeAnalyzer) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) CallExpression(io.prestosql.spi.relation.CallExpression) Verify.verify(com.google.common.base.Verify.verify) Map(java.util.Map) Session(io.prestosql.Session) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) SqlToRowExpressionTranslator(io.prestosql.sql.relational.SqlToRowExpressionTranslator) Symbol(io.prestosql.spi.plan.Symbol) ImmutableMap(com.google.common.collect.ImmutableMap) Rule(io.prestosql.sql.planner.iterative.Rule) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) FunctionType(io.prestosql.spi.type.FunctionType) RowExpressionOptimizer(io.prestosql.sql.relational.RowExpressionOptimizer) Metadata(io.prestosql.metadata.Metadata) LambdaArgumentDeclaration(io.prestosql.sql.tree.LambdaArgumentDeclaration) NodeRef(io.prestosql.sql.tree.NodeRef) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) RowExpressionInterpreter(io.prestosql.sql.planner.RowExpressionInterpreter) List(java.util.List) RowExpression(io.prestosql.spi.relation.RowExpression) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) RowExpressionOptimizer(io.prestosql.sql.relational.RowExpressionOptimizer) NodeRef(io.prestosql.sql.tree.NodeRef) LambdaArgumentDeclaration(io.prestosql.sql.tree.LambdaArgumentDeclaration) CallExpression(io.prestosql.spi.relation.CallExpression) FunctionType(io.prestosql.spi.type.FunctionType) RowExpression(io.prestosql.spi.relation.RowExpression) TypeProvider(io.prestosql.sql.planner.TypeProvider) ImmutableMap(com.google.common.collect.ImmutableMap) TypeAnalyzer(io.prestosql.sql.planner.TypeAnalyzer) Type(io.prestosql.spi.type.Type) FunctionType(io.prestosql.spi.type.FunctionType) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) Session(io.prestosql.Session)

Example 8 with NodeRef

use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.

the class TestSqlToRowExpressionTranslator method testPossibleExponentialOptimizationTime.

@Test(timeOut = 10_000)
public void testPossibleExponentialOptimizationTime() {
    Expression expression = new LongLiteral("1");
    ImmutableMap.Builder<NodeRef<Expression>, Type> types = ImmutableMap.builder();
    types.put(NodeRef.of(expression), BIGINT);
    for (int i = 0; i < 100; i++) {
        expression = new CoalesceExpression(expression, new LongLiteral("2"));
        types.put(NodeRef.of(expression), BIGINT);
    }
    translateAndOptimize(expression, types.build());
}
Also used : NodeRef(io.prestosql.sql.tree.NodeRef) Type(io.prestosql.spi.type.Type) DecimalType.createDecimalType(io.prestosql.spi.type.DecimalType.createDecimalType) CoalesceExpression(io.prestosql.sql.tree.CoalesceExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) LongLiteral(io.prestosql.sql.tree.LongLiteral) CoalesceExpression(io.prestosql.sql.tree.CoalesceExpression) ImmutableMap(com.google.common.collect.ImmutableMap) Test(org.testng.annotations.Test)

Example 9 with NodeRef

use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.

the class SimplifyExpressions method rewrite.

public static Expression rewrite(Expression inputExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Metadata metadata, LiteralEncoder literalEncoder, TypeAnalyzer typeAnalyzer) {
    Expression expression = inputExpression;
    requireNonNull(metadata, "metadata is null");
    requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    if (expression instanceof SymbolReference) {
        return expression;
    }
    expression = pushDownNegations(expression);
    expression = extractCommonPredicates(expression);
    Map<NodeRef<Expression>, Type> expressionTypes = typeAnalyzer.getTypes(session, planSymbolAllocator.getTypes(), expression);
    ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes);
    Object optimized = interpreter.optimize(NoOpSymbolResolver.INSTANCE);
    return literalEncoder.toExpression(optimized, expressionTypes.get(NodeRef.of(expression)));
}
Also used : NodeRef(io.prestosql.sql.tree.NodeRef) Type(io.prestosql.spi.type.Type) Expression(io.prestosql.sql.tree.Expression) SymbolReference(io.prestosql.sql.tree.SymbolReference) ExpressionInterpreter(io.prestosql.sql.planner.ExpressionInterpreter)

Example 10 with NodeRef

use of io.prestosql.sql.tree.NodeRef in project hetu-core by openlookeng.

the class GroupingOperationRewriter method rewriteGroupingOperation.

public static Expression rewriteGroupingOperation(GroupingOperation expression, List<Set<Integer>> groupingSets, Map<NodeRef<Expression>, FieldId> columnReferenceFields, Optional<Symbol> groupIdSymbol) {
    requireNonNull(groupIdSymbol, "groupIdSymbol is null");
    // See SQL:2011:4.16.2 and SQL:2011:6.9.10.
    if (groupingSets.size() == 1) {
        return new LongLiteral("0");
    } else {
        checkState(groupIdSymbol.isPresent(), "groupId symbol is missing");
        RelationId relationId = columnReferenceFields.get(NodeRef.of(expression.getGroupingColumns().get(0))).getRelationId();
        List<Integer> columns = expression.getGroupingColumns().stream().map(NodeRef::of).peek(groupingColumn -> checkState(columnReferenceFields.containsKey(groupingColumn), "the grouping column is not in the columnReferencesField map")).map(columnReferenceFields::get).map(fieldId -> translateFieldToInteger(fieldId, relationId)).collect(toImmutableList());
        List<Expression> groupingResults = groupingSets.stream().map(groupingSet -> String.valueOf(calculateGrouping(groupingSet, columns))).map(LongLiteral::new).collect(toImmutableList());
        // It is necessary to add a 1 to the groupId because the underlying array is indexed starting at 1
        return new SubscriptExpression(new ArrayConstructor(groupingResults), new ArithmeticBinaryExpression(ADD, toSymbolReference(groupIdSymbol.get()), new GenericLiteral("BIGINT", "1")));
    }
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) GroupingOperation(io.prestosql.sql.tree.GroupingOperation) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) Set(java.util.Set) RelationId(io.prestosql.sql.analyzer.RelationId) NodeRef(io.prestosql.sql.tree.NodeRef) Preconditions.checkState(com.google.common.base.Preconditions.checkState) SymbolUtils.toSymbolReference(io.prestosql.sql.planner.SymbolUtils.toSymbolReference) List(java.util.List) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) LongLiteral(io.prestosql.sql.tree.LongLiteral) ADD(io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD) Map(java.util.Map) Objects.requireNonNull(java.util.Objects.requireNonNull) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor) Optional(java.util.Optional) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) FieldId(io.prestosql.sql.analyzer.FieldId) Expression(io.prestosql.sql.tree.Expression) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) LongLiteral(io.prestosql.sql.tree.LongLiteral) GenericLiteral(io.prestosql.sql.tree.GenericLiteral) NodeRef(io.prestosql.sql.tree.NodeRef) ArithmeticBinaryExpression(io.prestosql.sql.tree.ArithmeticBinaryExpression) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) Expression(io.prestosql.sql.tree.Expression) RelationId(io.prestosql.sql.analyzer.RelationId) SubscriptExpression(io.prestosql.sql.tree.SubscriptExpression) ArrayConstructor(io.prestosql.sql.tree.ArrayConstructor)

Aggregations

NodeRef (io.prestosql.sql.tree.NodeRef)15 Type (io.prestosql.spi.type.Type)14 Expression (io.prestosql.sql.tree.Expression)10 ExpressionInterpreter (io.prestosql.sql.planner.ExpressionInterpreter)7 RowExpression (io.prestosql.spi.relation.RowExpression)6 ImmutableMap (com.google.common.collect.ImmutableMap)4 Symbol (io.prestosql.spi.plan.Symbol)4 ImmutableList.toImmutableList (com.google.common.collect.ImmutableList.toImmutableList)3 OperatorType (io.prestosql.spi.function.OperatorType)3 DecimalType.createDecimalType (io.prestosql.spi.type.DecimalType.createDecimalType)3 LambdaExpression (io.prestosql.sql.tree.LambdaExpression)3 FunctionType (io.prestosql.spi.type.FunctionType)2 RowType (io.prestosql.spi.type.RowType)2 VarbinaryType (io.prestosql.spi.type.VarbinaryType)2 VarcharType (io.prestosql.spi.type.VarcharType)2 VarcharType.createVarcharType (io.prestosql.spi.type.VarcharType.createVarcharType)2 PlanSymbolAllocator (io.prestosql.sql.planner.PlanSymbolAllocator)2 RowExpressionInterpreter (io.prestosql.sql.planner.RowExpressionInterpreter)2 TypeAnalyzer (io.prestosql.sql.planner.TypeAnalyzer)2 CoalesceExpression (io.prestosql.sql.tree.CoalesceExpression)2