Search in sources :

Example 61 with Pair

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Pair in project hive by apache.

the class HiveSubQRemoveRelBuilder method join.

/** Creates a {@link org.apache.calcite.rel.core.Join} with correlating
   * variables. */
public HiveSubQRemoveRelBuilder join(JoinRelType joinType, RexNode condition, Set<CorrelationId> variablesSet) {
    Frame right = stack.pop();
    final Frame left = stack.pop();
    final RelNode join;
    final boolean correlate = variablesSet.size() == 1;
    RexNode postCondition = literal(true);
    if (correlate) {
        final CorrelationId id = Iterables.getOnlyElement(variablesSet);
        final ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, right.rel);
        if (!RelOptUtil.notContainsCorrelation(left.rel, id, Litmus.IGNORE)) {
            throw new IllegalArgumentException("variable " + id + " must not be used by left input to correlation");
        }
        switch(joinType) {
            case LEFT:
                // Correlate does not have an ON clause.
                // For a LEFT correlate, predicate must be evaluated first.
                // For INNER, we can defer.
                stack.push(right);
                filter(condition.accept(new Shifter(left.rel, id, right.rel)));
                right = stack.pop();
                break;
            default:
                postCondition = condition;
        }
        join = correlateFactory.createCorrelate(left.rel, right.rel, id, requiredColumns, SemiJoinType.of(joinType));
    } else {
        join = joinFactory.createJoin(left.rel, right.rel, condition, variablesSet, joinType, false);
    }
    final List<Pair<String, RelDataType>> pairs = new ArrayList<>();
    pairs.addAll(left.right);
    pairs.addAll(right.right);
    stack.push(new Frame(join, ImmutableList.copyOf(pairs)));
    filter(postCondition);
    return this;
}
Also used : RelNode(org.apache.calcite.rel.RelNode) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) ArrayList(java.util.ArrayList) CorrelationId(org.apache.calcite.rel.core.CorrelationId) RexNode(org.apache.calcite.rex.RexNode) Pair(org.apache.calcite.util.Pair)

Example 62 with Pair

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Pair in project hive by apache.

the class Vectorizer method validateAggregationDesc.

private Pair<Boolean, Boolean> validateAggregationDesc(AggregationDesc aggDesc, ProcessingMode processingMode, boolean hasKeys) {
    String udfName = aggDesc.getGenericUDAFName().toLowerCase();
    if (!supportedAggregationUdfs.contains(udfName)) {
        setExpressionIssue("Aggregation Function", "UDF " + udfName + " not supported");
        return new Pair<Boolean, Boolean>(false, false);
    }
    if (aggDesc.getParameters() != null && !validateExprNodeDesc(aggDesc.getParameters(), "Aggregation Function UDF " + udfName + " parameter")) {
        return new Pair<Boolean, Boolean>(false, false);
    }
    // See if we can vectorize the aggregation.
    VectorizationContext vc = new ValidatorVectorizationContext(hiveConf);
    VectorAggregateExpression vectorAggrExpr;
    try {
        vectorAggrExpr = vc.getAggregatorExpression(aggDesc);
    } catch (Exception e) {
        // We should have already attempted to vectorize in validateAggregationDesc.
        if (LOG.isDebugEnabled()) {
            LOG.debug("Vectorization of aggregation should have succeeded ", e);
        }
        setExpressionIssue("Aggregation Function", "Vectorization of aggreation should have succeeded " + e);
        return new Pair<Boolean, Boolean>(false, false);
    }
    if (LOG.isDebugEnabled()) {
        LOG.debug("Aggregation " + aggDesc.getExprString() + " --> " + " vector expression " + vectorAggrExpr.toString());
    }
    ObjectInspector.Category outputCategory = aggregationOutputCategory(vectorAggrExpr);
    boolean outputIsPrimitive = (outputCategory == ObjectInspector.Category.PRIMITIVE);
    if (processingMode == ProcessingMode.MERGE_PARTIAL && hasKeys && !outputIsPrimitive) {
        setOperatorIssue("Vectorized Reduce MergePartial GROUP BY keys can only handle aggregate outputs that are primitive types");
        return new Pair<Boolean, Boolean>(false, false);
    }
    return new Pair<Boolean, Boolean>(true, outputIsPrimitive);
}
Also used : ObjectInspector(org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector) StructObjectInspector(org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector) VectorAggregateExpression(org.apache.hadoop.hive.ql.exec.vector.expressions.aggregates.VectorAggregateExpression) VectorizationContext(org.apache.hadoop.hive.ql.exec.vector.VectorizationContext) UDFToString(org.apache.hadoop.hive.ql.udf.UDFToString) UDFToBoolean(org.apache.hadoop.hive.ql.udf.UDFToBoolean) Category(org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category) SemanticException(org.apache.hadoop.hive.ql.parse.SemanticException) HiveException(org.apache.hadoop.hive.ql.metadata.HiveException) Pair(org.apache.calcite.util.Pair) ImmutablePair(org.apache.commons.lang3.tuple.ImmutablePair)

