use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class DrillReduceAggregatesRule method reduceAvg.
private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
RelDataType sumType = typeFactory.createTypeWithNullability(avgInputType, avgInputType.isNullable() || nGroups == 0);
SqlAggFunction sumAgg = new SqlSumEmptyIsZeroAggFunction();
AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
// NOTE: these references are with respect to the output
// of newAggRel
/*
RexNode numeratorRef =
rexBuilder.makeCall(CastHighOp,
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType))
);
*/
RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
if (isInferenceEnabled) {
return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
} else {
final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
}
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class DrillReduceAggregatesRule method reduceAggs.
/*
private boolean isMatch(AggregateCall call) {
if (call.getAggregation() instanceof SqlAvgAggFunction) {
final SqlAvgAggFunction.Subtype subtype =
((SqlAvgAggFunction) call.getAggregation()).getSubtype();
return (subtype == SqlAvgAggFunction.Subtype.AVG);
}
return false;
}
*/
/**
* Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
* the aggregates list to.
*
* <p>It handles newly generated common subexpressions since this was done
* at the sql2rel stage.
*/
private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
final int nGroups = oldAggRel.getGroupCount();
List<AggregateCall> newCalls = new ArrayList<>();
Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
List<RexNode> projList = new ArrayList<>();
// pass through group key
for (int i = 0; i < nGroups; ++i) {
projList.add(rexBuilder.makeInputRef(getFieldType(oldAggRel, i), i));
}
// List of input expressions. If a particular aggregate needs more, it
// will add an expression to the end, and we will create an extra
// project.
RelNode input = oldAggRel.getInput();
List<RexNode> inputExprs = new ArrayList<>();
for (RelDataTypeField field : input.getRowType().getFieldList()) {
inputExprs.add(rexBuilder.makeInputRef(field.getType(), inputExprs.size()));
}
// create new agg function calls and rest of project list together
for (AggregateCall oldCall : oldCalls) {
projList.add(reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
}
final int extraArgCount = inputExprs.size() - input.getRowType().getFieldCount();
if (extraArgCount > 0) {
input = RelOptUtil.createProject(input, inputExprs, CompositeList.of(input.getRowType().getFieldNames(), Collections.<String>nCopies(extraArgCount, null)));
}
Aggregate newAggRel = newAggregateRel(oldAggRel, input, newCalls);
RelNode projectRel = RelOptUtil.createProject(newAggRel, projList, oldAggRel.getRowType().getFieldNames());
ruleCall.transformTo(projectRel);
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class DrillWindowRel method implement.
@Override
public LogicalOperator implement(DrillImplementor implementor) {
final LogicalOperator inputOp = implementor.visitChild(this, 0, getInput());
org.apache.drill.common.logical.data.Window.Builder builder = new org.apache.drill.common.logical.data.Window.Builder();
final List<String> fields = getRowType().getFieldNames();
final List<String> childFields = getInput().getRowType().getFieldNames();
for (Group window : groups) {
for (RelFieldCollation orderKey : window.orderKeys.getFieldCollations()) {
builder.addOrdering(new Order.Ordering(orderKey.getDirection(), new FieldReference(fields.get(orderKey.getFieldIndex()))));
}
for (int group : BitSets.toIter(window.keys)) {
FieldReference fr = new FieldReference(childFields.get(group), ExpressionPosition.UNKNOWN);
builder.addWithin(fr, fr);
}
int groupCardinality = window.keys.cardinality();
for (Ord<AggregateCall> aggCall : Ord.zip(window.getAggregateCalls(this))) {
FieldReference ref = new FieldReference(fields.get(groupCardinality + aggCall.i));
LogicalExpression expr = toDrill(aggCall.e, childFields);
builder.addAggregation(ref, expr);
}
}
builder.setInput(inputOp);
org.apache.drill.common.logical.data.Window frame = builder.build();
return frame;
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class DrillAggregateRel method implement.
@Override
public LogicalOperator implement(DrillImplementor implementor) {
GroupingAggregate.Builder builder = GroupingAggregate.builder();
builder.setInput(implementor.visitChild(this, 0, getInput()));
final List<String> childFields = getInput().getRowType().getFieldNames();
final List<String> fields = getRowType().getFieldNames();
for (int group : BitSets.toIter(groupSet)) {
FieldReference fr = new FieldReference(childFields.get(group), ExpressionPosition.UNKNOWN);
builder.addKey(fr, fr);
}
for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
FieldReference ref = new FieldReference(fields.get(groupSet.cardinality() + aggCall.i));
LogicalExpression expr = toDrill(aggCall.e, childFields, implementor);
builder.addExpr(ref, expr);
}
return builder.build();
}
use of org.apache.calcite.rel.core.AggregateCall in project drill by apache.
the class ConvertCountToDirectScan method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final DrillAggregateRel agg = (DrillAggregateRel) call.rel(0);
final DrillScanRel scan = (DrillScanRel) call.rel(call.rels.length - 1);
final DrillProjectRel proj = call.rels.length == 3 ? (DrillProjectRel) call.rel(1) : null;
final GroupScan oldGrpScan = scan.getGroupScan();
final PlannerSettings settings = PrelUtil.getPlannerSettings(call.getPlanner());
// 4) No distinct agg call.
if (!(oldGrpScan.getScanStats(settings).getGroupScanProperty().hasExactRowCount() && agg.getGroupCount() == 0 && agg.getAggCallList().size() == 1 && !agg.containsDistinctCall())) {
return;
}
AggregateCall aggCall = agg.getAggCallList().get(0);
if (aggCall.getAggregation().getName().equals("COUNT")) {
long cnt = 0;
// count(Not-null-input) ==> rowCount
if (aggCall.getArgList().isEmpty() || (aggCall.getArgList().size() == 1 && !agg.getInput().getRowType().getFieldList().get(aggCall.getArgList().get(0).intValue()).getType().isNullable())) {
cnt = (long) oldGrpScan.getScanStats(settings).getRecordCount();
} else if (aggCall.getArgList().size() == 1) {
// count(columnName) ==> Agg ( Scan )) ==> columnValueCount
int index = aggCall.getArgList().get(0);
if (proj != null) {
if (proj.getProjects().get(index) instanceof RexInputRef) {
index = ((RexInputRef) proj.getProjects().get(index)).getIndex();
} else {
// do not apply for all other cases.
return;
}
}
String columnName = scan.getRowType().getFieldNames().get(index).toLowerCase();
cnt = oldGrpScan.getColumnValueCount(SchemaPath.getSimplePath(columnName));
if (cnt == GroupScan.NO_COLUMN_STATS) {
// if column stats are not available don't apply this rule
return;
}
} else {
// do nothing.
return;
}
RelDataType scanRowType = getCountDirectScanRowType(agg.getCluster().getTypeFactory());
final ScanPrel newScan = ScanPrel.create(scan, scan.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), getCountDirectScan(cnt), scanRowType);
List<RexNode> exprs = Lists.newArrayList();
exprs.add(RexInputRef.of(0, scanRowType));
final ProjectPrel newProj = new ProjectPrel(agg.getCluster(), agg.getTraitSet().plus(Prel.DRILL_PHYSICAL).plus(DrillDistributionTrait.SINGLETON), newScan, exprs, agg.getRowType());
call.transformTo(newProj);
}
}
Aggregations