Search in sources :

Example 21 with CallExpression

use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.

the class TranslateExpressions method createRewriter.

private static PlanRowExpressionRewriter createRewriter(Metadata metadata, SqlParser sqlParser) {
    return new PlanRowExpressionRewriter() {

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context) {
            // special treatment of the CallExpression in Aggregation
            if (expression instanceof CallExpression && ((CallExpression) expression).getArguments().stream().anyMatch(OriginalExpressionUtils::isExpression)) {
                return removeOriginalExpressionArguments((CallExpression) expression, context.getSession(), context.getSymbolAllocator(), context);
            }
            return removeOriginalExpression(expression, context, new HashMap<>());
        }

        private RowExpression removeOriginalExpressionArguments(CallExpression callExpression, Session session, PlanSymbolAllocator planSymbolAllocator, Rule.Context context) {
            Map<NodeRef<Expression>, Type> types = analyzeCallExpressionTypes(callExpression, session, planSymbolAllocator.getTypes());
            return new CallExpression(callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), callExpression.getArguments().stream().map(expression -> removeOriginalExpression(expression, session, types, context)).collect(toImmutableList()), Optional.empty());
        }

        private Map<NodeRef<Expression>, Type> analyzeCallExpressionTypes(CallExpression callExpression, Session session, TypeProvider typeProvider) {
            List<LambdaExpression> lambdaExpressions = callExpression.getArguments().stream().filter(OriginalExpressionUtils::isExpression).map(OriginalExpressionUtils::castToExpression).filter(LambdaExpression.class::isInstance).map(LambdaExpression.class::cast).collect(toImmutableList());
            ImmutableMap.Builder<NodeRef<Expression>, Type> builder = ImmutableMap.<NodeRef<Expression>, Type>builder();
            TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
            if (!lambdaExpressions.isEmpty()) {
                List<FunctionType> functionTypes = metadata.getFunctionAndTypeManager().getFunctionMetadata(callExpression.getFunctionHandle()).getArgumentTypes().stream().filter(typeSignature -> typeSignature.getBase().equals(FunctionType.NAME)).map(metadata::getType).map(FunctionType.class::cast).collect(toImmutableList());
                InternalAggregationFunction internalAggregationFunction = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(callExpression.getFunctionHandle());
                List<Class<?>> lambdaInterfaces = internalAggregationFunction.getLambdaInterfaces();
                verify(lambdaExpressions.size() == functionTypes.size());
                verify(lambdaExpressions.size() == lambdaInterfaces.size());
                for (int i = 0; i < lambdaExpressions.size(); i++) {
                    LambdaExpression lambdaExpression = lambdaExpressions.get(i);
                    FunctionType functionType = functionTypes.get(i);
                    // To compile lambda, LambdaDefinitionExpression needs to be generated from LambdaExpression,
                    // which requires the types of all sub-expressions.
                    // 
                    // In project and filter expression compilation, ExpressionAnalyzer.getExpressionTypesFromInput
                    // is used to generate the types of all sub-expressions. (see visitScanFilterAndProject and visitFilter)
                    // 
                    // This does not work here since the function call representation in final aggregation node
                    // is currently a hack: it takes intermediate type as input, and may not be a valid
                    // function call in Presto.
                    // 
                    // TODO: Once the final aggregation function call representation is fixed,
                    // the same mechanism in project and filter expression should be used here.
                    verify(lambdaExpression.getArguments().size() == functionType.getArgumentTypes().size());
                    Map<NodeRef<Expression>, Type> lambdaArgumentExpressionTypes = new HashMap<>();
                    Map<Symbol, Type> lambdaArgumentSymbolTypes = new HashMap<>();
                    for (int j = 0; j < lambdaExpression.getArguments().size(); j++) {
                        LambdaArgumentDeclaration argument = lambdaExpression.getArguments().get(j);
                        Type type = functionType.getArgumentTypes().get(j);
                        lambdaArgumentExpressionTypes.put(NodeRef.of(argument), type);
                        lambdaArgumentSymbolTypes.put(new Symbol(argument.getName().getValue()), type);
                    }
                    // the lambda expression itself
                    builder.put(NodeRef.of(lambdaExpression), functionType).putAll(lambdaArgumentExpressionTypes).putAll(typeAnalyzer.getTypes(session, TypeProvider.copyOf(lambdaArgumentSymbolTypes), lambdaExpression.getBody()));
                }
            }
            for (RowExpression argument : callExpression.getArguments()) {
                if (!isExpression(argument) || castToExpression(argument) instanceof LambdaExpression) {
                    continue;
                }
                builder.putAll(typeAnalyzer.getTypes(session, typeProvider, castToExpression(argument)));
            }
            return builder.build();
        }

        private RowExpression toRowExpression(Expression expression, Map<NodeRef<Expression>, Type> types, Map<Symbol, Integer> layout, Session session) {
            RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, FunctionKind.SCALAR, types, layout, metadata.getFunctionAndTypeManager(), session, false);
            return new RowExpressionOptimizer(metadata).optimize(rowExpression, RowExpressionInterpreter.Level.SERIALIZABLE, session.toConnectorSession());
        }

        private RowExpression removeOriginalExpression(RowExpression expression, Rule.Context context, Map<Symbol, Integer> layout) {
            if (isExpression(expression)) {
                TypeAnalyzer typeAnalyzer = new TypeAnalyzer(sqlParser, metadata);
                return toRowExpression(castToExpression(expression), typeAnalyzer.getTypes(context.getSession(), context.getSymbolAllocator().getTypes(), castToExpression(expression)), layout, context.getSession());
            }
            return expression;
        }

        private RowExpression removeOriginalExpression(RowExpression rowExpression, Session session, Map<NodeRef<Expression>, Type> types, Rule.Context context) {
            if (isExpression(rowExpression)) {
                Expression expression = castToExpression(rowExpression);
                return toRowExpression(expression, types, new HashMap<>(), session);
            }
            return rowExpression;
        }
    };
}
Also used : FunctionKind(io.prestosql.spi.function.FunctionKind) TypeProvider(io.prestosql.sql.planner.TypeProvider) HashMap(java.util.HashMap) SqlParser(io.prestosql.sql.parser.SqlParser) TypeAnalyzer(io.prestosql.sql.planner.TypeAnalyzer) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) CallExpression(io.prestosql.spi.relation.CallExpression) Verify.verify(com.google.common.base.Verify.verify) Map(java.util.Map) Session(io.prestosql.Session) Type(io.prestosql.spi.type.Type) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) SqlToRowExpressionTranslator(io.prestosql.sql.relational.SqlToRowExpressionTranslator) Symbol(io.prestosql.spi.plan.Symbol) ImmutableMap(com.google.common.collect.ImmutableMap) Rule(io.prestosql.sql.planner.iterative.Rule) ImmutableList.toImmutableList(com.google.common.collect.ImmutableList.toImmutableList) FunctionType(io.prestosql.spi.type.FunctionType) RowExpressionOptimizer(io.prestosql.sql.relational.RowExpressionOptimizer) Metadata(io.prestosql.metadata.Metadata) LambdaArgumentDeclaration(io.prestosql.sql.tree.LambdaArgumentDeclaration) NodeRef(io.prestosql.sql.tree.NodeRef) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) RowExpressionInterpreter(io.prestosql.sql.planner.RowExpressionInterpreter) List(java.util.List) RowExpression(io.prestosql.spi.relation.RowExpression) Optional(java.util.Optional) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) PlanSymbolAllocator(io.prestosql.sql.planner.PlanSymbolAllocator) RowExpressionOptimizer(io.prestosql.sql.relational.RowExpressionOptimizer) NodeRef(io.prestosql.sql.tree.NodeRef) LambdaArgumentDeclaration(io.prestosql.sql.tree.LambdaArgumentDeclaration) CallExpression(io.prestosql.spi.relation.CallExpression) FunctionType(io.prestosql.spi.type.FunctionType) RowExpression(io.prestosql.spi.relation.RowExpression) TypeProvider(io.prestosql.sql.planner.TypeProvider) ImmutableMap(com.google.common.collect.ImmutableMap) TypeAnalyzer(io.prestosql.sql.planner.TypeAnalyzer) Type(io.prestosql.spi.type.Type) FunctionType(io.prestosql.spi.type.FunctionType) OriginalExpressionUtils.isExpression(io.prestosql.sql.relational.OriginalExpressionUtils.isExpression) CallExpression(io.prestosql.spi.relation.CallExpression) OriginalExpressionUtils.castToExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToExpression) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) RowExpression(io.prestosql.spi.relation.RowExpression) Expression(io.prestosql.sql.tree.Expression) OriginalExpressionUtils(io.prestosql.sql.relational.OriginalExpressionUtils) LambdaExpression(io.prestosql.sql.tree.LambdaExpression) HashMap(java.util.HashMap) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) Session(io.prestosql.Session)

