use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveRelFieldTrimmer method rewriteGBConstantKeys.
/**
* This method replaces group by 'constant key' with group by true (boolean)
* if and only if
* group by doesn't have grouping sets
* all keys in group by are constant
* none of the relnode above aggregate refers to these keys
*
* If all of above is true then group by is rewritten and a new project is introduced
* underneath aggregate
*
* This is mainly done so that hive is able to push down queries with
* group by 'constant key with type not supported by druid' into druid.
*/
private Aggregate rewriteGBConstantKeys(Aggregate aggregate, ImmutableBitSet fieldsUsed, ImmutableBitSet aggCallFields) {
if ((aggregate.getIndicatorCount() > 0) || (aggregate.getGroupSet().isEmpty()) || fieldsUsed.contains(aggregate.getGroupSet())) {
return aggregate;
}
final RelNode input = aggregate.getInput();
final RelDataType rowType = input.getRowType();
RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final List<RexNode> newProjects = new ArrayList<>();
final List<RexNode> inputExprs = input instanceof Project ? ((Project) input).getProjects() : null;
if (inputExprs == null || inputExprs.isEmpty()) {
return aggregate;
}
boolean allConstants = true;
for (int key : aggregate.getGroupSet()) {
// getChildExprs on Join could return less number of expressions than there are coming out of join
if (inputExprs.size() <= key || !isRexLiteral(inputExprs.get(key))) {
allConstants = false;
break;
}
}
if (allConstants) {
for (int i = 0; i < rowType.getFieldCount(); i++) {
if (aggregate.getGroupSet().get(i) && !aggCallFields.get(i)) {
newProjects.add(rexBuilder.makeLiteral(true));
} else {
newProjects.add(rexBuilder.makeInputRef(input, i));
}
}
final RelBuilder relBuilder = REL_BUILDER.get();
relBuilder.push(input);
relBuilder.project(newProjects);
Aggregate newAggregate = new HiveAggregate(aggregate.getCluster(), aggregate.getTraitSet(), relBuilder.build(), aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList());
return newAggregate;
}
return aggregate;
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project hive by apache.
the class HiveRemoveGBYSemiJoinRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Join join = call.rel(0);
if (join.getJoinType() != JoinRelType.SEMI && join.getJoinType() != JoinRelType.ANTI) {
return;
}
final RelNode left = call.rel(1);
final Aggregate rightAggregate = call.rel(2);
// if grouping sets are involved do early return
if (rightAggregate.getGroupType() != Aggregate.Group.SIMPLE) {
return;
}
if (rightAggregate.indicator) {
return;
}
// if there is any aggregate function this group by is not un-necessary
if (!rightAggregate.getAggCallList().isEmpty()) {
return;
}
JoinPredicateInfo joinPredInfo;
try {
joinPredInfo = HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(join);
} catch (CalciteSemanticException e) {
LOG.warn("Exception while extracting predicate info from {}", join);
return;
}
if (!joinPredInfo.getNonEquiJoinPredicateElements().isEmpty()) {
return;
}
ImmutableBitSet.Builder rightKeys = ImmutableBitSet.builder();
for (JoinLeafPredicateInfo leftPredInfo : joinPredInfo.getEquiJoinPredicateElements()) {
rightKeys.addAll(leftPredInfo.getProjsFromRightPartOfJoinKeysInChildSchema());
}
boolean shouldTransform = rightKeys.build().equals(ImmutableBitSet.range(rightAggregate.getGroupCount()));
if (shouldTransform) {
final RelBuilder relBuilder = call.builder();
RelNode newRightInput = relBuilder.project(relBuilder.push(rightAggregate.getInput()).fields(rightAggregate.getGroupSet().asList())).build();
RelNode newJoin;
if (join.getJoinType() == JoinRelType.SEMI) {
newJoin = call.builder().push(left).push(newRightInput).semiJoin(join.getCondition()).build();
} else {
newJoin = call.builder().push(left).push(newRightInput).antiJoin(join.getCondition()).build();
}
call.transformTo(newJoin);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project beam by apache.
the class BeamAggregationRule method updateWindow.
private static RelNode updateWindow(Aggregate aggregate, Project project) {
ImmutableBitSet groupByFields = aggregate.getGroupSet();
ArrayList<RexNode> projects = new ArrayList(project.getProjects());
WindowFn windowFn = null;
int windowFieldIndex = -1;
for (int groupFieldIndex : groupByFields.asList()) {
RexNode projNode = projects.get(groupFieldIndex);
if (!(projNode instanceof RexCall)) {
continue;
}
RexCall rexCall = (RexCall) projNode;
WindowFn fn = createWindowFn(rexCall.getOperands(), rexCall.op.kind);
if (fn != null) {
windowFn = fn;
windowFieldIndex = groupFieldIndex;
projects.set(groupFieldIndex, rexCall.getOperands().get(0));
}
}
if (windowFn == null) {
return null;
}
final Project newProject = project.copy(project.getTraitSet(), project.getInput(), projects, project.getRowType());
return new BeamAggregationRel(aggregate.getCluster(), aggregate.getTraitSet().replace(BeamLogicalConvention.INSTANCE), convert(newProject, newProject.getTraitSet().replace(BeamLogicalConvention.INSTANCE)), aggregate.getGroupSet(), aggregate.getGroupSets(), aggregate.getAggCallList(), windowFn, windowFieldIndex);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project beam by apache.
the class BeamAggregationRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final Project project = call.rel(1);
if (aggregate.getGroupType() != Aggregate.Group.SIMPLE) {
return;
}
RelNode x = updateWindow(aggregate, project);
if (x == null) {
// Non-windowed case should be handled by the BeamBasicAggregationRule
return;
}
call.transformTo(x);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.Aggregate in project beam by apache.
the class BeamDDLTest method unparseAggregateFunction.
@Test
public void unparseAggregateFunction() {
SqlIdentifier name = new SqlIdentifier("foo", SqlParserPos.ZERO);
SqlNode jarPath = SqlLiteral.createCharString("path/to/udf.jar", SqlParserPos.ZERO);
SqlCreateFunction createFunction = new SqlCreateFunction(SqlParserPos.ZERO, false, name, jarPath, true);
SqlWriter sqlWriter = new SqlPrettyWriter(BeamBigQuerySqlDialect.DEFAULT);
createFunction.unparse(sqlWriter, 0, 0);
assertEquals("CREATE AGGREGATE FUNCTION foo USING JAR 'path/to/udf.jar'", sqlWriter.toSqlString().getSql());
}
Aggregations