Search in sources :

Example 46 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    final Map<ImmutableBitSet, Integer> groupSetToDistinctAggCallFilterArg = new HashMap<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        if (!aggCall.isDistinct()) {
            groupSetTreeSet.add(aggregate.getGroupSet());
        } else {
            ImmutableBitSet groupSet = ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet());
            groupSetToDistinctAggCallFilterArg.put(groupSet, aggCall.filterArg);
            groupSetTreeSet.add(groupSet);
        }
    }
    final com.google.common.collect.ImmutableList<ImmutableBitSet> groupSets = com.google.common.collect.ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            AggregateCall newAggCall = aggCall.left.adaptTo(aggregate.getInput(), aggCall.left.getArgList(), aggCall.left.filterArg, aggregate.getGroupCount(), fullGroupSet.cardinality());
            distinctAggCalls.add(newAggCall.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    final int groupCount = fullGroupSet.cardinality();
    final Map<ImmutableBitSet, Integer> filters = new LinkedHashMap<>();
    final int z = groupCount + distinctAggCalls.size();
    distinctAggCalls.add(AggregateCall.create(SqlStdOperatorTable.GROUPING, false, false, false, ImmutableIntList.copyOf(fullGroupSet), -1, RelCollations.EMPTY, groupSets.size(), relBuilder.peek(), null, "$g"));
    for (Ord<ImmutableBitSet> groupSet : Ord.zip(groupSets)) {
        filters.put(groupSet.e, z + groupSet.i);
    }
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    // values to BOOLEAN.
    if (!filters.isEmpty()) {
        final List<RexNode> nodes = new ArrayList<>(relBuilder.fields());
        final RexNode nodeZ = nodes.remove(nodes.size() - 1);
        for (Map.Entry<ImmutableBitSet, Integer> entry : filters.entrySet()) {
            final long v = groupValue(fullGroupSet, entry.getKey());
            // Get and remap the filterArg of the distinct aggregate call.
            int distinctAggCallFilterArg = remap(fullGroupSet, groupSetToDistinctAggCallFilterArg.getOrDefault(entry.getKey(), -1));
            RexNode expr;
            if (distinctAggCallFilterArg < 0) {
                expr = relBuilder.equals(nodeZ, relBuilder.literal(v));
            } else {
                RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
                // merge the filter of the distinct aggregate call itself.
                expr = relBuilder.and(relBuilder.equals(nodeZ, relBuilder.literal(v)), rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, relBuilder.field(distinctAggCallFilterArg)));
            }
            nodes.add(relBuilder.alias(expr, "$g_" + v));
        }
        relBuilder.project(nodes);
    }
    int aggCallIdx = 0;
    int x = groupCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    // TODO supports more aggCalls (currently only supports COUNT)
    // Some aggregate functions (e.g. COUNT) have the special property that they can return a
    // non-null result without any input. We need to make sure we return a result in this case.
    final List<Integer> needDefaultValueAggCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
            switch(aggCall.getAggregation().getKind()) {
                case COUNT:
                    needDefaultValueAggCalls.add(aggCallIdx);
                    break;
                default:
            }
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, aggCall.isApproximate(), false, newArgList, newFilterArg, RelCollations.EMPTY, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
        aggCallIdx++;
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    if (!needDefaultValueAggCalls.isEmpty() && aggregate.getGroupCount() == 0) {
        final Aggregate newAgg = (Aggregate) relBuilder.peek();
        final List<RexNode> nodes = new ArrayList<>();
        for (int i = 0; i < newAgg.getGroupCount(); ++i) {
            nodes.add(RexInputRef.of(i, newAgg.getRowType()));
        }
        for (int i = 0; i < newAgg.getAggCallList().size(); ++i) {
            final RexNode inputRef = RexInputRef.of(newAgg.getGroupCount() + i, newAgg.getRowType());
            RexNode newNode = inputRef;
            if (needDefaultValueAggCalls.contains(i)) {
                SqlKind originalFunKind = aggregate.getAggCallList().get(i).getAggregation().getKind();
                switch(originalFunKind) {
                    case COUNT:
                        newNode = relBuilder.call(SqlStdOperatorTable.CASE, relBuilder.isNotNull(inputRef), inputRef, relBuilder.literal(BigDecimal.ZERO));
                        break;
                    default:
                }
            }
            nodes.add(newNode);
        }
        relBuilder.project(nodes);
    }
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) TreeSet(java.util.TreeSet) RexBuilder(org.apache.calcite.rex.RexBuilder) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlKind(org.apache.calcite.sql.SqlKind) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) HashMap(java.util.HashMap) LinkedHashMap(java.util.LinkedHashMap) Map(java.util.Map) RexNode(org.apache.calcite.rex.RexNode)