Example 22 with CallExpression

use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.

the class SimplifyCountOverConstant method apply.

@Override
public Result apply(AggregationNode parent, Captures captures, Context context) {
    ProjectNode child = captures.get(CHILD);
    boolean changed = false;
    Map<Symbol, AggregationNode.Aggregation> aggregations = new LinkedHashMap<>(parent.getAggregations());
    for (Entry<Symbol, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
        Symbol symbol = entry.getKey();
        AggregationNode.Aggregation aggregation = entry.getValue();
        if (isCountOverConstant(aggregation, child.getAssignments())) {
            changed = true;
            aggregations.put(symbol, new AggregationNode.Aggregation(new CallExpression("count", functionResolution.countFunction(), BIGINT, ImmutableList.of(), Optional.empty()), ImmutableList.of(), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
        }
    }
    if (!changed) {
        return Result.empty();
    }
    return Result.ofPlanNode(new AggregationNode(parent.getId(), child, aggregations, parent.getGroupingSets(), ImmutableList.of(), parent.getStep(), parent.getHashSymbol(), parent.getGroupIdSymbol(), parent.getAggregationType(), parent.getFinalizeSymbol()));
}
Also used : Symbol(io.prestosql.spi.plan.Symbol) ProjectNode(io.prestosql.spi.plan.ProjectNode) AggregationNode(io.prestosql.spi.plan.AggregationNode) CallExpression(io.prestosql.spi.relation.CallExpression) LinkedHashMap(java.util.LinkedHashMap)

Example 23 with CallExpression

use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.

the class PushPartialAggregationThroughExchange method split.

private PlanNode split(AggregationNode node, Context context) {
    // otherwise, add a partial and final with an exchange in between
    Map<Symbol, AggregationNode.Aggregation> intermediateAggregation = new HashMap<>();
    Map<Symbol, AggregationNode.Aggregation> finalAggregation = new HashMap<>();
    for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
        AggregationNode.Aggregation originalAggregation = entry.getValue();
        String functionName = metadata.getFunctionAndTypeManager().getFunctionMetadata(originalAggregation.getFunctionHandle()).getName().getObjectName();
        FunctionHandle functionHandle = originalAggregation.getFunctionHandle();
        InternalAggregationFunction function = metadata.getFunctionAndTypeManager().getAggregateFunctionImplementation(functionHandle);
        Symbol intermediateSymbol = context.getSymbolAllocator().newSymbol(functionName, function.getIntermediateType());
        checkState(!originalAggregation.getOrderingScheme().isPresent(), "Aggregate with ORDER BY does not support partial aggregation");
        intermediateAggregation.put(intermediateSymbol, new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getIntermediateType(), originalAggregation.getArguments(), Optional.empty()), originalAggregation.getArguments(), originalAggregation.isDistinct(), originalAggregation.getFilter(), originalAggregation.getOrderingScheme(), originalAggregation.getMask()));
        // rewrite final aggregation in terms of intermediate function
        finalAggregation.put(entry.getKey(), new AggregationNode.Aggregation(new CallExpression(functionName, functionHandle, function.getFinalType(), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), Optional.empty()), ImmutableList.<RowExpression>builder().add(new VariableReferenceExpression(intermediateSymbol.getName(), function.getIntermediateType())).addAll(originalAggregation.getArguments().stream().filter(PushPartialAggregationThroughExchange::isLambda).collect(toImmutableList())).build(), false, Optional.empty(), Optional.empty(), Optional.empty()));
    }
    PlanNode partial = new AggregationNode(context.getIdAllocator().getNextId(), node.getSource(), intermediateAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), PARTIAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
    return new AggregationNode(node.getId(), partial, finalAggregation, node.getGroupingSets(), // through the exchange may or may not preserve these properties. Hence, it is safest to drop preGroupedSymbols here.
    ImmutableList.of(), FINAL, node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol());
}
Also used : HashMap(java.util.HashMap) Symbol(io.prestosql.spi.plan.Symbol) RowExpression(io.prestosql.spi.relation.RowExpression) AggregationNode(io.prestosql.spi.plan.AggregationNode) InternalAggregationFunction(io.prestosql.operator.aggregation.InternalAggregationFunction) SystemSessionProperties.preferPartialAggregation(io.prestosql.SystemSessionProperties.preferPartialAggregation) PlanNode(io.prestosql.spi.plan.PlanNode) VariableReferenceExpression(io.prestosql.spi.relation.VariableReferenceExpression) HashMap(java.util.HashMap) Map(java.util.Map) FunctionHandle(io.prestosql.spi.function.FunctionHandle) CallExpression(io.prestosql.spi.relation.CallExpression)