Example 63 with Pair

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Pair in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method convertSingletonDistinct.

/**
	 * Converts an aggregate with one distinct aggregate and one or more
	 * non-distinct aggregates to multi-phase aggregates (see reference example
	 * below).
	 *
	 * @param relBuilder Contains the input relational expression
	 * @param aggregate  Original aggregate
	 * @param argLists   Arguments and filters to the distinct aggregate function
	 *
	 */
private RelBuilder convertSingletonDistinct(RelBuilder relBuilder, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    // For example,
    //	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
    //	FROM emp
    //	GROUP BY deptno
    //
    // becomes
    //
    //	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
    //	FROM (
    //		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
    //		  FROM EMP
    //		  GROUP BY deptno, sal)			// Aggregate B
    //	GROUP BY deptno						// Aggregate A
    relBuilder.push(aggregate.getInput());
    final List<Pair<RexNode, String>> projects = new ArrayList<>();
    final Map<Integer, Integer> sourceOf = new HashMap<>();
    SortedSet<Integer> newGroupSet = new TreeSet<>();
    final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
    final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;
    SortedSet<Integer> groupSet = new TreeSet<>(aggregate.getGroupSet().asList());
    // Add the distinct aggregate column(s) to the group-by columns,
    // if not already a part of the group-by
    newGroupSet.addAll(aggregate.getGroupSet().asList());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        newGroupSet.addAll(argList.getKey());
    }
    // transformation.
    for (int arg : newGroupSet) {
        sourceOf.put(arg, projects.size());
        projects.add(RexInputRef.of2(arg, childFields));
    }
    // Generate the intermediate aggregate B
    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
    final List<AggregateCall> newAggCalls = new ArrayList<>();
    final List<Integer> fakeArgs = new ArrayList<>();
    final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
    // e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.put(arg, projects.size());
                }
            }
        }
    }
    int fakeArg0 = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // We will deal with non-distinct aggregates below
        if (!aggCall.isDistinct()) {
            boolean isGroupKeyUsedInAgg = false;
            for (int arg : aggCall.getArgList()) {
                if (groupSet.contains(arg)) {
                    isGroupKeyUsedInAgg = true;
                    break;
                }
            }
            if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
                while (sourceOf.get(fakeArg0) != null) {
                    ++fakeArg0;
                }
                fakeArgs.add(fakeArg0);
                ++fakeArg0;
            }
        }
    }
    for (final AggregateCall aggCall : aggCalls) {
        if (!aggCall.isDistinct()) {
            for (int arg : aggCall.getArgList()) {
                if (!groupSet.contains(arg)) {
                    sourceOf.remove(arg);
                }
            }
        }
    }
    // Compute the remapped arguments using fake arguments for non-distinct
    // aggregates with no arguments e.g. count(*).
    int fakeArgIdx = 0;
    for (final AggregateCall aggCall : aggCalls) {
        // as-is all the non-distinct aggregates
        if (!aggCall.isDistinct()) {
            final AggregateCall newCall = AggregateCall.create(aggCall.getAggregation(), false, aggCall.getArgList(), -1, ImmutableBitSet.of(newGroupSet).cardinality(), relBuilder.peek(), null, aggCall.name);
            newAggCalls.add(newCall);
            if (newCall.getArgList().size() == 0) {
                int fakeArg = fakeArgs.get(fakeArgIdx);
                callArgMap.put(newCall, fakeArg);
                sourceOf.put(fakeArg, projects.size());
                projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                ++fakeArgIdx;
            } else {
                for (int arg : newCall.getArgList()) {
                    if (groupSet.contains(arg)) {
                        int fakeArg = fakeArgs.get(fakeArgIdx);
                        callArgMap.put(newCall, fakeArg);
                        sourceOf.put(fakeArg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()), newCall.getName()));
                        ++fakeArgIdx;
                    } else {
                        sourceOf.put(arg, projects.size());
                        projects.add(Pair.of((RexNode) new RexInputRef(arg, newCall.getType()), newCall.getName()));
                    }
                }
            }
        }
    }
    // Generate the aggregate B (see the reference example above)
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
    // Convert the existing aggregate to aggregate A (see the reference example above)
    final List<AggregateCall> newTopAggCalls = Lists.newArrayList(aggregate.getAggCallList());
    // Use the remapped arguments for the (non)distinct aggregate calls
    for (int i = 0; i < newTopAggCalls.size(); i++) {
        // Re-map arguments.
        final AggregateCall aggCall = newTopAggCalls.get(i);
        final int argCount = aggCall.getArgList().size();
        final List<Integer> newArgs = new ArrayList<>(argCount);
        final AggregateCall newCall;
        for (int j = 0; j < argCount; j++) {
            final Integer arg = aggCall.getArgList().get(j);
            if (callArgMap.containsKey(aggCall)) {
                newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
            } else {
                newArgs.add(sourceOf.get(arg));
            }
        }
        if (aggCall.isDistinct()) {
            newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
        } else {
            // aggregate A must be SUM. For other aggregates, it remains the same.
            if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
                if (aggCall.getArgList().size() == 0) {
                    newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
                }
                if (hasGroupBy) {
                    SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                } else {
                    SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
                    newCall = AggregateCall.create(sumAgg, false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.getName());
                }
            } else {
                newCall = AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1, aggregate.getGroupSet().cardinality(), relBuilder.peek(), aggCall.getType(), aggCall.name);
            }
        }
        newTopAggCalls.set(i, newCall);
    }
    // Populate the group-by keys with the remapped arguments for aggregate A
    newGroupSet.clear();
    for (int arg : aggregate.getGroupSet()) {
        newGroupSet.add(sourceOf.get(arg));
    }
    relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), aggregate.indicator, ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
    return relBuilder;
}
Also used : HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) SqlCountAggFunction(org.apache.calcite.sql.fun.SqlCountAggFunction) AggregateCall(org.apache.calcite.rel.core.AggregateCall) SqlSumEmptyIsZeroAggFunction(org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) TreeSet(java.util.TreeSet) SqlSumAggFunction(org.apache.calcite.sql.fun.SqlSumAggFunction) RexInputRef(org.apache.calcite.rex.RexInputRef) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)