Example 47 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by apache.

the class DrillReduceAggregatesRule method reduceAvg.

private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int iAvgInput = oldCall.getArgList().get(0);
    RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
    RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.of()).inferReturnType(oldCall.createBinding(oldAggRel));
    sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
    SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
    AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
    RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
    // NOTE:  these references are with respect to the output
    // of newAggRel
    /*
    RexNode numeratorRef =
        rexBuilder.makeCall(CastHighOp,
          rexBuilder.addAggCall(
              sumCall,
              nGroups,
              newCalls,
              aggCallMapping,
              ImmutableList.of(avgInputType))
        );
    */
    RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
    RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    if (isInferenceEnabled) {
        return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
    } else {
        final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
    }
}
Also used : PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) DrillCalciteSqlAggFunctionWrapper(org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode)

Example 48 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by apache.

the class DrillOperatorTable method populateWrappedCalciteOperators.

private void populateWrappedCalciteOperators() {
    for (SqlOperator calciteOperator : inner.getOperatorList()) {
        final SqlOperator wrapper;
        if (calciteOperator instanceof SqlAggFunction) {
            wrapper = new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction) calciteOperator, getFunctionListWithInference(calciteOperator.getName()));
        } else if (calciteOperator instanceof SqlFunction) {
            wrapper = new DrillCalciteSqlFunctionWrapper((SqlFunction) calciteOperator, getFunctionListWithInference(calciteOperator.getName()));
        } else if (calciteOperator instanceof SqlBetweenOperator) {
            // During the procedure of converting to RexNode,
            // StandardConvertletTable.convertBetween expects the SqlOperator to be a subclass of SqlBetweenOperator
            final SqlBetweenOperator sqlBetweenOperator = (SqlBetweenOperator) calciteOperator;
            wrapper = new DrillCalciteSqlBetweenOperatorWrapper(sqlBetweenOperator);
        } else {
            final String drillOpName;
            // Otherwise, Calcite will mix them up with binary operator subtract (-) or add (+)
            if (calciteOperator == SqlStdOperatorTable.UNARY_MINUS || calciteOperator == SqlStdOperatorTable.UNARY_PLUS) {
                drillOpName = calciteOperator.getName();
            } else {
                drillOpName = FunctionCallFactory.convertToDrillFunctionName(calciteOperator.getName());
            }
            final List<DrillFuncHolder> drillFuncHolders = getFunctionListWithInference(drillOpName);
            if (drillFuncHolders.isEmpty()) {
                continue;
            }
            wrapper = new DrillCalciteSqlOperatorWrapper(calciteOperator, drillOpName, drillFuncHolders);
        }
        calciteToWrapper.put(calciteOperator, wrapper);
    }
}
Also used : DrillFuncHolder(org.apache.drill.exec.expr.fn.DrillFuncHolder) SqlOperator(org.apache.calcite.sql.SqlOperator) SqlBetweenOperator(org.apache.calcite.sql.fun.SqlBetweenOperator) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlFunction(org.apache.calcite.sql.SqlFunction)

Aggregations

SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)47 RelDataType (org.apache.calcite.rel.type.RelDataType)30 ArrayList (java.util.ArrayList)25 RexNode (org.apache.calcite.rex.RexNode)24 AggregateCall (org.apache.calcite.rel.core.AggregateCall)20 RexBuilder (org.apache.calcite.rex.RexBuilder)15 RelNode (org.apache.calcite.rel.RelNode)13 List (java.util.List)11 Aggregate (org.apache.calcite.rel.core.Aggregate)10 RelBuilder (org.apache.calcite.tools.RelBuilder)10 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)9 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)9 HashMap (java.util.HashMap)8 ImmutableList (com.google.common.collect.ImmutableList)7 SqlOperator (org.apache.calcite.sql.SqlOperator)6 Type (java.lang.reflect.Type)5 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)5 RexInputRef (org.apache.calcite.rex.RexInputRef)5 SqlKind (org.apache.calcite.sql.SqlKind)5 Mappings (org.apache.calcite.util.mapping.Mappings)5