Search in sources :

Example 41 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project calcite by apache.

the class AggregateFilterTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final Filter filter = call.rel(1);
    // Do the columns used by the filter appear in the output of the aggregate?
    final ImmutableBitSet filterColumns = RelOptUtil.InputFinder.bits(filter.getCondition());
    final ImmutableBitSet newGroupSet = aggregate.getGroupSet().union(filterColumns);
    final RelNode input = filter.getInput();
    final RelMetadataQuery mq = call.getMetadataQuery();
    final Boolean unique = mq.areColumnsUnique(input, newGroupSet);
    if (unique != null && unique) {
        // the rule fires forever: A-F => A-F-A => A-A-F-A => A-A-A-F-A => ...
        return;
    }
    boolean allColumnsInAggregate = aggregate.getGroupSet().contains(filterColumns);
    final Aggregate newAggregate = aggregate.copy(aggregate.getTraitSet(), input, false, newGroupSet, null, aggregate.getAggCallList());
    final Mappings.TargetMapping mapping = Mappings.target(new Function<Integer, Integer>() {

        public Integer apply(Integer a0) {
            return newGroupSet.indexOf(a0);
        }
    }, input.getRowType().getFieldCount(), newGroupSet.cardinality());
    final RexNode newCondition = RexUtil.apply(mapping, filter.getCondition());
    final Filter newFilter = filter.copy(filter.getTraitSet(), newAggregate, newCondition);
    if (allColumnsInAggregate && aggregate.getGroupType() == Group.SIMPLE) {
        // Everything needed by the filter is returned by the aggregate.
        assert newGroupSet.equals(aggregate.getGroupSet());
        call.transformTo(newFilter);
    } else {
        // If aggregate uses grouping sets, we always need to split it.
        // Otherwise, it means that grouping sets are not used, but the
        // filter needs at least one extra column, and now aggregate it away.
        final ImmutableBitSet.Builder topGroupSet = ImmutableBitSet.builder();
        for (int c : aggregate.getGroupSet()) {
            topGroupSet.set(newGroupSet.indexOf(c));
        }
        ImmutableList<ImmutableBitSet> newGroupingSets = null;
        if (aggregate.getGroupType() != Group.SIMPLE) {
            ImmutableList.Builder<ImmutableBitSet> newGroupingSetsBuilder = ImmutableList.builder();
            for (ImmutableBitSet groupingSet : aggregate.getGroupSets()) {
                final ImmutableBitSet.Builder newGroupingSet = ImmutableBitSet.builder();
                for (int c : groupingSet) {
                    newGroupingSet.set(newGroupSet.indexOf(c));
                }
                newGroupingSetsBuilder.add(newGroupingSet.build());
            }
            newGroupingSets = newGroupingSetsBuilder.build();
        }
        final List<AggregateCall> topAggCallList = Lists.newArrayList();
        int i = newGroupSet.cardinality();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            final SqlAggFunction rollup = SubstitutionVisitor.getRollup(aggregateCall.getAggregation());
            if (rollup == null) {
                // This aggregate cannot be rolled up.
                return;
            }
            if (aggregateCall.isDistinct()) {
                // Cannot roll up distinct.
                return;
            }
            topAggCallList.add(AggregateCall.create(rollup, aggregateCall.isDistinct(), aggregateCall.isApproximate(), ImmutableList.of(i++), -1, aggregateCall.type, aggregateCall.name));
        }
        final Aggregate topAggregate = aggregate.copy(aggregate.getTraitSet(), newFilter, aggregate.indicator, topGroupSet.build(), newGroupingSets, topAggCallList);
        call.transformTo(topAggregate);
    }
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ImmutableList(com.google.common.collect.ImmutableList) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) Filter(org.apache.calcite.rel.core.Filter) Aggregate(org.apache.calcite.rel.core.Aggregate) RexNode(org.apache.calcite.rex.RexNode)

Example 42 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project druid by druid-io.

the class BaseVarianceSqlAggregator method toDruidAggregation.