Example 24 with CallExpression

use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.

the class RewriteSpatialPartitioningAggregation method apply.

@Override
public Result apply(AggregationNode node, Captures captures, Context context) {
    ImmutableMap.Builder<Symbol, Aggregation> aggregations = ImmutableMap.builder();
    Symbol partitionCountSymbol = context.getSymbolAllocator().newSymbol("partition_count", INTEGER);
    ImmutableMap.Builder<Symbol, RowExpression> envelopeAssignments = ImmutableMap.builder();
    for (Map.Entry<Symbol, Aggregation> entry : node.getAggregations().entrySet()) {
        Aggregation aggregation = entry.getValue();
        QualifiedObjectName name = metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getFunctionHandle()).getName();
        Type geometryType = metadata.getType(GEOMETRY_TYPE_SIGNATURE);
        if (name.equals(NAME) && aggregation.getArguments().size() == 1) {
            RowExpression geometry = getOnlyElement(aggregation.getArguments().stream().collect(toImmutableList()));
            Symbol envelopeSymbol = context.getSymbolAllocator().newSymbol("envelope", metadata.getType(GEOMETRY_TYPE_SIGNATURE));
            if (isFunctionNameMatch(geometry, "ST_Envelope")) {
                envelopeAssignments.put(envelopeSymbol, geometry);
            } else {
                envelopeAssignments.put(envelopeSymbol, castToRowExpression(new FunctionCallBuilder(metadata).setName(QualifiedName.of("ST_Envelope")).addArgument(GEOMETRY_TYPE_SIGNATURE, castToExpression(geometry)).build()));
            }
            aggregations.put(entry.getKey(), new Aggregation(new CallExpression(name.getObjectName(), metadata.getFunctionAndTypeManager().lookupFunction(NAME.getObjectName(), fromTypes(geometryType, INTEGER)), context.getSymbolAllocator().getTypes().get(entry.getKey()), ImmutableList.of(castToRowExpression(toSymbolReference(envelopeSymbol)), castToRowExpression(toSymbolReference(partitionCountSymbol))), Optional.empty()), ImmutableList.of(castToRowExpression(toSymbolReference(envelopeSymbol)), castToRowExpression(toSymbolReference(partitionCountSymbol))), false, Optional.empty(), Optional.empty(), aggregation.getMask()));
        } else {
            aggregations.put(entry);
        }
    }
    return Result.ofPlanNode(new AggregationNode(node.getId(), new ProjectNode(context.getIdAllocator().getNextId(), node.getSource(), Assignments.builder().putAll(AssignmentUtils.identityAsSymbolReferences(node.getSource().getOutputSymbols())).put(partitionCountSymbol, castToRowExpression(new LongLiteral(Integer.toString(getHashPartitionCount(context.getSession()))))).putAll(envelopeAssignments.build()).build()), aggregations.build(), node.getGroupingSets(), node.getPreGroupedSymbols(), node.getStep(), node.getHashSymbol(), node.getGroupIdSymbol(), node.getAggregationType(), node.getFinalizeSymbol()));
}
Also used : LongLiteral(io.prestosql.sql.tree.LongLiteral) Symbol(io.prestosql.spi.plan.Symbol) OriginalExpressionUtils.castToRowExpression(io.prestosql.sql.relational.OriginalExpressionUtils.castToRowExpression) RowExpression(io.prestosql.spi.relation.RowExpression) AggregationNode(io.prestosql.spi.plan.AggregationNode) ImmutableMap(com.google.common.collect.ImmutableMap) QualifiedObjectName(io.prestosql.spi.connector.QualifiedObjectName) Aggregation(io.prestosql.spi.plan.AggregationNode.Aggregation) Type(io.prestosql.spi.type.Type) ProjectNode(io.prestosql.spi.plan.ProjectNode) Map(java.util.Map) ImmutableMap(com.google.common.collect.ImmutableMap) CallExpression(io.prestosql.spi.relation.CallExpression) FunctionCallBuilder(io.prestosql.sql.planner.FunctionCallBuilder)

