Search in sources :

Example 11 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by axbaretto.

the class DrillReduceAggregatesRule method reduceAvg.

private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
    final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
    final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
    final int nGroups = oldAggRel.getGroupCount();
    RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
    RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
    int iAvgInput = oldCall.getArgList().get(0);
    RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
    RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.<DrillFuncHolder>of()).inferReturnType(oldCall.createBinding(oldAggRel));
    sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
    SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
    AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
    final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
    final RelDataType countType = countAgg.getReturnType(typeFactory);
    AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
    RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
    // NOTE:  these references are with respect to the output
    // of newAggRel
    /*
    RexNode numeratorRef =
        rexBuilder.makeCall(CastHighOp,
          rexBuilder.addAggCall(
              sumCall,
              nGroups,
              newCalls,
              aggCallMapping,
              ImmutableList.of(avgInputType))
        );
    */
    RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
    RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
    if (isInferenceEnabled) {
        return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
    } else {
        final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
        return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
    }
}
Also used : DrillFuncHolder(org.apache.drill.exec.expr.fn.DrillFuncHolder) PlannerSettings(org.apache.drill.exec.planner.physical.PlannerSettings) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) DrillCalciteSqlAggFunctionWrapper(org.apache.drill.exec.planner.sql.DrillCalciteSqlAggFunctionWrapper) DrillSqlOperator(org.apache.drill.exec.planner.sql.DrillSqlOperator) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RexNode(org.apache.calcite.rex.RexNode)

Example 12 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by axbaretto.

the class DrillReduceAggregatesRule method reduceAgg.

private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
    if (sqlAggFunction instanceof SqlSumAggFunction) {
        // case COUNT(x) when 0 then null else SUM0(x) end
        return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
    }
    if (sqlAggFunction instanceof SqlAvgAggFunction) {
        final SqlKind subtype = sqlAggFunction.getKind();
        switch(subtype) {
            case AVG:
                // replace original AVG(x) with SUM(x) / COUNT(x)
                return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
            case STDDEV_POP:
                // / COUNT(x))
                return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
            case STDDEV_SAMP:
                // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
                return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
            case VAR_POP:
                // / COUNT(x)
                return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
            case VAR_SAMP:
                // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
                return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
            default:
                throw Util.unexpected(subtype);
        }
    } else {
        // anything else:  preserve original call
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        final int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = new ArrayList<>();
        List<Integer> ordinals = oldCall.getArgList();
        assert ordinals.size() <= inputExprs.size();
        for (int ordinal : ordinals) {
            oldArgTypes.add(inputExprs.get(ordinal).getType());
        }
        return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
    }
}
Also used : SqlAvgAggFunction(org.apache.calcite.sql.fun.SqlAvgAggFunction) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlKind(org.apache.calcite.sql.SqlKind)

Example 13 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by axbaretto.

the class DrillOperatorTable method populateWrappedCalciteOperators.

private void populateWrappedCalciteOperators() {
    for (SqlOperator calciteOperator : inner.getOperatorList()) {
        final SqlOperator wrapper;
        if (calciteOperator instanceof SqlAggFunction) {
            wrapper = new DrillCalciteSqlAggFunctionWrapper((SqlAggFunction) calciteOperator, getFunctionListWithInference(calciteOperator.getName()));
        } else if (calciteOperator instanceof SqlFunction) {
            wrapper = new DrillCalciteSqlFunctionWrapper((SqlFunction) calciteOperator, getFunctionListWithInference(calciteOperator.getName()));
        } else {
            final String drillOpName = FunctionCallFactory.replaceOpWithFuncName(calciteOperator.getName());
            final List<DrillFuncHolder> drillFuncHolders = getFunctionListWithInference(drillOpName);
            if (drillFuncHolders.isEmpty() || calciteOperator == SqlStdOperatorTable.UNARY_MINUS || calciteOperator == SqlStdOperatorTable.UNARY_PLUS) {
                continue;
            }
            wrapper = new DrillCalciteSqlOperatorWrapper(calciteOperator, drillOpName, drillFuncHolders);
        }
        calciteToWrapper.put(calciteOperator, wrapper);
    }
}
Also used : SqlOperator(org.apache.calcite.sql.SqlOperator) List(java.util.List) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlFunction(org.apache.calcite.sql.SqlFunction)

Example 14 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by apache.

the class WindowPrule method onMatch.