@Nullable
@Override
public Aggregation toDruidAggregation(PlannerContext plannerContext, RowSignature rowSignature, VirtualColumnRegistry virtualColumnRegistry, RexBuilder rexBuilder, String name, AggregateCall aggregateCall, Project project, List<Aggregation> existingAggregations, boolean finalizeAggregations) {
    final RexNode inputOperand = Expressions.fromFieldAccess(rowSignature, project, aggregateCall.getArgList().get(0));
    final DruidExpression input = Aggregations.toDruidExpressionForNumericAggregator(plannerContext, rowSignature, inputOperand);
    if (input == null) {
        return null;
    }
    final AggregatorFactory aggregatorFactory;
    final RelDataType dataType = inputOperand.getType();
    final ColumnType inputType = Calcites.getColumnTypeForRelDataType(dataType);
    final DimensionSpec dimensionSpec;
    final String aggName = StringUtils.format("%s:agg", name);
    final SqlAggFunction func = calciteFunction();
    final String estimator;
    final String inputTypeName;
    PostAggregator postAggregator = null;
    if (input.isSimpleExtraction()) {
        dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
    } else {
        String virtualColumnName = virtualColumnRegistry.getOrCreateVirtualColumnForExpression(input, dataType);
        dimensionSpec = new DefaultDimensionSpec(virtualColumnName, null, inputType);
    }
    if (inputType == null) {
        throw new IAE("VarianceSqlAggregator[%s] has invalid inputType", func);
    }
    if (inputType.isNumeric()) {
        inputTypeName = StringUtils.toLowerCase(inputType.getType().name());
    } else {
        throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType.asTypeString());
    }
    if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
        estimator = "population";
    } else {
        estimator = "sample";
    }
    aggregatorFactory = new VarianceAggregatorFactory(aggName, dimensionSpec.getDimension(), estimator, inputTypeName);
    if (func == SqlStdOperatorTable.STDDEV_POP || func == SqlStdOperatorTable.STDDEV_SAMP || func == SqlStdOperatorTable.STDDEV) {
        postAggregator = new StandardDeviationPostAggregator(name, aggregatorFactory.getName(), estimator);
    }
    return Aggregation.create(ImmutableList.of(aggregatorFactory), postAggregator);
}
Also used : DefaultDimensionSpec(org.apache.druid.query.dimension.DefaultDimensionSpec) DimensionSpec(org.apache.druid.query.dimension.DimensionSpec) ColumnType(org.apache.druid.segment.column.ColumnType) PostAggregator(org.apache.druid.query.aggregation.PostAggregator) StandardDeviationPostAggregator(org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator) StandardDeviationPostAggregator(org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) VarianceAggregatorFactory(org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory) AggregatorFactory(org.apache.druid.query.aggregation.AggregatorFactory) IAE(org.apache.druid.java.util.common.IAE) DefaultDimensionSpec(org.apache.druid.query.dimension.DefaultDimensionSpec) VarianceAggregatorFactory(org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory) DruidExpression(org.apache.druid.sql.calcite.expression.DruidExpression) RexNode(org.apache.calcite.rex.RexNode) Nullable(javax.annotation.Nullable)

Example 43 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.

the class OverConvertRule method convert.

