use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef in project samza by apache.
the class TestJoinTranslator method testTranslateStreamToTableJoin.
private void testTranslateStreamToTableJoin(boolean isRemoteTable) throws IOException, ClassNotFoundException {
// setup mock values to the constructor of JoinTranslator
final String logicalOpId = "sql0_join3";
final int queryId = 0;
LogicalJoin mockJoin = PowerMockito.mock(LogicalJoin.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
RelNode mockLeftInput = PowerMockito.mock(LogicalTableScan.class);
RelNode mockRightInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockLeftInput);
inputs.add(mockRightInput);
RelOptTable mockLeftTable = mock(RelOptTable.class);
when(mockLeftInput.getTable()).thenReturn(mockLeftTable);
List<String> qualifiedTableName = Arrays.asList("test", "LeftTable");
when(mockLeftTable.getQualifiedName()).thenReturn(qualifiedTableName);
when(mockLeftInput.getId()).thenReturn(1);
when(mockRightInput.getId()).thenReturn(2);
when(mockJoin.getId()).thenReturn(3);
when(mockJoin.getInputs()).thenReturn(inputs);
when(mockJoin.getLeft()).thenReturn(mockLeftInput);
when(mockJoin.getRight()).thenReturn(mockRightInput);
RexCall mockJoinCondition = mock(RexCall.class);
when(mockJoinCondition.isAlwaysTrue()).thenReturn(false);
when(mockJoinCondition.getKind()).thenReturn(SqlKind.EQUALS);
when(mockJoin.getCondition()).thenReturn(mockJoinCondition);
RexInputRef mockLeftConditionInput = mock(RexInputRef.class);
RexInputRef mockRightConditionInput = mock(RexInputRef.class);
when(mockLeftConditionInput.getIndex()).thenReturn(0);
when(mockRightConditionInput.getIndex()).thenReturn(0);
List<RexNode> condOperands = new ArrayList<>();
condOperands.add(mockLeftConditionInput);
condOperands.add(mockRightConditionInput);
when(mockJoinCondition.getOperands()).thenReturn(condOperands);
RelDataType mockLeftCondDataType = mock(RelDataType.class);
RelDataType mockRightCondDataType = mock(RelDataType.class);
when(mockLeftCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockRightCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockLeftConditionInput.getType()).thenReturn(mockLeftCondDataType);
when(mockRightConditionInput.getType()).thenReturn(mockRightCondDataType);
RelDataType mockLeftRowType = mock(RelDataType.class);
// ?? why ??
when(mockLeftRowType.getFieldCount()).thenReturn(0);
when(mockLeftInput.getRowType()).thenReturn(mockLeftRowType);
List<String> leftFieldNames = Collections.singletonList("test_table_field1");
List<String> rightStreamFieldNames = Collections.singletonList("test_stream_field1");
when(mockLeftRowType.getFieldNames()).thenReturn(leftFieldNames);
RelDataType mockRightRowType = mock(RelDataType.class);
when(mockRightInput.getRowType()).thenReturn(mockRightRowType);
when(mockRightRowType.getFieldNames()).thenReturn(rightStreamFieldNames);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockLeftInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockLeftInputStream = new MessageStreamImpl<>(mockAppDesc, mockLeftInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockLeftInput.getId()))).thenReturn(mockLeftInputStream);
OperatorSpec<Object, SamzaSqlRelMessage> mockRightInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockRightInputStream = new MessageStreamImpl<>(mockAppDesc, mockRightInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockRightInput.getId()))).thenReturn(mockRightInputStream);
when(mockTranslatorContext.getStreamAppDescriptor()).thenReturn(mockAppDesc);
InputOperatorSpec mockInputOp = mock(InputOperatorSpec.class);
OutputStreamImpl mockOutputStream = mock(OutputStreamImpl.class);
when(mockInputOp.isKeyed()).thenReturn(true);
when(mockOutputStream.isKeyed()).thenReturn(true);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(3), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
if (isRemoteTable) {
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RemoteTableDescriptor.class));
} else {
IntermediateMessageStreamImpl mockPartitionedStream = new IntermediateMessageStreamImpl(mockAppDesc, mockInputOp, mockOutputStream);
when(mockAppDesc.getIntermediateStream(any(String.class), any(Serde.class), eq(false))).thenReturn(mockPartitionedStream);
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RocksDbTableDescriptor.class));
}
when(mockJoin.getJoinType()).thenReturn(JoinRelType.INNER);
SamzaSqlExecutionContext mockExecutionContext = mock(SamzaSqlExecutionContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(mockExecutionContext);
SamzaSqlApplicationConfig mockAppConfig = mock(SamzaSqlApplicationConfig.class);
when(mockExecutionContext.getSamzaSqlApplicationConfig()).thenReturn(mockAppConfig);
Map<String, SqlIOConfig> ssConfigBySource = mock(HashMap.class);
when(mockAppConfig.getInputSystemStreamConfigBySource()).thenReturn(ssConfigBySource);
SqlIOConfig mockIOConfig = mock(SqlIOConfig.class);
TableDescriptor mockTableDesc;
if (isRemoteTable) {
mockTableDesc = mock(RemoteTableDescriptor.class);
} else {
mockTableDesc = mock(RocksDbTableDescriptor.class);
}
when(ssConfigBySource.get(String.join(".", qualifiedTableName))).thenReturn(mockIOConfig);
when(mockIOConfig.getTableDescriptor()).thenReturn(Optional.of(mockTableDesc));
JoinTranslator joinTranslator = new JoinTranslator(logicalOpId, "", queryId);
// Verify Metrics Works with Join
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
TranslatorInputMetricsMapFunction inputMetricsMF = joinTranslator.getInputMetricsMF();
assertNotNull(inputMetricsMF);
inputMetricsMF.init(mockContext);
TranslatorOutputMetricsMapFunction outputMetricsMF = joinTranslator.getOutputMetricsMF();
assertNotNull(outputMetricsMF);
outputMetricsMF.init(mockContext);
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(logicalOpId).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(1).getCount());
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
// Apply translate() method to verify that we are getting the correct map operator constructed
joinTranslator.translate(mockJoin, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerMessageStream(3, this.getRegisteredMessageStream(3));
when(mockTranslatorContext.getRelNode(3)).thenReturn(mockJoin);
when(mockTranslatorContext.getMessageStream(3)).thenReturn(this.getRegisteredMessageStream(3));
StreamTableJoinOperatorSpec joinSpec = (StreamTableJoinOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(3), "operatorSpec");
assertNotNull(joinSpec);
assertEquals(joinSpec.getOpCode(), OperatorSpec.OpCode.JOIN);
// Verify joinSpec has the corresponding setup
StreamTableJoinFunction joinFn = joinSpec.getJoinFn();
assertNotNull(joinFn);
if (isRemoteTable) {
assertTrue(joinFn instanceof SamzaSqlRemoteTableJoinFunction);
} else {
assertTrue(joinFn instanceof SamzaSqlLocalTableJoinFunction);
}
assertTrue(Whitebox.getInternalState(joinFn, "isTablePosOnRight").equals(false));
assertEquals(Collections.singletonList(0), Whitebox.getInternalState(joinFn, "streamFieldIds"));
assertEquals(leftFieldNames, Whitebox.getInternalState(joinFn, "tableFieldNames"));
List<String> outputFieldNames = new ArrayList<>();
outputFieldNames.addAll(leftFieldNames);
outputFieldNames.addAll(rightStreamFieldNames);
assertEquals(outputFieldNames, Whitebox.getInternalState(joinFn, "outFieldNames"));
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef in project flink by apache.
the class HiveParserCalcitePlanner method genOBLogicalPlan.
private Pair<Sort, RelNode> genOBLogicalPlan(HiveParserQB qb, RelNode srcRel, boolean outermostOB) throws SemanticException {
Sort sortRel = null;
RelNode originalOBInput = null;
HiveParserQBParseInfo qbp = qb.getParseInfo();
String dest = qbp.getClauseNames().iterator().next();
HiveParserASTNode obAST = qbp.getOrderByForClause(dest);
if (obAST != null) {
// 1. OB Expr sanity test
// in strict mode, in the presence of order by, limit must be specified
Integer limit = qb.getParseInfo().getDestLimit(dest);
if (limit == null) {
String mapRedMode = semanticAnalyzer.getConf().getVar(HiveConf.ConfVars.HIVEMAPREDMODE);
boolean banLargeQuery = Boolean.parseBoolean(semanticAnalyzer.getConf().get("hive.strict.checks.large.query", "false"));
if ("strict".equalsIgnoreCase(mapRedMode) || banLargeQuery) {
throw new SemanticException(generateErrorMessage(obAST, "Order by-s without limit"));
}
}
// 2. Walk through OB exprs and extract field collations and additional
// virtual columns needed
final List<RexNode> virtualCols = new ArrayList<>();
final List<RelFieldCollation> fieldCollations = new ArrayList<>();
int fieldIndex;
List<Node> obASTExprLst = obAST.getChildren();
HiveParserASTNode obASTExpr;
HiveParserASTNode nullOrderASTExpr;
List<Pair<HiveParserASTNode, TypeInfo>> vcASTAndType = new ArrayList<>();
HiveParserRowResolver inputRR = relToRowResolver.get(srcRel);
HiveParserRowResolver outputRR = new HiveParserRowResolver();
HiveParserRexNodeConverter converter = new HiveParserRexNodeConverter(cluster, srcRel.getRowType(), relToHiveColNameCalcitePosMap.get(srcRel), 0, false, funcConverter);
int numSrcFields = srcRel.getRowType().getFieldCount();
for (Node node : obASTExprLst) {
// 2.1 Convert AST Expr to ExprNode
obASTExpr = (HiveParserASTNode) node;
nullOrderASTExpr = (HiveParserASTNode) obASTExpr.getChild(0);
HiveParserASTNode ref = (HiveParserASTNode) nullOrderASTExpr.getChild(0);
Map<HiveParserASTNode, ExprNodeDesc> astToExprNodeDesc = semanticAnalyzer.genAllExprNodeDesc(ref, inputRR);
ExprNodeDesc obExprNodeDesc = astToExprNodeDesc.get(ref);
if (obExprNodeDesc == null) {
throw new SemanticException("Invalid order by expression: " + obASTExpr.toString());
}
// 2.2 Convert ExprNode to RexNode
RexNode rexNode = converter.convert(obExprNodeDesc).accept(funcConverter);
// present in the child (& hence we add a child Project Rel)
if (rexNode instanceof RexInputRef) {
fieldIndex = ((RexInputRef) rexNode).getIndex();
} else {
fieldIndex = numSrcFields + virtualCols.size();
virtualCols.add(rexNode);
vcASTAndType.add(new Pair<>(ref, obExprNodeDesc.getTypeInfo()));
}
// 2.4 Determine the Direction of order by
RelFieldCollation.Direction direction = RelFieldCollation.Direction.DESCENDING;
if (obASTExpr.getType() == HiveASTParser.TOK_TABSORTCOLNAMEASC) {
direction = RelFieldCollation.Direction.ASCENDING;
}
RelFieldCollation.NullDirection nullOrder;
if (nullOrderASTExpr.getType() == HiveASTParser.TOK_NULLS_FIRST) {
nullOrder = RelFieldCollation.NullDirection.FIRST;
} else if (nullOrderASTExpr.getType() == HiveASTParser.TOK_NULLS_LAST) {
nullOrder = RelFieldCollation.NullDirection.LAST;
} else {
throw new SemanticException("Unexpected null ordering option: " + nullOrderASTExpr.getType());
}
// 2.5 Add to field collations
fieldCollations.add(new RelFieldCollation(fieldIndex, direction, nullOrder));
}
// 3. Add Child Project Rel if needed, Generate Output RR, input Sel Rel
// for top constraining Sel
RelNode obInputRel = srcRel;
if (!virtualCols.isEmpty()) {
List<RexNode> originalInputRefs = srcRel.getRowType().getFieldList().stream().map(input -> new RexInputRef(input.getIndex(), input.getType())).collect(Collectors.toList());
HiveParserRowResolver obSyntheticProjectRR = new HiveParserRowResolver();
if (!HiveParserRowResolver.add(obSyntheticProjectRR, inputRR)) {
throw new SemanticException("Duplicates detected when adding columns to RR: see previous message");
}
int vcolPos = inputRR.getRowSchema().getSignature().size();
for (Pair<HiveParserASTNode, TypeInfo> astTypePair : vcASTAndType) {
obSyntheticProjectRR.putExpression(astTypePair.getKey(), new ColumnInfo(getColumnInternalName(vcolPos), astTypePair.getValue(), null, false));
vcolPos++;
}
obInputRel = genSelectRelNode(CompositeList.of(originalInputRefs, virtualCols), obSyntheticProjectRR, srcRel);
if (outermostOB) {
if (!HiveParserRowResolver.add(outputRR, inputRR)) {
throw new SemanticException("Duplicates detected when adding columns to RR: see previous message");
}
} else {
if (!HiveParserRowResolver.add(outputRR, obSyntheticProjectRR)) {
throw new SemanticException("Duplicates detected when adding columns to RR: see previous message");
}
}
originalOBInput = srcRel;
} else {
if (!HiveParserRowResolver.add(outputRR, inputRR)) {
throw new SemanticException("Duplicates detected when adding columns to RR: see previous message");
}
}
// 4. Construct SortRel
RelTraitSet traitSet = cluster.traitSet();
RelCollation canonizedCollation = traitSet.canonize(RelCollationImpl.of(fieldCollations));
sortRel = LogicalSort.create(obInputRel, canonizedCollation, null, null);
// 5. Update the maps
Map<String, Integer> hiveColNameCalcitePosMap = buildHiveToCalciteColumnMap(outputRR);
relToRowResolver.put(sortRel, outputRR);
relToHiveColNameCalcitePosMap.put(sortRel, hiveColNameCalcitePosMap);
}
return (new Pair<>(sortRel, originalOBInput));
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef in project flink by apache.
the class FlinkAggregateJoinTransposeRule method toRegularAggregate.
/**
* Convert aggregate with AUXILIARY_GROUP to regular aggregate. Return original aggregate and
* null project if the given aggregate does not contain AUXILIARY_GROUP, else new aggregate
* without AUXILIARY_GROUP and a project to permute output columns if needed.
*/
private Pair<Aggregate, List<RexNode>> toRegularAggregate(Aggregate aggregate) {
Tuple2<int[], Seq<AggregateCall>> auxGroupAndRegularAggCalls = AggregateUtil.checkAndSplitAggCalls(aggregate);
final int[] auxGroup = auxGroupAndRegularAggCalls._1;
final Seq<AggregateCall> regularAggCalls = auxGroupAndRegularAggCalls._2;
if (auxGroup.length != 0) {
int[] fullGroupSet = AggregateUtil.checkAndGetFullGroupSet(aggregate);
ImmutableBitSet newGroupSet = ImmutableBitSet.of(fullGroupSet);
List<AggregateCall> aggCalls = JavaConverters.seqAsJavaListConverter(regularAggCalls).asJava();
final Aggregate newAgg = aggregate.copy(aggregate.getTraitSet(), aggregate.getInput(), aggregate.indicator, newGroupSet, com.google.common.collect.ImmutableList.of(newGroupSet), aggCalls);
final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
final List<RexNode> projectAfterAgg = new ArrayList<>();
for (int i = 0; i < fullGroupSet.length; ++i) {
int group = fullGroupSet[i];
int index = newGroupSet.indexOf(group);
projectAfterAgg.add(new RexInputRef(index, aggFields.get(i).getType()));
}
int fieldCntOfAgg = aggFields.size();
for (int i = fullGroupSet.length; i < fieldCntOfAgg; ++i) {
projectAfterAgg.add(new RexInputRef(i, aggFields.get(i).getType()));
}
Preconditions.checkArgument(projectAfterAgg.size() == fieldCntOfAgg);
return new Pair<>(newAgg, projectAfterAgg);
} else {
return new Pair<>(aggregate, null);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method createSelectDistinct.
/**
* Given an {@link org.apache.calcite.rel.core.Aggregate} and the ordinals of the arguments to a
* particular call to an aggregate function, creates a 'select distinct' relational expression
* which projects the group columns and those arguments but nothing else.
*
* <p>For example, given
*
* <blockquote>
*
* <pre>select f0, count(distinct f1), count(distinct f2)
* from t group by f0</pre>
*
* </blockquote>
*
* <p>and the argument list
*
* <blockquote>
*
* {2}
*
* </blockquote>
*
* <p>returns
*
* <blockquote>
*
* <pre>select distinct f0, f2 from t</pre>
*
* </blockquote>
*
* <p>The <code>sourceOf</code> map is populated with the source of each column; in this case
* sourceOf.get(0) = 0, and sourceOf.get(1) = 2.
*
* @param relBuilder Relational expression builder
* @param aggregate Aggregate relational expression
* @param argList Ordinals of columns to make distinct
* @param filterArg Ordinal of column to filter on, or -1
* @param sourceOf Out parameter, is populated with a map of where each output field came from
* @return Aggregate relational expression which projects the required columns
*/
private RelBuilder createSelectDistinct(RelBuilder relBuilder, Aggregate aggregate, List<Integer> argList, int filterArg, Map<Integer, Integer> sourceOf) {
relBuilder.push(aggregate.getInput());
final List<Pair<RexNode, String>> projects = new ArrayList<>();
final List<RelDataTypeField> childFields = relBuilder.peek().getRowType().getFieldList();
for (int i : aggregate.getGroupSet()) {
sourceOf.put(i, projects.size());
projects.add(RexInputRef.of2(i, childFields));
}
if (filterArg >= 0) {
sourceOf.put(filterArg, projects.size());
projects.add(RexInputRef.of2(filterArg, childFields));
}
for (Integer arg : argList) {
if (filterArg >= 0) {
// Implement
// agg(DISTINCT arg) FILTER $f
// by generating
// SELECT DISTINCT ... CASE WHEN $f THEN arg ELSE NULL END AS arg
// and then applying
// agg(arg)
// as usual.
//
// It works except for (rare) agg functions that need to see null
// values.
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RexInputRef filterRef = RexInputRef.of(filterArg, childFields);
final Pair<RexNode, String> argRef = RexInputRef.of2(arg, childFields);
RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.CASE, filterRef, argRef.left, rexBuilder.makeNullLiteral(argRef.left.getType()));
sourceOf.put(arg, projects.size());
projects.add(Pair.of(condition, "i$" + argRef.right));
continue;
}
if (sourceOf.get(arg) != null) {
continue;
}
sourceOf.put(arg, projects.size());
projects.add(RexInputRef.of2(arg, childFields));
}
relBuilder.project(Pair.left(projects), Pair.right(projects));
// Get the distinct values of the GROUP BY fields and the arguments
// to the agg functions.
relBuilder.push(aggregate.copy(aggregate.getTraitSet(), relBuilder.build(), ImmutableBitSet.range(projects.size()), null, com.google.common.collect.ImmutableList.<AggregateCall>of()));
return relBuilder;
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexInputRef in project flink by apache.
the class FlinkAggregateExpandDistinctAggregatesRule method onMatch.
// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (!AggregateUtil.containsAccurateDistinctCall(aggregate.getAggCallList())) {
return;
}
// accurate distinct call.
if (AggregateUtil.containsApproximateDistinctCall(aggregate.getAggCallList())) {
throw new TableException("There are both Distinct AggCall and Approximate Distinct AggCall in one sql statement, " + "it is not supported yet.\nPlease choose one of them.");
}
// by DecomposeGroupingSetsRule. Then this rule expands it's distinct aggregates.
if (aggregate.getGroupSets().size() > 1) {
return;
}
// Find all of the agg expressions. We use a LinkedHashSet to ensure determinism.
// Find all aggregate calls without distinct
int nonDistinctAggCallCount = 0;
// Find all aggregate calls without distinct but ignore MAX, MIN, BIT_AND, BIT_OR
int nonDistinctAggCallExcludingIgnoredCount = 0;
int filterCount = 0;
int unsupportedNonDistinctAggCallCount = 0;
final Set<Pair<List<Integer>, Integer>> argLists = new LinkedHashSet<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
if (aggCall.filterArg >= 0) {
++filterCount;
}
if (!aggCall.isDistinct()) {
++nonDistinctAggCallCount;
final SqlKind aggCallKind = aggCall.getAggregation().getKind();
// We only support COUNT/SUM/MIN/MAX for the "single" count distinct optimization
switch(aggCallKind) {
case COUNT:
case SUM:
case SUM0:
case MIN:
case MAX:
break;
default:
++unsupportedNonDistinctAggCallCount;
}
if (aggCall.getAggregation().getDistinctOptionality() == Optionality.IGNORED) {
argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
} else {
++nonDistinctAggCallExcludingIgnoredCount;
}
} else {
argLists.add(Pair.of(aggCall.getArgList(), aggCall.filterArg));
}
}
final int distinctAggCallCount = aggregate.getAggCallList().size() - nonDistinctAggCallCount;
Preconditions.checkState(argLists.size() > 0, "containsDistinctCall lied");
// we can still use this promotion.
if (nonDistinctAggCallExcludingIgnoredCount == 0 && argLists.size() == 1 && aggregate.getGroupType() == Group.SIMPLE) {
final Pair<List<Integer>, Integer> pair = com.google.common.collect.Iterables.getOnlyElement(argLists);
final RelBuilder relBuilder = call.builder();
convertMonopole(relBuilder, aggregate, pair.left, pair.right);
call.transformTo(relBuilder.build());
return;
}
if (useGroupingSets) {
rewriteUsingGroupingSets(call, aggregate);
return;
}
// we can generate multi-phase aggregates
if (// one distinct aggregate
distinctAggCallCount == 1 && // no filter
filterCount == 0 && unsupportedNonDistinctAggCallCount == // sum/min/max/count in non-distinct aggregate
0 && nonDistinctAggCallCount > 0) {
// one or more non-distinct aggregates
final RelBuilder relBuilder = call.builder();
convertSingletonDistinct(relBuilder, aggregate, argLists);
call.transformTo(relBuilder.build());
return;
}
// Create a list of the expressions which will yield the final result.
// Initially, the expressions point to the input field.
final List<RelDataTypeField> aggFields = aggregate.getRowType().getFieldList();
final List<RexInputRef> refs = new ArrayList<>();
final List<String> fieldNames = aggregate.getRowType().getFieldNames();
final ImmutableBitSet groupSet = aggregate.getGroupSet();
final int groupCount = aggregate.getGroupCount();
for (int i : Util.range(groupCount)) {
refs.add(RexInputRef.of(i, aggFields));
}
// Aggregate the original relation, including any non-distinct aggregates.
final List<AggregateCall> newAggCallList = new ArrayList<>();
int i = -1;
for (AggregateCall aggCall : aggregate.getAggCallList()) {
++i;
if (aggCall.isDistinct()) {
refs.add(null);
continue;
}
refs.add(new RexInputRef(groupCount + newAggCallList.size(), aggFields.get(groupCount + i).getType()));
newAggCallList.add(aggCall);
}
// In the case where there are no non-distinct aggregates (regardless of
// whether there are group bys), there's no need to generate the
// extra aggregate and join.
final RelBuilder relBuilder = call.builder();
relBuilder.push(aggregate.getInput());
int n = 0;
if (!newAggCallList.isEmpty()) {
final RelBuilder.GroupKey groupKey = relBuilder.groupKey(groupSet, aggregate.getGroupSets());
relBuilder.aggregate(groupKey, newAggCallList);
++n;
}
// set of operands.
for (Pair<List<Integer>, Integer> argList : argLists) {
doRewrite(relBuilder, aggregate, n++, argList.left, argList.right, refs);
}
relBuilder.project(refs, fieldNames);
call.transformTo(relBuilder.build());
}
Aggregations