@Override
public void onMatch(RelOptRuleCall call) {
    final DrillWindowRel window = call.rel(0);
    RelNode input = call.rel(1);
    // TODO: Order window based on existing partition by
    // input.getTraitSet().subsumes()
    boolean partitionby = false;
    boolean addMerge = false;
    // The start index of the constant fields of DrillWindowRel
    final int startConstantsIndex = window.getInput().getRowType().getFieldCount();
    int constantShiftIndex = 0;
    for (final Ord<Window.Group> w : Ord.zip(window.groups)) {
        Window.Group windowBase = w.getValue();
        RelTraitSet traits = call.getPlanner().emptyTraitSet().plus(Prel.DRILL_PHYSICAL);
        // For empty Over-Clause
        if (windowBase.keys.isEmpty() && windowBase.orderKeys.getFieldCollations().isEmpty()) {
            DrillDistributionTrait distEmptyKeys = new DrillDistributionTrait(DrillDistributionTrait.DistributionType.SINGLETON);
            traits = traits.plus(distEmptyKeys);
        } else if (windowBase.keys.size() > 0) {
            DrillDistributionTrait distOnAllKeys = new DrillDistributionTrait(DrillDistributionTrait.DistributionType.HASH_DISTRIBUTED, ImmutableList.copyOf(getDistributionFields(windowBase)));
            partitionby = true;
            traits = traits.plus(distOnAllKeys);
        } else if (windowBase.orderKeys.getFieldCollations().size() > 0) {
            // if only the order-by clause is specified, there is a single partition
            // consisting of all the rows, so we do a distributed sort followed by a
            // single merge as the input of the window operator
            DrillDistributionTrait distKeys = new DrillDistributionTrait(DrillDistributionTrait.DistributionType.HASH_DISTRIBUTED, ImmutableList.copyOf(getDistributionFieldsFromCollation(windowBase)));
            traits = traits.plus(distKeys);
            if (!isSingleMode(call)) {
                addMerge = true;
            }
        }
        // Add collation trait if either partition-by or order-by is specified.
        if (partitionby || windowBase.orderKeys.getFieldCollations().size() > 0) {
            RelCollation collation = getCollation(windowBase);
            traits = traits.plus(collation);
        }
        RelNode convertedInput = convert(input, traits);
        if (addMerge) {
            traits = traits.plus(DrillDistributionTrait.SINGLETON);
            convertedInput = new SingleMergeExchangePrel(window.getCluster(), traits, convertedInput, windowBase.collation());
        }
        List<RelDataTypeField> newRowFields = Lists.newArrayList();
        newRowFields.addAll(convertedInput.getRowType().getFieldList());
        Iterable<RelDataTypeField> newWindowFields = Iterables.filter(window.getRowType().getFieldList(), new Predicate<RelDataTypeField>() {

            @Override
            public boolean apply(RelDataTypeField relDataTypeField) {
                return relDataTypeField.getName().startsWith("w" + w.i + "$");
            }
        });
        for (RelDataTypeField newField : newWindowFields) {
            newRowFields.add(newField);
        }
        RelDataType rowType = new RelRecordType(newRowFields);
        List<Window.RexWinAggCall> newWinAggCalls = Lists.newArrayList();
        for (Ord<Window.RexWinAggCall> aggOrd : Ord.zip(windowBase.aggCalls)) {
            Window.RexWinAggCall aggCall = aggOrd.getValue();
            // If the argument points at the constant and
            // additional fields have been generated by the Window below,
            // the index of constants will be shifted
            final List<RexNode> newOperandsOfWindowFunction = Lists.newArrayList();
            for (RexNode operand : aggCall.getOperands()) {
                if (operand instanceof RexInputRef) {
                    final RexInputRef rexInputRef = (RexInputRef) operand;
                    final int refIndex = rexInputRef.getIndex();
                    // Check if this RexInputRef points at the constants
                    if (rexInputRef.getIndex() >= startConstantsIndex) {
                        operand = new RexInputRef(refIndex + constantShiftIndex, window.constants.get(refIndex - startConstantsIndex).getType());
                    }
                }
                newOperandsOfWindowFunction.add(operand);
            }
            aggCall = new Window.RexWinAggCall((SqlAggFunction) aggCall.getOperator(), aggCall.getType(), newOperandsOfWindowFunction, aggCall.ordinal, aggCall.distinct);
            newWinAggCalls.add(new Window.RexWinAggCall((SqlAggFunction) aggCall.getOperator(), aggCall.getType(), aggCall.getOperands(), aggOrd.i, aggCall.distinct));
        }
        windowBase = new Window.Group(windowBase.keys, windowBase.isRows, windowBase.lowerBound, windowBase.upperBound, windowBase.orderKeys, newWinAggCalls);
        input = new WindowPrel(window.getCluster(), window.getTraitSet().merge(traits), convertedInput, window.getConstants(), rowType, windowBase);
        constantShiftIndex += windowBase.aggCalls.size();
    }
    call.transformTo(input);
}
Also used : RelDataType(org.apache.calcite.rel.type.RelDataType) RelTraitSet(org.apache.calcite.plan.RelTraitSet) DrillWindowRel(org.apache.drill.exec.planner.logical.DrillWindowRel) Window(org.apache.calcite.rel.core.Window) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) RelRecordType(org.apache.calcite.rel.type.RelRecordType) RelCollation(org.apache.calcite.rel.RelCollation) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 15 with SqlAggFunction