Example 64 with Pair

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Pair in project flink by apache.

the class FlinkAggregateExpandDistinctAggregatesRule method rewriteUsingGroupingSets.

/*
	public RelBuilder convertSingletonDistinct(RelBuilder relBuilder,
											   Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
		// For example,
		//	SELECT deptno, COUNT(*), SUM(bonus), MIN(DISTINCT sal)
		//	FROM emp
		//	GROUP BY deptno
		//
		// becomes
		//
		//	SELECT deptno, SUM(cnt), SUM(bonus), MIN(sal)
		//	FROM (
		//		  SELECT deptno, COUNT(*) as cnt, SUM(bonus), sal
		//		  FROM EMP
		//		  GROUP BY deptno, sal)			// Aggregate B
		//	GROUP BY deptno						// Aggregate A
		relBuilder.push(aggregate.getInput());
		final List<Pair<RexNode, String>> projects = new ArrayList<>();
		final Map<Integer, Integer> sourceOf = new HashMap<>();
		SortedSet<Integer> newGroupSet = new TreeSet<>();
		final List<RelDataTypeField> childFields =
				relBuilder.peek().getRowType().getFieldList();
		final boolean hasGroupBy = aggregate.getGroupSet().size() > 0;

		// Add the distinct aggregate column(s) to the group-by columns,
		// if not already a part of the group-by
		newGroupSet.addAll(aggregate.getGroupSet().asList());
		for (Pair<List<Integer>, Integer> argList : argLists) {
			newGroupSet.addAll(argList.getKey());
		}

		// Re-map the arguments to the aggregate A. These arguments will get
		// remapped because of the intermediate aggregate B generated as part of the
		// transformation.
		for (int arg : newGroupSet) {
			sourceOf.put(arg, projects.size());
			projects.add(RexInputRef.of2(arg, childFields));
		}
		// Generate the intermediate aggregate B
		final List<AggregateCall> aggCalls = aggregate.getAggCallList();
		final List<AggregateCall> newAggCalls = new ArrayList<>();
		final List<Integer> fakeArgs = new ArrayList<>();
		final Map<AggregateCall, Integer> callArgMap = new HashMap<>();
		// First identify the real arguments, then use the rest for fake arguments
		// e.g. if real arguments are 0, 1, 3. Then the fake arguments will be 2, 4
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.put(arg, projects.size());
					}
				}
			}
		}
		int fakeArg0 = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// We will deal with non-distinct aggregates below
			if (!aggCall.isDistinct()) {
				boolean isGroupKeyUsedInAgg = false;
				for (int arg : aggCall.getArgList()) {
					if (sourceOf.containsKey(arg)) {
						isGroupKeyUsedInAgg = true;
						break;
					}
				}
				if (aggCall.getArgList().size() == 0 || isGroupKeyUsedInAgg) {
					while (sourceOf.get(fakeArg0) != null) {
						++fakeArg0;
					}
					fakeArgs.add(fakeArg0);
				}
			}
		}
		for (final AggregateCall aggCall : aggCalls) {
			if (!aggCall.isDistinct()) {
				for (int arg : aggCall.getArgList()) {
					if (!sourceOf.containsKey(arg)) {
						sourceOf.remove(arg);
					}
				}
			}
		}
		// Compute the remapped arguments using fake arguments for non-distinct
		// aggregates with no arguments e.g. count(*).
		int fakeArgIdx = 0;
		for (final AggregateCall aggCall : aggCalls) {
			// Project the column corresponding to the distinct aggregate. Project
			// as-is all the non-distinct aggregates
			if (!aggCall.isDistinct()) {
				final AggregateCall newCall =
						AggregateCall.create(aggCall.getAggregation(), false,
								aggCall.getArgList(), -1,
								ImmutableBitSet.of(newGroupSet).cardinality(),
								relBuilder.peek(), null, aggCall.name);
				newAggCalls.add(newCall);
				if (newCall.getArgList().size() == 0) {
					int fakeArg = fakeArgs.get(fakeArgIdx);
					callArgMap.put(newCall, fakeArg);
					sourceOf.put(fakeArg, projects.size());
					projects.add(
							Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
									newCall.getName()));
					++fakeArgIdx;
				} else {
					for (int arg : newCall.getArgList()) {
						if (sourceOf.containsKey(arg)) {
							int fakeArg = fakeArgs.get(fakeArgIdx);
							callArgMap.put(newCall, fakeArg);
							sourceOf.put(fakeArg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(fakeArg, newCall.getType()),
											newCall.getName()));
							++fakeArgIdx;
						} else {
							sourceOf.put(arg, projects.size());
							projects.add(
									Pair.of((RexNode) new RexInputRef(arg, newCall.getType()),
											newCall.getName()));
						}
					}
				}
			}
		}
		// Generate the aggregate B (see the reference example above)
		relBuilder.push(
				aggregate.copy(
						aggregate.getTraitSet(), relBuilder.build(),
						false, ImmutableBitSet.of(newGroupSet), null, newAggCalls));
		// Convert the existing aggregate to aggregate A (see the reference example above)
		final List<AggregateCall> newTopAggCalls =
				Lists.newArrayList(aggregate.getAggCallList());
		// Use the remapped arguments for the (non)distinct aggregate calls
		for (int i = 0; i < newTopAggCalls.size(); i++) {
			// Re-map arguments.
			final AggregateCall aggCall = newTopAggCalls.get(i);
			final int argCount = aggCall.getArgList().size();
			final List<Integer> newArgs = new ArrayList<>(argCount);
			final AggregateCall newCall;


			for (int j = 0; j < argCount; j++) {
				final Integer arg = aggCall.getArgList().get(j);
				if (callArgMap.containsKey(aggCall)) {
					newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
				}
				else {
					newArgs.add(sourceOf.get(arg));
				}
			}
			if (aggCall.isDistinct()) {
				newCall =
						AggregateCall.create(aggCall.getAggregation(), false, newArgs,
								-1, aggregate.getGroupSet().cardinality(), relBuilder.peek(),
								aggCall.getType(), aggCall.name);
			} else {
				// If aggregate B had a COUNT aggregate call the corresponding aggregate at
				// aggregate A must be SUM. For other aggregates, it remains the same.
				if (aggCall.getAggregation() instanceof SqlCountAggFunction) {
					if (aggCall.getArgList().size() == 0) {
						newArgs.add(sourceOf.get(callArgMap.get(aggCall)));
					}
					if (hasGroupBy) {
						SqlSumAggFunction sumAgg = new SqlSumAggFunction(null);
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					} else {
						SqlSumEmptyIsZeroAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
						newCall =
								AggregateCall.create(sumAgg, false, newArgs, -1,
										aggregate.getGroupSet().cardinality(), relBuilder.peek(),
										aggCall.getType(), aggCall.getName());
					}
				} else {
					newCall =
							AggregateCall.create(aggCall.getAggregation(), false, newArgs, -1,
									aggregate.getGroupSet().cardinality(),
									relBuilder.peek(), aggCall.getType(), aggCall.name);
				}
			}
			newTopAggCalls.set(i, newCall);
		}
		// Populate the group-by keys with the remapped arguments for aggregate A
		newGroupSet.clear();
		for (int arg : aggregate.getGroupSet()) {
			newGroupSet.add(sourceOf.get(arg));
		}
		relBuilder.push(
				aggregate.copy(aggregate.getTraitSet(),
						relBuilder.build(), aggregate.indicator,
						ImmutableBitSet.of(newGroupSet), null, newTopAggCalls));
		return relBuilder;
	}
	*/