@Override
public Optional<RexNode> convert(CallExpression call, ConvertContext context) {
    List<Expression> children = call.getChildren();
    if (call.getFunctionDefinition() == BuiltInFunctionDefinitions.OVER) {
        FlinkTypeFactory typeFactory = context.getTypeFactory();
        Expression agg = children.get(0);
        FunctionDefinition def = ((CallExpression) agg).getFunctionDefinition();
        boolean isDistinct = BuiltInFunctionDefinitions.DISTINCT == def;
        SqlAggFunction aggFunc = agg.accept(new SqlAggFunctionVisitor(context.getRelBuilder()));
        RelDataType aggResultType = typeFactory.createFieldTypeFromLogicalType(fromDataTypeToLogicalType(((ResolvedExpression) agg).getOutputDataType()));
        // assemble exprs by agg children
        List<RexNode> aggExprs = agg.getChildren().stream().map(child -> {
            if (isDistinct) {
                return context.toRexNode(child.getChildren().get(0));
            } else {
                return context.toRexNode(child);
            }
        }).collect(Collectors.toList());
        // assemble order by key
        Expression orderKeyExpr = children.get(1);
        Set<SqlKind> kinds = new HashSet<>();
        RexNode collationRexNode = createCollation(context.toRexNode(orderKeyExpr), RelFieldCollation.Direction.ASCENDING, null, kinds);
        ImmutableList<RexFieldCollation> orderKey = ImmutableList.of(new RexFieldCollation(collationRexNode, kinds));
        // assemble partition by keys
        List<RexNode> partitionKeys = children.subList(4, children.size()).stream().map(context::toRexNode).collect(Collectors.toList());
        // assemble bounds
        Expression preceding = children.get(2);
        boolean isPhysical = fromDataTypeToLogicalType(((ResolvedExpression) preceding).getOutputDataType()).is(LogicalTypeRoot.BIGINT);
        Expression following = children.get(3);
        RexWindowBound lowerBound = createBound(context, preceding, SqlKind.PRECEDING);
        RexWindowBound upperBound = createBound(context, following, SqlKind.FOLLOWING);
        // build RexOver
        return Optional.of(context.getRelBuilder().getRexBuilder().makeOver(aggResultType, aggFunc, aggExprs, partitionKeys, orderKey, lowerBound, upperBound, isPhysical, true, false, isDistinct));
    }
    return Optional.empty();
}
Also used : SqlPostfixOperator(org.apache.calcite.sql.SqlPostfixOperator) CallExpression(org.apache.flink.table.expressions.CallExpression) FlinkPlannerImpl(org.apache.flink.table.planner.calcite.FlinkPlannerImpl) OrdinalReturnTypeInference(org.apache.calcite.sql.type.OrdinalReturnTypeInference) Expression(org.apache.flink.table.expressions.Expression) FlinkTypeFactory(org.apache.flink.table.planner.calcite.FlinkTypeFactory) ArrayList(java.util.ArrayList) HashSet(java.util.HashSet) BigDecimal(java.math.BigDecimal) RexFieldCollation(org.apache.calcite.rex.RexFieldCollation) SqlLiteral(org.apache.calcite.sql.SqlLiteral) SqlNode(org.apache.calcite.sql.SqlNode) DecimalType(org.apache.flink.table.types.logical.DecimalType) ImmutableList(com.google.common.collect.ImmutableList) RexNode(org.apache.calcite.rex.RexNode) ResolvedExpression(org.apache.flink.table.expressions.ResolvedExpression) SqlWindow(org.apache.calcite.sql.SqlWindow) SqlOperator(org.apache.calcite.sql.SqlOperator) ExpressionConverter.extractValue(org.apache.flink.table.planner.expressions.converter.ExpressionConverter.extractValue) RexWindowBound(org.apache.calcite.rex.RexWindowBound) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlParserPos(org.apache.calcite.sql.parser.SqlParserPos) SqlKind(org.apache.calcite.sql.SqlKind) FunctionDefinition(org.apache.flink.table.functions.FunctionDefinition) BuiltInFunctionDefinitions(org.apache.flink.table.functions.BuiltInFunctionDefinitions) TableException(org.apache.flink.table.api.TableException) SqlBasicCall(org.apache.calcite.sql.SqlBasicCall) Set(java.util.Set) ValueLiteralExpression(org.apache.flink.table.expressions.ValueLiteralExpression) RelFieldCollation(org.apache.calcite.rel.RelFieldCollation) Collectors(java.util.stream.Collectors) List(java.util.List) LogicalTypeDataTypeConverter.fromDataTypeToLogicalType(org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType) Optional(java.util.Optional) SqlAggFunctionVisitor(org.apache.flink.table.planner.expressions.SqlAggFunctionVisitor) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) LogicalTypeRoot(org.apache.flink.table.types.logical.LogicalTypeRoot) RexCall(org.apache.calcite.rex.RexCall) ResolvedExpression(org.apache.flink.table.expressions.ResolvedExpression) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) SqlKind(org.apache.calcite.sql.SqlKind) RexFieldCollation(org.apache.calcite.rex.RexFieldCollation) SqlAggFunctionVisitor(org.apache.flink.table.planner.expressions.SqlAggFunctionVisitor) FlinkTypeFactory(org.apache.flink.table.planner.calcite.FlinkTypeFactory) CallExpression(org.apache.flink.table.expressions.CallExpression) Expression(org.apache.flink.table.expressions.Expression) ResolvedExpression(org.apache.flink.table.expressions.ResolvedExpression) ValueLiteralExpression(org.apache.flink.table.expressions.ValueLiteralExpression) RexWindowBound(org.apache.calcite.rex.RexWindowBound) FunctionDefinition(org.apache.flink.table.functions.FunctionDefinition) CallExpression(org.apache.flink.table.expressions.CallExpression) RexNode(org.apache.calcite.rex.RexNode) HashSet(java.util.HashSet)

Example 44 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.

the class AggregateCallJsonDeserializer method deserialize.

