use of org.apache.calcite.rel.type.RelDataType in project drill by apache.
the class InsertLocalExchangeVisitor method visitExchange.
@Override
public Prel visitExchange(ExchangePrel prel, Void value) throws RuntimeException {
Prel child = ((Prel) prel.getInput()).accept(this, null);
// If DeMuxExchange is enabled, insert a UnorderedDeMuxExchangePrel after HashToRandomExchangePrel.
if (!(prel instanceof HashToRandomExchangePrel)) {
return (Prel) prel.copy(prel.getTraitSet(), Collections.singletonList(((RelNode) child)));
}
Prel newPrel = child;
final HashToRandomExchangePrel hashPrel = (HashToRandomExchangePrel) prel;
final List<String> childFields = child.getRowType().getFieldNames();
List<RexNode> removeUpdatedExpr = null;
if (isMuxEnabled) {
// Insert Project Operator with new column that will be a hash for HashToRandomExchange fields
final List<DistributionField> distFields = hashPrel.getFields();
final List<String> outputFieldNames = Lists.newArrayList(childFields);
final RexBuilder rexBuilder = prel.getCluster().getRexBuilder();
final List<RelDataTypeField> childRowTypeFields = child.getRowType().getFieldList();
final HashExpressionCreatorHelper<RexNode> hashHelper = new RexNodeBasedHashExpressionCreatorHelper(rexBuilder);
final List<RexNode> distFieldRefs = Lists.newArrayListWithExpectedSize(distFields.size());
for (int i = 0; i < distFields.size(); i++) {
final int fieldId = distFields.get(i).getFieldId();
distFieldRefs.add(rexBuilder.makeInputRef(childRowTypeFields.get(fieldId).getType(), fieldId));
}
final List<RexNode> updatedExpr = Lists.newArrayListWithExpectedSize(childRowTypeFields.size());
removeUpdatedExpr = Lists.newArrayListWithExpectedSize(childRowTypeFields.size());
for (RelDataTypeField field : childRowTypeFields) {
RexNode rex = rexBuilder.makeInputRef(field.getType(), field.getIndex());
updatedExpr.add(rex);
removeUpdatedExpr.add(rex);
}
outputFieldNames.add(HashPrelUtil.HASH_EXPR_NAME);
// distribution seed
final RexNode distSeed = rexBuilder.makeBigintLiteral(BigDecimal.valueOf(HashPrelUtil.DIST_SEED));
updatedExpr.add(HashPrelUtil.createHashBasedPartitionExpression(distFieldRefs, distSeed, hashHelper));
RelDataType rowType = RexUtil.createStructType(prel.getCluster().getTypeFactory(), updatedExpr, outputFieldNames);
ProjectPrel addColumnprojectPrel = new ProjectPrel(child.getCluster(), child.getTraitSet(), child, updatedExpr, rowType);
newPrel = new UnorderedMuxExchangePrel(addColumnprojectPrel.getCluster(), addColumnprojectPrel.getTraitSet(), addColumnprojectPrel);
}
newPrel = new HashToRandomExchangePrel(prel.getCluster(), prel.getTraitSet(), newPrel, ((HashToRandomExchangePrel) prel).getFields());
if (isDeMuxEnabled) {
HashToRandomExchangePrel hashExchangePrel = (HashToRandomExchangePrel) newPrel;
// Insert a DeMuxExchange to narrow down the number of receivers
newPrel = new UnorderedDeMuxExchangePrel(prel.getCluster(), prel.getTraitSet(), hashExchangePrel, hashExchangePrel.getFields());
}
if (isMuxEnabled) {
// remove earlier inserted Project Operator - since it creates issues down the road in HashJoin
RelDataType removeRowType = RexUtil.createStructType(newPrel.getCluster().getTypeFactory(), removeUpdatedExpr, childFields);
ProjectPrel removeColumnProjectPrel = new ProjectPrel(newPrel.getCluster(), newPrel.getTraitSet(), newPrel, removeUpdatedExpr, removeRowType);
return removeColumnProjectPrel;
}
return newPrel;
}
use of org.apache.calcite.rel.type.RelDataType in project drill by apache.
the class StarColumnConverter method visitProject.
@Override
public Prel visitProject(ProjectPrel prel, Void value) throws RuntimeException {
ProjectPrel proj = (ProjectPrel) prel;
// Require prefix rename : there exists other expression, in addition to a star column.
if (// not set yet.
!prefixedForStar && StarColumnHelper.containsStarColumnInProject(prel.getInput().getRowType(), proj.getProjects()) && prel.getRowType().getFieldNames().size() > 1) {
prefixedForStar = true;
}
// For project, we need make sure that the project's field name is same as the input,
// when the project expression is RexInPutRef, since we may insert a PAS which will
// rename the projected fields.
RelNode child = ((Prel) prel.getInput(0)).accept(this, null);
List<String> fieldNames = Lists.newArrayList();
for (Pair<String, RexNode> pair : Pair.zip(prel.getRowType().getFieldNames(), proj.getProjects())) {
if (pair.right instanceof RexInputRef) {
String name = child.getRowType().getFieldNames().get(((RexInputRef) pair.right).getIndex());
fieldNames.add(name);
} else {
fieldNames.add(pair.left);
}
}
// Make sure the field names are unique : no allow of duplicate field names in a rowType.
fieldNames = makeUniqueNames(fieldNames);
RelDataType rowType = RexUtil.createStructType(prel.getCluster().getTypeFactory(), proj.getProjects(), fieldNames);
ProjectPrel newProj = (ProjectPrel) proj.copy(proj.getTraitSet(), child, proj.getProjects(), rowType);
if (ProjectRemoveRule.isTrivial(newProj)) {
return (Prel) child;
} else {
return newProj;
}
}
use of org.apache.calcite.rel.type.RelDataType in project drill by apache.
the class DrillReduceAggregatesRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
// final RexNode argRef = inputExprs.get(argOrdinal);
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
inputExprs.set(argOrdinal, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final RelDataType sumType = typeFactory.createTypeWithNullability(argType, true);
final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final AggregateCall sumArgAggCall = AggregateCall.create(new SqlSumAggFunction(sumType), oldCall.isDistinct(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType().getSqlTypeName());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final SqlOperator divide;
if (isInferenceEnabled) {
divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
} else {
divide = SqlStdOperatorTable.DIVIDE;
}
final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
if (isInferenceEnabled) {
return result;
} else {
/*
* Currently calcite's strategy to infer the return type of aggregate functions
* is wrong because it uses the first known argument to determine output type. For
* instance if we are performing stddev on an integer column then it interprets the
* output type to be integer which is incorrect as it should be double. So based on
* this if we add cast after rewriting the aggregate we add an additional cast which
* would cause wrong results. So we simply add a cast to ANY.
*/
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
}
}
use of org.apache.calcite.rel.type.RelDataType in project drill by apache.
the class DrillReduceAggregatesRule method reduceSum.
private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int arg = oldCall.getArgList().get(0);
RelDataType argType = getFieldType(oldAggRel.getInput(), arg);
final RelDataType sumType;
final SqlAggFunction sumZeroAgg;
if (isInferenceEnabled) {
sumType = oldCall.getType();
sumZeroAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
} else {
sumType = typeFactory.createTypeWithNullability(argType, argType.isNullable());
sumZeroAgg = new SqlSumEmptyIsZeroAggFunction();
}
AggregateCall sumZeroCall = AggregateCall.create(sumZeroAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.getArgList(), -1, countType, null);
// NOTE: these references are with respect to the output
// of newAggRel
RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
if (!oldCall.getType().isNullable()) {
// null). Therefore we translate to SUM0(x).
return sumZeroRef;
}
RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
return rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), sumZeroRef);
}
use of org.apache.calcite.rel.type.RelDataType in project drill by apache.
the class DrillReduceAggregatesRule method reduceAgg.
private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
final SqlAggFunction sqlAggFunction = DrillCalciteWrapperUtility.extractSqlOperatorFromWrapper(oldCall.getAggregation());
if (sqlAggFunction instanceof SqlSumAggFunction) {
// case COUNT(x) when 0 then null else SUM0(x) end
return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
}
if (sqlAggFunction instanceof SqlAvgAggFunction) {
final SqlAvgAggFunction.Subtype subtype = ((SqlAvgAggFunction) sqlAggFunction).getSubtype();
switch(subtype) {
case AVG:
// replace original AVG(x) with SUM(x) / COUNT(x)
return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
case STDDEV_POP:
// / COUNT(x))
return reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
case STDDEV_SAMP:
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
return reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
case VAR_POP:
// / COUNT(x)
return reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
case VAR_SAMP:
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
return reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
default:
throw Util.unexpected(subtype);
}
} else {
// anything else: preserve original call
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
final int nGroups = oldAggRel.getGroupCount();
List<RelDataType> oldArgTypes = new ArrayList<>();
List<Integer> ordinals = oldCall.getArgList();
assert ordinals.size() <= inputExprs.size();
for (int ordinal : ordinals) {
oldArgTypes.add(inputExprs.get(ordinal).getType());
}
return rexBuilder.addAggCall(oldCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, oldArgTypes);
}
}
Aggregations