use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexBuilder in project calcite by apache.
the class TraitPropagationTest method run.
// Created so that we can control when the TraitDefs are defined (e.g.
// before the cluster is created).
private static RelNode run(PropAction action, RuleSet rules) throws Exception {
FrameworkConfig config = Frameworks.newConfigBuilder().ruleSets(rules).build();
final Properties info = new Properties();
final Connection connection = DriverManager.getConnection("jdbc:calcite:", info);
final CalciteServerStatement statement = connection.createStatement().unwrap(CalciteServerStatement.class);
final CalcitePrepare.Context prepareContext = statement.createPrepareContext();
final JavaTypeFactory typeFactory = prepareContext.getTypeFactory();
CalciteCatalogReader catalogReader = new CalciteCatalogReader(prepareContext.getRootSchema(), prepareContext.getDefaultSchemaPath(), typeFactory, prepareContext.config());
final RexBuilder rexBuilder = new RexBuilder(typeFactory);
final RelOptPlanner planner = new VolcanoPlanner(config.getCostFactory(), config.getContext());
// set up rules before we generate cluster
planner.clearRelTraitDefs();
planner.addRelTraitDef(RelCollationTraitDef.INSTANCE);
planner.addRelTraitDef(ConventionTraitDef.INSTANCE);
planner.clear();
for (RelOptRule r : rules) {
planner.addRule(r);
}
final RelOptCluster cluster = RelOptCluster.create(planner, rexBuilder);
return action.apply(cluster, catalogReader, prepareContext.getRootSchema().plus());
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexBuilder in project drill by axbaretto.
the class ConvertHiveParquetScanToDrillParquetScan method createProjectRel.
/**
* Create a project that converts the native scan output to expected output of Hive scan.
*/
private DrillProjectRel createProjectRel(final DrillScanRel hiveScanRel, final Map<String, String> partitionColMapping, final DrillScanRel nativeScanRel) {
final List<RexNode> rexNodes = Lists.newArrayList();
final RexBuilder rb = hiveScanRel.getCluster().getRexBuilder();
final RelDataType hiveScanRowType = hiveScanRel.getRowType();
for (String colName : hiveScanRowType.getFieldNames()) {
final String dirColName = partitionColMapping.get(colName);
if (dirColName != null) {
rexNodes.add(createPartitionColumnCast(hiveScanRel, nativeScanRel, colName, dirColName, rb));
} else {
rexNodes.add(createColumnFormatConversion(hiveScanRel, nativeScanRel, colName, rb));
}
}
return DrillProjectRel.create(hiveScanRel.getCluster(), hiveScanRel.getTraitSet(), nativeScanRel, rexNodes, hiveScanRowType);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexBuilder in project drill by axbaretto.
the class DrillReduceAggregatesRule method reduceAggs.
/*
private boolean isMatch(AggregateCall call) {
if (call.getAggregation() instanceof SqlAvgAggFunction) {
final SqlAvgAggFunction.Subtype subtype =
((SqlAvgAggFunction) call.getAggregation()).getSubtype();
return (subtype == SqlAvgAggFunction.Subtype.AVG);
}
return false;
}
*/
/**
* Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
* the aggregates list to.
*
* <p>It handles newly generated common subexpressions since this was done
* at the sql2rel stage.
*/
private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
final int nGroups = oldAggRel.getGroupCount();
List<AggregateCall> newCalls = new ArrayList<>();
Map<AggregateCall, RexNode> aggCallMapping = new HashMap<>();
List<RexNode> projList = new ArrayList<>();
// pass through group key
for (int i = 0; i < nGroups; ++i) {
projList.add(rexBuilder.makeInputRef(getFieldType(oldAggRel, i), i));
}
// List of input expressions. If a particular aggregate needs more, it
// will add an expression to the end, and we will create an extra
// project.
RelNode input = oldAggRel.getInput();
List<RexNode> inputExprs = new ArrayList<>();
for (RelDataTypeField field : input.getRowType().getFieldList()) {
inputExprs.add(rexBuilder.makeInputRef(field.getType(), inputExprs.size()));
}
// create new agg function calls and rest of project list together
for (AggregateCall oldCall : oldCalls) {
projList.add(reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
}
final int extraArgCount = inputExprs.size() - input.getRowType().getFieldCount();
if (extraArgCount > 0) {
input = RelOptUtil.createProject(input, inputExprs, CompositeList.of(input.getRowType().getFieldNames(), Collections.<String>nCopies(extraArgCount, null)), false, relBuilderFactory.create(input.getCluster(), null));
}
Aggregate newAggRel = newAggregateRel(oldAggRel, input, newCalls);
RelNode projectRel = RelOptUtil.createProject(newAggRel, projList, oldAggRel.getRowType().getFieldNames());
ruleCall.transformTo(projectRel);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexBuilder in project drill by axbaretto.
the class DrillReduceAggregatesRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argType = getFieldType(oldAggRel.getInput(), argOrdinal);
// final RexNode argRef = inputExprs.get(argOrdinal);
RexNode argRef = rexBuilder.makeCall(CastHighOp, inputExprs.get(argOrdinal));
inputExprs.set(argOrdinal, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.<DrillFuncHolder>of()).inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, true);
final AggregateCall sumArgSquaredAggCall = AggregateCall.create(new DrillCalciteSqlAggFunctionWrapper(new SqlSumAggFunction(sumType), sumType), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argSquaredOrdinal), -1, sumType, null);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final AggregateCall sumArgAggCall = AggregateCall.create(new DrillCalciteSqlAggFunctionWrapper(new SqlSumAggFunction(sumType), sumType), oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argOrdinal), -1, sumType, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
final AggregateCall countArgAggCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final SqlOperator divide;
if (isInferenceEnabled) {
divide = new DrillSqlOperator("divide", 2, true, oldCall.getType(), false);
} else {
divide = SqlStdOperatorTable.DIVIDE;
}
final RexNode div = rexBuilder.makeCall(divide, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
if (isInferenceEnabled) {
return result;
} else {
/*
* Currently calcite's strategy to infer the return type of aggregate functions
* is wrong because it uses the first known argument to determine output type. For
* instance if we are performing stddev on an integer column then it interprets the
* output type to be integer which is incorrect as it should be double. So based on
* this if we add cast after rewriting the aggregate we add an additional cast which
* would cause wrong results. So we simply add a cast to ANY.
*/
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), result);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexBuilder in project drill by axbaretto.
the class DrillReduceAggregatesRule method reduceAvg.
private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
final PlannerSettings plannerSettings = (PlannerSettings) oldAggRel.getCluster().getPlanner().getContext();
final boolean isInferenceEnabled = plannerSettings.isTypeInferenceEnabled();
final int nGroups = oldAggRel.getGroupCount();
RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
int iAvgInput = oldCall.getArgList().get(0);
RelDataType avgInputType = getFieldType(oldAggRel.getInput(), iAvgInput);
RelDataType sumType = TypeInferenceUtils.getDrillSqlReturnTypeInference(SqlKind.SUM.name(), ImmutableList.<DrillFuncHolder>of()).inferReturnType(oldCall.createBinding(oldAggRel));
sumType = typeFactory.createTypeWithNullability(sumType, sumType.isNullable() || nGroups == 0);
SqlAggFunction sumAgg = new DrillCalciteSqlAggFunctionWrapper(new SqlSumEmptyIsZeroAggFunction(), sumType);
AggregateCall sumCall = AggregateCall.create(sumAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, sumType, null);
final SqlCountAggFunction countAgg = (SqlCountAggFunction) SqlStdOperatorTable.COUNT;
final RelDataType countType = countAgg.getReturnType(typeFactory);
AggregateCall countCall = AggregateCall.create(countAgg, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), -1, countType, null);
RexNode tmpsumRef = rexBuilder.addAggCall(sumCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode tmpcountRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
RexNode n = rexBuilder.makeCall(SqlStdOperatorTable.CASE, rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, tmpcountRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.constantNull(), tmpsumRef);
// NOTE: these references are with respect to the output
// of newAggRel
/*
RexNode numeratorRef =
rexBuilder.makeCall(CastHighOp,
rexBuilder.addAggCall(
sumCall,
nGroups,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType))
);
*/
RexNode numeratorRef = rexBuilder.makeCall(CastHighOp, n);
RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
if (isInferenceEnabled) {
return rexBuilder.makeCall(new DrillSqlOperator("divide", 2, true, oldCall.getType(), false), numeratorRef, denominatorRef);
} else {
final RexNode divideRef = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(typeFactory.createSqlType(SqlTypeName.ANY), divideRef);
}
}
Aggregations