@SuppressWarnings("DanglingJavadoc")
private void rewriteUsingGroupingSets(RelOptRuleCall call, Aggregate aggregate, Set<Pair<List<Integer>, Integer>> argLists) {
    final Set<ImmutableBitSet> groupSetTreeSet = new TreeSet<>(ImmutableBitSet.ORDERING);
    groupSetTreeSet.add(aggregate.getGroupSet());
    for (Pair<List<Integer>, Integer> argList : argLists) {
        groupSetTreeSet.add(ImmutableBitSet.of(argList.left).setIf(argList.right, argList.right >= 0).union(aggregate.getGroupSet()));
    }
    final ImmutableList<ImmutableBitSet> groupSets = ImmutableList.copyOf(groupSetTreeSet);
    final ImmutableBitSet fullGroupSet = ImmutableBitSet.union(groupSets);
    final List<AggregateCall> distinctAggCalls = new ArrayList<>();
    for (Pair<AggregateCall, String> aggCall : aggregate.getNamedAggCalls()) {
        if (!aggCall.left.isDistinct()) {
            distinctAggCalls.add(aggCall.left.rename(aggCall.right));
        }
    }
    final RelBuilder relBuilder = call.builder();
    relBuilder.push(aggregate.getInput());
    relBuilder.aggregate(relBuilder.groupKey(fullGroupSet, groupSets.size() > 1, groupSets), distinctAggCalls);
    final RelNode distinct = relBuilder.peek();
    final int groupCount = fullGroupSet.cardinality();
    final int indicatorCount = groupSets.size() > 1 ? groupCount : 0;
    final RelOptCluster cluster = aggregate.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
    final RelDataType booleanType = typeFactory.createTypeWithNullability(typeFactory.createSqlType(SqlTypeName.BOOLEAN), false);
    final List<Pair<RexNode, String>> predicates = new ArrayList<>();
    final Map<ImmutableBitSet, Integer> filters = new HashMap<>();
    /** Function to register a filter for a group set. */
    class Registrar {

        RexNode group = null;

        private int register(ImmutableBitSet groupSet) {
            if (group == null) {
                group = makeGroup(groupCount - 1);
            }
            final RexNode node = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, group, rexBuilder.makeExactLiteral(toNumber(remap(fullGroupSet, groupSet))));
            predicates.add(Pair.of(node, toString(groupSet)));
            return groupCount + indicatorCount + distinctAggCalls.size() + predicates.size() - 1;
        }

        private RexNode makeGroup(int i) {
            final RexInputRef ref = rexBuilder.makeInputRef(booleanType, groupCount + i);
            final RexNode kase = rexBuilder.makeCall(SqlStdOperatorTable.CASE, ref, rexBuilder.makeExactLiteral(BigDecimal.ZERO), rexBuilder.makeExactLiteral(TWO.pow(i)));
            if (i == 0) {
                return kase;
            } else {
                return rexBuilder.makeCall(SqlStdOperatorTable.PLUS, makeGroup(i - 1), kase);
            }
        }

        private BigDecimal toNumber(ImmutableBitSet bitSet) {
            BigDecimal n = BigDecimal.ZERO;
            for (int key : bitSet) {
                n = n.add(TWO.pow(key));
            }
            return n;
        }

        private String toString(ImmutableBitSet bitSet) {
            final StringBuilder buf = new StringBuilder("$i");
            for (int key : bitSet) {
                buf.append(key).append('_');
            }
            return buf.substring(0, buf.length() - 1);
        }
    }
    final Registrar registrar = new Registrar();
    for (ImmutableBitSet groupSet : groupSets) {
        filters.put(groupSet, registrar.register(groupSet));
    }
    if (!predicates.isEmpty()) {
        List<Pair<RexNode, String>> nodes = new ArrayList<>();
        for (RelDataTypeField f : relBuilder.peek().getRowType().getFieldList()) {
            final RexNode node = rexBuilder.makeInputRef(f.getType(), f.getIndex());
            nodes.add(Pair.of(node, f.getName()));
        }
        nodes.addAll(predicates);
        relBuilder.project(Pair.left(nodes), Pair.right(nodes));
    }
    int x = groupCount + indicatorCount;
    final List<AggregateCall> newCalls = new ArrayList<>();
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
        final int newFilterArg;
        final List<Integer> newArgList;
        final SqlAggFunction aggregation;
        if (!aggCall.isDistinct()) {
            aggregation = SqlStdOperatorTable.MIN;
            newArgList = ImmutableIntList.of(x++);
            newFilterArg = filters.get(aggregate.getGroupSet());
        } else {
            aggregation = aggCall.getAggregation();
            newArgList = remap(fullGroupSet, aggCall.getArgList());
            newFilterArg = filters.get(ImmutableBitSet.of(aggCall.getArgList()).setIf(aggCall.filterArg, aggCall.filterArg >= 0).union(aggregate.getGroupSet()));
        }
        final AggregateCall newCall = AggregateCall.create(aggregation, false, newArgList, newFilterArg, aggregate.getGroupCount(), distinct, null, aggCall.name);
        newCalls.add(newCall);
    }
    relBuilder.aggregate(relBuilder.groupKey(remap(fullGroupSet, aggregate.getGroupSet()), aggregate.indicator, remap(fullGroupSet, aggregate.getGroupSets())), newCalls);
    relBuilder.convert(aggregate.getRowType(), true);
    call.transformTo(relBuilder.build());
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) ImmutableBitSet(org.apache.calcite.util.ImmutableBitSet) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) TreeSet(java.util.TreeSet) RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) ArrayList(java.util.ArrayList) ImmutableList(com.google.common.collect.ImmutableList) ImmutableIntList(org.apache.calcite.util.ImmutableIntList) List(java.util.List) Pair(org.apache.calcite.util.Pair) RelBuilder(org.apache.calcite.tools.RelBuilder) SqlAggFunction(org.apache.calcite.sql.SqlAggFunction) BigDecimal(java.math.BigDecimal) AggregateCall(org.apache.calcite.rel.core.AggregateCall) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) RexInputRef(org.apache.calcite.rex.RexInputRef) RexNode(org.apache.calcite.rex.RexNode)