@Override
public AggregateCall deserialize(JsonParser jsonParser, DeserializationContext ctx) throws IOException {
    final JsonNode jsonNode = jsonParser.readValueAsTree();
    final SerdeContext serdeContext = SerdeContext.get(ctx);
    final String name = jsonNode.required(FIELD_NAME_NAME).asText();
    final SqlAggFunction aggFunction = (SqlAggFunction) RexNodeJsonDeserializer.deserializeSqlOperator(jsonNode, serdeContext);
    final List<Integer> argList = new ArrayList<>();
    final JsonNode argListNode = jsonNode.required(FIELD_NAME_ARG_LIST);
    for (JsonNode argNode : argListNode) {
        argList.add(argNode.intValue());
    }
    final int filterArg = jsonNode.required(FIELD_NAME_FILTER_ARG).asInt();
    final boolean distinct = jsonNode.required(FIELD_NAME_DISTINCT).asBoolean();
    final boolean approximate = jsonNode.required(FIELD_NAME_APPROXIMATE).asBoolean();
    final boolean ignoreNulls = jsonNode.required(FIELD_NAME_IGNORE_NULLS).asBoolean();
    final RelDataType relDataType = RelDataTypeJsonDeserializer.deserialize(jsonNode.required(FIELD_NAME_TYPE), serdeContext);
    return AggregateCall.create(aggFunction, distinct, approximate, ignoreNulls, argList, filterArg, RelCollations.EMPTY, relDataType, name);
}
Also used : ArrayList(java.util.ArrayList) JsonNode(org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode) RelDataType(org.apache.calcite.rel.type.RelDataType) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction)

Example 45 with SqlAggFunction

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.sql.SqlAggFunction in project flink by apache.

the class FlinkAggregateJoinTransposeRule method onMatch.

