use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate in project hive by apache.
the class HiveAggregateProjectMergeRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final HiveAggregate aggregate = call.rel(0);
final HiveProject project = call.rel(1);
RelNode x = apply(aggregate, project);
if (x != null) {
call.transformTo(x);
}
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate in project hive by apache.
the class PlanModifierForASTConv method replaceEmptyGroupAggr.
private static void replaceEmptyGroupAggr(final RelNode rel, RelNode parent) {
// If this function is called, the parent should only include constant
List<RexNode> exps = parent instanceof Project ? ((Project) parent).getProjects() : Collections.emptyList();
for (RexNode rexNode : exps) {
if (!rexNode.accept(new HiveCalciteUtil.ConstantFinder())) {
throw new RuntimeException("We expect " + parent.toString() + " to contain only constants. However, " + rexNode.toString() + " is " + rexNode.getKind());
}
}
HiveAggregate oldAggRel = (HiveAggregate) rel;
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RelDataType longType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, typeFactory);
RelDataType intType = TypeConverter.convert(TypeInfoFactory.intTypeInfo, typeFactory);
// Create the dummy aggregation.
SqlAggFunction countFn = SqlFunctionConverter.getCalciteAggFn("count", false, ImmutableList.of(intType), longType);
// TODO: Using 0 might be wrong; might need to walk down to find the
// proper index of a dummy.
List<Integer> argList = ImmutableList.of(0);
AggregateCall dummyCall = new AggregateCall(countFn, false, argList, longType, null);
Aggregate newAggRel = oldAggRel.copy(oldAggRel.getTraitSet(), oldAggRel.getInput(), oldAggRel.indicator, oldAggRel.getGroupSet(), oldAggRel.getGroupSets(), ImmutableList.of(dummyCall));
RelNode select = introduceDerivedTable(newAggRel);
parent.replaceInput(0, select);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate in project hive by apache.
the class HiveRelDecorrelator method decorrelateRel.
public Frame decorrelateRel(HiveAggregate rel) throws SemanticException {
// Aggregate itself should not reference cor vars.
assert !cm.mapRefRelToCorRef.containsKey(rel);
final RelNode oldInput = rel.getInput();
final Frame frame = getInvoke(oldInput, rel);
if (frame == null) {
// If input has not been rewritten, do not rewrite this rel.
return null;
}
// assert !frame.corVarOutputPos.isEmpty();
final RelNode newInput = frame.r;
// map from newInput
Map<Integer, Integer> mapNewInputToProjOutputs = new HashMap<>();
final int oldGroupKeyCount = rel.getGroupSet().cardinality();
// Project projects the original expressions,
// plus any correlated variables the input wants to pass along.
final List<Pair<RexNode, String>> projects = Lists.newArrayList();
List<RelDataTypeField> newInputOutput = newInput.getRowType().getFieldList();
int newPos = 0;
// oldInput has the original group by keys in the front.
final NavigableMap<Integer, RexLiteral> omittedConstants = new TreeMap<>();
for (int i = 0; i < oldGroupKeyCount; i++) {
final RexLiteral constant = projectedLiteral(newInput, i);
int newInputPos = frame.oldToNewOutputs.get(i);
projects.add(RexInputRef.of2(newInputPos, newInputOutput));
mapNewInputToProjOutputs.put(newInputPos, newPos);
newPos++;
}
final SortedMap<CorDef, Integer> corDefOutputs = new TreeMap<>();
if (!frame.corDefOutputs.isEmpty()) {
// position oldGroupKeyCount.
for (Map.Entry<CorDef, Integer> entry : frame.corDefOutputs.entrySet()) {
projects.add(RexInputRef.of2(entry.getValue(), newInputOutput));
corDefOutputs.put(entry.getKey(), newPos);
mapNewInputToProjOutputs.put(entry.getValue(), newPos);
newPos++;
}
}
// add the remaining fields
final int newGroupKeyCount = newPos;
for (int i = 0; i < newInputOutput.size(); i++) {
if (!mapNewInputToProjOutputs.containsKey(i)) {
projects.add(RexInputRef.of2(i, newInputOutput));
mapNewInputToProjOutputs.put(i, newPos);
newPos++;
}
}
assert newPos == newInputOutput.size();
// This Project will be what the old input maps to,
// replacing any previous mapping from old input).
RelNode newProject = HiveProject.create(newInput, Pair.left(projects), Pair.right(projects));
// update mappings:
// oldInput ----> newInput
//
// newProject
// |
// oldInput ----> newInput
//
// is transformed to
//
// oldInput ----> newProject
// |
// newInput
Map<Integer, Integer> combinedMap = Maps.newHashMap();
for (Integer oldInputPos : frame.oldToNewOutputs.keySet()) {
combinedMap.put(oldInputPos, mapNewInputToProjOutputs.get(frame.oldToNewOutputs.get(oldInputPos)));
}
register(oldInput, newProject, combinedMap, corDefOutputs);
// now it's time to rewrite the Aggregate
final ImmutableBitSet newGroupSet = ImmutableBitSet.range(newGroupKeyCount);
List<AggregateCall> newAggCalls = Lists.newArrayList();
List<AggregateCall> oldAggCalls = rel.getAggCallList();
int oldInputOutputFieldCount = rel.getGroupSet().cardinality();
int newInputOutputFieldCount = newGroupSet.cardinality();
int i = -1;
for (AggregateCall oldAggCall : oldAggCalls) {
++i;
List<Integer> oldAggArgs = oldAggCall.getArgList();
List<Integer> aggArgs = Lists.newArrayList();
// for the argument.
for (int oldPos : oldAggArgs) {
aggArgs.add(combinedMap.get(oldPos));
}
final int filterArg = oldAggCall.filterArg < 0 ? oldAggCall.filterArg : combinedMap.get(oldAggCall.filterArg);
newAggCalls.add(oldAggCall.adaptTo(newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount));
// The old to new output position mapping will be the same as that
// of newProject, plus any aggregates that the oldAgg produces.
combinedMap.put(oldInputOutputFieldCount + i, newInputOutputFieldCount + i);
}
relBuilder.push(new HiveAggregate(rel.getCluster(), rel.getTraitSet(), newProject, newGroupSet, null, newAggCalls));
if (!omittedConstants.isEmpty()) {
final List<RexNode> postProjects = new ArrayList<>(relBuilder.fields());
for (Map.Entry<Integer, RexLiteral> entry : omittedConstants.descendingMap().entrySet()) {
postProjects.add(entry.getKey() + frame.corDefOutputs.size(), entry.getValue());
}
relBuilder.project(postProjects);
}
// located at the same position as the input newProject.
return register(rel, relBuilder.build(), combinedMap, corDefOutputs);
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate in project hive by apache.
the class HiveIntersectRewriteRule method onMatch.
// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final HiveIntersect hiveIntersect = call.rel(0);
final RelOptCluster cluster = hiveIntersect.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
int numOfBranch = hiveIntersect.getInputs().size();
Builder<RelNode> bldr = new ImmutableList.Builder<RelNode>();
// 1st level GB: create a GB (col0, col1, count(1) as c) for each branch
for (int index = 0; index < numOfBranch; index++) {
RelNode input = hiveIntersect.getInputs().get(index);
final List<RexNode> gbChildProjLst = Lists.newArrayList();
final List<Integer> groupSetPositions = Lists.newArrayList();
for (int cInd = 0; cInd < input.getRowType().getFieldList().size(); cInd++) {
gbChildProjLst.add(rexBuilder.makeInputRef(input, cInd));
groupSetPositions.add(cInd);
}
gbChildProjLst.add(rexBuilder.makeBigintLiteral(new BigDecimal(1)));
// create the project before GB because we need a new project with extra column '1'.
RelNode gbInputRel = null;
try {
gbInputRel = HiveProject.create(input, gbChildProjLst, null);
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
// groupSetPosition includes all the positions
final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
// count(1), 1's position is input.getRowType().getFieldList().size()
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, input.getRowType().getFieldList().size(), aggFnRetType);
aggregateCalls.add(aggregateCall);
HiveRelNode aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel, groupSet, null, aggregateCalls);
bldr.add(aggregateRel);
}
// create a union above all the branches
HiveRelNode union = new HiveUnion(cluster, TraitsUtil.getDefaultTraitSet(cluster), bldr.build());
// 2nd level GB: create a GB (col0, col1, count(c)) for each branch
final List<Integer> groupSetPositions = Lists.newArrayList();
// the index of c
int cInd = union.getRowType().getFieldList().size() - 1;
for (int index = 0; index < union.getRowType().getFieldList().size(); index++) {
if (index != cInd) {
groupSetPositions.add(index);
}
}
List<AggregateCall> aggregateCalls = Lists.newArrayList();
RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
aggregateCalls.add(aggregateCall);
if (hiveIntersect.all) {
aggregateCall = HiveCalciteUtil.createSingleArgAggCall("min", cluster, TypeInfoFactory.longTypeInfo, cInd, aggFnRetType);
aggregateCalls.add(aggregateCall);
}
final ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
HiveRelNode aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), union, groupSet, null, aggregateCalls);
// add a filter count(c) = #branches
int countInd = cInd;
List<RexNode> childRexNodeLst = new ArrayList<RexNode>();
RexInputRef ref = rexBuilder.makeInputRef(aggregateRel, countInd);
RexLiteral literal = rexBuilder.makeBigintLiteral(new BigDecimal(numOfBranch));
childRexNodeLst.add(ref);
childRexNodeLst.add(literal);
ImmutableList.Builder<RelDataType> calciteArgTypesBldr = new ImmutableList.Builder<RelDataType>();
calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
RexNode factoredFilterExpr = null;
try {
factoredFilterExpr = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), true, false), childRexNodeLst);
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
RelNode filterRel = new HiveFilter(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), aggregateRel, factoredFilterExpr);
if (!hiveIntersect.all) {
// the schema for intersect distinct is like this
// R3 on all attributes + count(c) as cnt
// finally add a project to project out the last column
Set<Integer> projectOutColumnPositions = new HashSet<>();
projectOutColumnPositions.add(filterRel.getRowType().getFieldList().size() - 1);
try {
call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(filterRel, projectOutColumnPositions));
} catch (CalciteSemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
} else {
// the schema for intersect all is like this
// R3 + count(c) as cnt + min(c) as m
// we create a input project for udtf whose schema is like this
// min(c) as m + R3
List<RexNode> originalInputRefs = Lists.transform(filterRel.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>() {
@Override
public RexNode apply(RelDataTypeField input) {
return new RexInputRef(input.getIndex(), input.getType());
}
});
List<RexNode> copyInputRefs = new ArrayList<>();
copyInputRefs.add(originalInputRefs.get(originalInputRefs.size() - 1));
for (int i = 0; i < originalInputRefs.size() - 2; i++) {
copyInputRefs.add(originalInputRefs.get(i));
}
RelNode srcRel = null;
try {
srcRel = HiveProject.create(filterRel, copyInputRefs, null);
HiveTableFunctionScan udtf = HiveCalciteUtil.createUDTFForSetOp(cluster, srcRel);
// finally add a project to project out the 1st column
Set<Integer> projectOutColumnPositions = new HashSet<>();
projectOutColumnPositions.add(0);
call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(udtf, projectOutColumnPositions));
} catch (SemanticException e) {
LOG.debug(e.toString());
throw new RuntimeException(e);
}
}
}
use of org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate in project hive by apache.
the class JDBCAggregationPushDownRule method matches.
@Override
public boolean matches(RelOptRuleCall call) {
final HiveAggregate agg = call.rel(0);
final HiveJdbcConverter converter = call.rel(1);
if (agg.getGroupType() != Group.SIMPLE) {
// TODO: Grouping sets not supported yet
return false;
}
for (AggregateCall relOptRuleOperand : agg.getAggCallList()) {
SqlAggFunction f = relOptRuleOperand.getAggregation();
if (f instanceof HiveSqlCountAggFunction) {
// count distinct with more that one argument is not supported
HiveSqlCountAggFunction countAgg = (HiveSqlCountAggFunction) f;
if (countAgg.isDistinct() && 1 < relOptRuleOperand.getArgList().size()) {
return false;
}
}
SqlKind kind = f.getKind();
if (!converter.getJdbcDialect().supportsAggregateFunction(kind)) {
return false;
}
}
return true;
}
Aggregations