use of org.apache.calcite.sql.SqlAggFunction in project drill by apache.

the class DrillReduceAggregatesRule method reduceAgg.

private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
    final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
    if (sqlAggFunction instanceof SqlSumAggFunction) {
        // case COUNT(x) when 0 then null else SUM0(x) end
        return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
    }
    if (sqlAggFunction instanceof SqlAvgAggFunction) {
        // causes the loss of the scale
        if (oldCall.getType().getSqlTypeName() == SqlTypeName.DECIMAL) {
            return oldAggRel.getCluster().getRexBuilder().addAggCall(oldCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, ImmutableList.of(getFieldType(oldAggRel.getInput(), oldCall.getArgList().get(0))));
        }
        final SqlKind subtype = sqlAggFunction.getKind();
        switch(subtype) {
            case AVG:
                // replace original AVG(x) with SUM(x) / COUNT(x)
                return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
            case STDDEV_POP:
                // / COUNT(x))
                return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
            case STDDEV_SAMP:
                // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
                return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
            case VAR_POP:
                // / COUNT(x)
                return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
            case VAR_SAMP:
                // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
                return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
            default:
                throw Util.unexpected(subtype);
        }
    } else {
        // anything else:  preserve original call
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        final int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = new ArrayList<>();
        List<Integer> ordinals = oldCall.getArgList();
        assert ordinals.size() <= inputExprs.size();
        for (int ordinal : ordinals) {
            oldArgTypes.add(inputExprs.get(ordinal).getType());
        }
        // different RelDataTypes
        if (aggCallMapping.containsKey(oldCall) && !aggCallMapping.get(oldCall).getType().equals(oldCall.getType())) {
            int index = newCalls.size() + nGroups;
            newCalls.add(oldCall);
            return rexBuilder.makeInputRef(oldCall.getType(), index);
        }
        return rexBuilder.addAggCall(oldCall, nGroups, newCalls, aggCallMapping, oldArgTypes);
    }
}
Also used : SqlAvgAggFunction(org.apache.calcite.sql.fun.SqlAvgAggFunction) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) ArrayList(java.util.ArrayList) RexBuilder(org.apache.calcite.rex.RexBuilder) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlKind(org.apache.calcite.sql.SqlKind)

Aggregations

SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)43 RelDataType (org.apache.calcite.rel.type.RelDataType)30 ArrayList (java.util.ArrayList)22 RexNode (org.apache.calcite.rex.RexNode)22 AggregateCall (org.apache.calcite.rel.core.AggregateCall)17 RexBuilder (org.apache.calcite.rex.RexBuilder)13 List (java.util.List)11 RelNode (org.apache.calcite.rel.RelNode)11 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)9 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)8 ImmutableList (com.google.common.collect.ImmutableList)7 HashMap (java.util.HashMap)7 Aggregate (org.apache.calcite.rel.core.Aggregate)7 RelBuilder (org.apache.calcite.tools.RelBuilder)7 SqlOperator (org.apache.calcite.sql.SqlOperator)6 Type (java.lang.reflect.Type)5 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)5 RexInputRef (org.apache.calcite.rex.RexInputRef)5 Mappings (org.apache.calcite.util.mapping.Mappings)5 Join (org.apache.calcite.rel.core.Join)4