Example 25 with CallExpression

use of io.prestosql.spi.relation.CallExpression in project hetu-core by openlookeng.

the class RowExpressionRewriteRuleSet method rewriteAggregation.

private AggregationNode.Aggregation rewriteAggregation(AggregationNode.Aggregation aggregation, Type returnType, Rule.Context context, FunctionAndTypeManager functionAndTypeManager) {
    CallExpression callExpression = new CallExpression(aggregation.getFunctionCall().getDisplayName(), aggregation.getFunctionHandle(), returnType, aggregation.getArguments(), Optional.empty());
    RowExpression expression = rewriter.rewrite(callExpression, context);
    return new AggregationNode.Aggregation(aggregation.getFunctionCall(), ((CallExpression) expression).getArguments(), aggregation.isDistinct(), aggregation.getFilter(), aggregation.getOrderingScheme(), aggregation.getMask());
}
Also used : RowExpression(io.prestosql.spi.relation.RowExpression) CallExpression(io.prestosql.spi.relation.CallExpression)

Aggregations

CallExpression (io.prestosql.spi.relation.CallExpression)96 RowExpression (io.prestosql.spi.relation.RowExpression)51 BuiltInFunctionHandle (io.prestosql.spi.function.BuiltInFunctionHandle)45 Signature (io.prestosql.spi.function.Signature)36 VariableReferenceExpression (io.prestosql.spi.relation.VariableReferenceExpression)29 Test (org.testng.annotations.Test)28 Symbol (io.prestosql.spi.plan.Symbol)27 Type (io.prestosql.spi.type.Type)25 FunctionHandle (io.prestosql.spi.function.FunctionHandle)24 TypeSignature (io.prestosql.spi.type.TypeSignature)24 ArrayList (java.util.ArrayList)23 ConstantExpression (io.prestosql.spi.relation.ConstantExpression)22 AggregationNode (io.prestosql.spi.plan.AggregationNode)20 Map (java.util.Map)20 ImmutableMap (com.google.common.collect.ImmutableMap)17 OperatorType (io.prestosql.spi.function.OperatorType)17 SpecialForm (io.prestosql.spi.relation.SpecialForm)15 Optional (java.util.Optional)15 List (java.util.List)14 ImmutableList (com.google.common.collect.ImmutableList)13