Example 65 with Pair

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.Pair in project flink by apache.

the class FlinkRelDecorrelator method aggregateCorrelatorOutput.

/**
	 * Pulls a {@link Project} above a {@link Correlate} from its RHS input.
	 * Enforces nullability for join output.
	 *
	 * @param correlate Correlate
	 * @param project   the original project as the RHS input of the join
	 * @param isCount   Positions which are calls to the <code>COUNT</code>
	 *                  aggregation function
	 * @return the subtree with the new LogicalProject at the root
	 */
private RelNode aggregateCorrelatorOutput(Correlate correlate, LogicalProject project, Set<Integer> isCount) {
    final RelNode left = correlate.getLeft();
    final JoinRelType joinType = correlate.getJoinType().toJoinType();
    // now create the new project
    final List<Pair<RexNode, String>> newProjects = Lists.newArrayList();
    // Project everything from the LHS and then those from the original
    // project
    final List<RelDataTypeField> leftInputFields = left.getRowType().getFieldList();
    for (int i = 0; i < leftInputFields.size(); i++) {
        newProjects.add(RexInputRef.of2(i, leftInputFields));
    }
    // Marked where the projected expr is coming from so that the types will
    // become nullable for the original projections which are now coming out
    // of the nullable side of the OJ.
    boolean projectPulledAboveLeftCorrelator = joinType.generatesNullsOnRight();
    for (Pair<RexNode, String> pair : project.getNamedProjects()) {
        RexNode newProjExpr = removeCorrelationExpr(pair.left, projectPulledAboveLeftCorrelator, isCount);
        newProjects.add(Pair.of(newProjExpr, pair.right));
    }
    return RelOptUtil.createProject(correlate, newProjects, false);
}
Also used : JoinRelType(org.apache.calcite.rel.core.JoinRelType) RelDataTypeField(org.apache.calcite.rel.type.RelDataTypeField) RelNode(org.apache.calcite.rel.RelNode) Pair(org.apache.calcite.util.Pair) RexNode(org.apache.calcite.rex.RexNode)

Aggregations

Pair (org.apache.calcite.util.Pair)112 RexNode (org.apache.calcite.rex.RexNode)72 ArrayList (java.util.ArrayList)70 RelNode (org.apache.calcite.rel.RelNode)59 RelDataTypeField (org.apache.calcite.rel.type.RelDataTypeField)55 RexInputRef (org.apache.calcite.rex.RexInputRef)29 ImmutableBitSet (org.apache.calcite.util.ImmutableBitSet)29 HashMap (java.util.HashMap)26 RexBuilder (org.apache.calcite.rex.RexBuilder)23 Map (java.util.Map)21 AggregateCall (org.apache.calcite.rel.core.AggregateCall)20 List (java.util.List)19 RelDataType (org.apache.calcite.rel.type.RelDataType)19 ImmutableList (com.google.common.collect.ImmutableList)18 JoinRelType (org.apache.calcite.rel.core.JoinRelType)16 TreeMap (java.util.TreeMap)14 RelDataTypeFactory (org.apache.calcite.rel.type.RelDataTypeFactory)13 RelBuilder (org.apache.calcite.tools.RelBuilder)13 ImmutableMap (com.google.common.collect.ImmutableMap)12 ImmutableSortedMap (com.google.common.collect.ImmutableSortedMap)12