public void onMatch(RelOptRuleCall call) {
    final Aggregate origAgg = call.rel(0);
    final Join join = call.rel(1);
    final RexBuilder rexBuilder = origAgg.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();
    // converts an aggregate with AUXILIARY_GROUP to a regular aggregate.
    // if the converted aggregate can be push down,
    // AggregateReduceGroupingRule will try reduce grouping of new aggregates created by this
    // rule
    final Pair<Aggregate, List<RexNode>> newAggAndProject = toRegularAggregate(origAgg);
    final Aggregate aggregate = newAggAndProject.left;
    final List<RexNode> projectAfterAgg = newAggAndProject.right;
    // If any aggregate call has a filter or distinct, bail out
    for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
        if (aggregateCall.getAggregation().unwrap(SqlSplittableAggFunction.class) == null) {
            return;
        }
        if (aggregateCall.filterArg >= 0 || aggregateCall.isDistinct()) {
            return;
        }
    }
    if (join.getJoinType() != JoinRelType.INNER) {
        return;
    }
    if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
        return;
    }
    // Do the columns used by the join appear in the output of the aggregate?
    final ImmutableBitSet aggregateColumns = aggregate.getGroupSet();
    final RelMetadataQuery mq = call.getMetadataQuery();
    final ImmutableBitSet keyColumns = keyColumns(aggregateColumns, mq.getPulledUpPredicates(join).pulledUpPredicates);
    final ImmutableBitSet joinColumns = RelOptUtil.InputFinder.bits(join.getCondition());
    final boolean allColumnsInAggregate = keyColumns.contains(joinColumns);
    final ImmutableBitSet belowAggregateColumns = aggregateColumns.union(joinColumns);
    // Split join condition
    final List<Integer> leftKeys = com.google.common.collect.Lists.newArrayList();
    final List<Integer> rightKeys = com.google.common.collect.Lists.newArrayList();
    final List<Boolean> filterNulls = com.google.common.collect.Lists.newArrayList();
    RexNode nonEquiConj = RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), leftKeys, rightKeys, filterNulls);
    // If it contains non-equi join conditions, we bail out
    if (!nonEquiConj.isAlwaysTrue()) {
        return;
    }
    // Push each aggregate function down to each side that contains all of its
    // arguments. Note that COUNT(*), because it has no arguments, can go to
    // both sides.
    final Map<Integer, Integer> map = new HashMap<>();
    final List<Side> sides = new ArrayList<>();
    int uniqueCount = 0;
    int offset = 0;
    int belowOffset = 0;
    for (int s = 0; s < 2; s++) {
        final Side side = new Side();
        final RelNode joinInput = join.getInput(s);
        int fieldCount = joinInput.getRowType().getFieldCount();
        final ImmutableBitSet fieldSet = ImmutableBitSet.range(offset, offset + fieldCount);
        final ImmutableBitSet belowAggregateKeyNotShifted = belowAggregateColumns.intersect(fieldSet);
        for (Ord<Integer> c : Ord.zip(belowAggregateKeyNotShifted)) {
            map.put(c.e, belowOffset + c.i);
        }
        final Mappings.TargetMapping mapping = s == 0 ? Mappings.createIdentity(fieldCount) : Mappings.createShiftMapping(fieldCount + offset, 0, offset, fieldCount);
        final ImmutableBitSet belowAggregateKey = belowAggregateKeyNotShifted.shift(-offset);
        final boolean unique;
        if (!allowFunctions) {
            assert aggregate.getAggCallList().isEmpty();
            // If there are no functions, it doesn't matter as much whether we
            // aggregate the inputs before the join, because there will not be
            // any functions experiencing a cartesian product effect.
            // 
            // But finding out whether the input is already unique requires a call
            // to areColumnsUnique that currently (until [CALCITE-1048] "Make
            // metadata more robust" is fixed) places a heavy load on
            // the metadata system.
            // 
            // So we choose to imagine the input is already unique, which is
            // untrue but harmless.
            // 
            Util.discard(Bug.CALCITE_1048_FIXED);
            unique = true;
        } else {
            final Boolean unique0 = mq.areColumnsUnique(joinInput, belowAggregateKey);
            unique = unique0 != null && unique0;
        }
        if (unique) {
            ++uniqueCount;
            side.aggregate = false;
            relBuilder.push(joinInput);
            final Map<Integer, Integer> belowAggregateKeyToNewProjectMap = new HashMap<>();
            final List<RexNode> projects = new ArrayList<>();
            for (Integer i : belowAggregateKey) {
                belowAggregateKeyToNewProjectMap.put(i, projects.size());
                projects.add(relBuilder.field(i));
            }
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                if (!aggCall.e.getArgList().isEmpty() && fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    final RexNode singleton = splitter.singleton(rexBuilder, joinInput.getRowType(), aggCall.e.transform(mapping));
                    final RexNode targetSingleton = rexBuilder.ensureType(aggCall.e.type, singleton, false);
                    if (targetSingleton instanceof RexInputRef) {
                        final int index = ((RexInputRef) targetSingleton).getIndex();
                        if (!belowAggregateKey.get(index)) {
                            projects.add(targetSingleton);
                            side.split.put(aggCall.i, projects.size() - 1);
                        } else {
                            side.split.put(aggCall.i, belowAggregateKeyToNewProjectMap.get(index));
                        }
                    } else {
                        projects.add(targetSingleton);
                        side.split.put(aggCall.i, projects.size() - 1);
                    }
                }
            }
            relBuilder.project(projects);
            side.newInput = relBuilder.build();
        } else {
            side.aggregate = true;
            List<AggregateCall> belowAggCalls = new ArrayList<>();
            final SqlSplittableAggFunction.Registry<AggregateCall> belowAggCallRegistry = registry(belowAggCalls);
            final int oldGroupKeyCount = aggregate.getGroupCount();
            final int newGroupKeyCount = belowAggregateKey.cardinality();
            for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
                final SqlAggFunction aggregation = aggCall.e.getAggregation();
                final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
                final AggregateCall call1;
                if (fieldSet.contains(ImmutableBitSet.of(aggCall.e.getArgList()))) {
                    final AggregateCall splitCall = splitter.split(aggCall.e, mapping);
                    call1 = splitCall.adaptTo(joinInput, splitCall.getArgList(), splitCall.filterArg, oldGroupKeyCount, newGroupKeyCount);
                } else {
                    call1 = splitter.other(rexBuilder.getTypeFactory(), aggCall.e);
                }
                if (call1 != null) {
                    side.split.put(aggCall.i, belowAggregateKey.cardinality() + belowAggCallRegistry.register(call1));
                }
            }
            side.newInput = relBuilder.push(joinInput).aggregate(relBuilder.groupKey(belowAggregateKey, null), belowAggCalls).build();
        }
        offset += fieldCount;
        belowOffset += side.newInput.getRowType().getFieldCount();
        sides.add(side);
    }
    if (uniqueCount == 2) {
        // invocation of this rule; if we continue we might loop forever.
        return;
    }
    // Update condition
    final Mapping mapping = (Mapping) Mappings.target(map::get, join.getRowType().getFieldCount(), belowOffset);
    final RexNode newCondition = RexUtil.apply(mapping, join.getCondition());
    // Create new join
    relBuilder.push(sides.get(0).newInput).push(sides.get(1).newInput).join(join.getJoinType(), newCondition);
    // Aggregate above to sum up the sub-totals
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final int groupIndicatorCount = aggregate.getGroupCount() + aggregate.getIndicatorCount();
    final int newLeftWidth = sides.get(0).newInput.getRowType().getFieldCount();
    final List<RexNode> projects = new ArrayList<>(rexBuilder.identityProjects(relBuilder.peek().getRowType()));
    for (Ord<AggregateCall> aggCall : Ord.zip(aggregate.getAggCallList())) {
        final SqlAggFunction aggregation = aggCall.e.getAggregation();
        final SqlSplittableAggFunction splitter = Preconditions.checkNotNull(aggregation.unwrap(SqlSplittableAggFunction.class));
        final Integer leftSubTotal = sides.get(0).split.get(aggCall.i);
        final Integer rightSubTotal = sides.get(1).split.get(aggCall.i);
        newAggCalls.add(splitter.topSplit(rexBuilder, registry(projects), groupIndicatorCount, relBuilder.peek().getRowType(), aggCall.e, leftSubTotal == null ? -1 : leftSubTotal, rightSubTotal == null ? -1 : rightSubTotal + newLeftWidth));
    }
    relBuilder.project(projects);
    boolean aggConvertedToProjects = false;
    if (allColumnsInAggregate) {
        // let's see if we can convert aggregate into projects
        List<RexNode> projects2 = new ArrayList<>();
        for (int key : Mappings.apply(mapping, aggregate.getGroupSet())) {
            projects2.add(relBuilder.field(key));
        }
        int aggCallIdx = projects2.size();
        for (AggregateCall newAggCall : newAggCalls) {
            final SqlSplittableAggFunction splitter = newAggCall.getAggregation().unwrap(SqlSplittableAggFunction.class);
            if (splitter != null) {
                final RelDataType rowType = relBuilder.peek().getRowType();
                final RexNode singleton = splitter.singleton(rexBuilder, rowType, newAggCall);
                final RelDataType originalAggCallType = aggregate.getRowType().getFieldList().get(aggCallIdx).getType();
                final RexNode targetSingleton = rexBuilder.ensureType(originalAggCallType, singleton, false);
                projects2.add(targetSingleton);
            }
            aggCallIdx += 1;
        }
        if (projects2.size() == aggregate.getGroupSet().cardinality() + newAggCalls.size()) {
            // We successfully converted agg calls into projects.
            relBuilder.project(projects2);
            aggConvertedToProjects = true;
        }
    }
    if (!aggConvertedToProjects) {
        relBuilder.aggregate(relBuilder.groupKey(Mappings.apply(mapping, aggregate.getGroupSet()), Mappings.apply2(mapping, aggregate.getGroupSets())), newAggCalls);
    }
    if (projectAfterAgg != null) {
        relBuilder.project(projectAfterAgg, origAgg.getRowType().getFieldNames());
    }
    call.transformTo(relBuilder.build());
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) Mapping(org.apache.calcite.util.mapping.Mapping) RelDataType(org.apache.calcite.rel.type.RelDataType) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) List(java.util.List) SqlSplittableAggFunction(org.apache.calcite.sql.SqlSplittableAggFunction) RelBuilder(org.apache.calcite.tools.RelBuilder) Join(org.apache.calcite.rel.core.Join) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelNode(org.apache.calcite.rel.RelNode) Mappings(org.apache.calcite.util.mapping.Mappings) RexInputRef(org.apache.calcite.rex.RexInputRef) Aggregate(org.apache.calcite.rel.core.Aggregate) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

SqlAggFunction (org.apache.calcite.sql.SqlAggFunction)47 RelDataType (org.apache.calcite.rel.type.RelDataType)30 ArrayList (java.util.ArrayList)25 RexNode (org.apache.calcite.rex.RexNode)24 AggregateCall (org.apache.calcite.rel.core.AggregateCall)20 RexBuilder (org.apache.calcite.rex.RexBuilder)15 RelNode (org.apache.calcite.rel.RelNode)13 List (java.util.List)11 Aggregate (org.apache.calcite.rel.core.Aggregate)10 RelBuilder (org.apache.calcite.tools.RelBuilder)10 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)9 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)9 HashMap (java.util.HashMap)8 ImmutableList (com.google.common.collect.ImmutableList)7 SqlOperator (org.apache.calcite.sql.SqlOperator)6 Type (java.lang.reflect.Type)5 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)5 RexInputRef (org.apache.calcite.rex.RexInputRef)5 SqlKind (org.apache.calcite.sql.SqlKind)5 Mappings (org.apache.calcite.util.mapping.Mappings)5