use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster in project calcite by apache.
the class LogicalTableModify method create.
/**
* Creates a LogicalTableModify.
*/
public static LogicalTableModify create(RelOptTable table, Prepare.CatalogReader schema, RelNode input, Operation operation, List<String> updateColumnList, List<RexNode> sourceExpressionList, boolean flattened) {
final RelOptCluster cluster = input.getCluster();
final RelTraitSet traitSet = cluster.traitSetOf(Convention.NONE);
return new LogicalTableModify(cluster, traitSet, table, schema, input, operation, updateColumnList, sourceExpressionList, flattened);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster in project calcite by apache.
the class EnumerableCalc method create.
/**
* Creates an EnumerableCalc.
*/
public static EnumerableCalc create(final RelNode input, final RexProgram program) {
final RelOptCluster cluster = input.getCluster();
final RelMetadataQuery mq = cluster.getMetadataQuery();
final RelTraitSet traitSet = cluster.traitSet().replace(EnumerableConvention.INSTANCE).replaceIfs(RelCollationTraitDef.INSTANCE, new Supplier<List<RelCollation>>() {
public List<RelCollation> get() {
return RelMdCollation.calc(mq, input, program);
}
}).replaceIf(RelDistributionTraitDef.INSTANCE, new Supplier<RelDistribution>() {
public RelDistribution get() {
return RelMdDistribution.calc(mq, input, program);
}
});
return new EnumerableCalc(cluster, traitSet, input, program);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster in project calcite by apache.
the class EnumerableJoinRule method convert.
@Override
public RelNode convert(RelNode rel) {
LogicalJoin join = (LogicalJoin) rel;
List<RelNode> newInputs = new ArrayList<>();
for (RelNode input : join.getInputs()) {
if (!(input.getConvention() instanceof EnumerableConvention)) {
input = convert(input, input.getTraitSet().replace(EnumerableConvention.INSTANCE));
}
newInputs.add(input);
}
final RelOptCluster cluster = join.getCluster();
final RelTraitSet traitSet = join.getTraitSet().replace(EnumerableConvention.INSTANCE);
final RelNode left = newInputs.get(0);
final RelNode right = newInputs.get(1);
final JoinInfo info = JoinInfo.of(left, right, join.getCondition());
if (!info.isEqui() && join.getJoinType() != JoinRelType.INNER) {
// if it is an inner join.
try {
return new EnumerableThetaJoin(cluster, traitSet, left, right, join.getCondition(), join.getVariablesSet(), join.getJoinType());
} catch (InvalidRelException e) {
EnumerableRules.LOGGER.debug(e.toString());
return null;
}
}
RelNode newRel;
try {
newRel = new EnumerableJoin(cluster, join.getTraitSet().replace(EnumerableConvention.INSTANCE), left, right, info.getEquiCondition(left, right, cluster.getRexBuilder()), info.leftKeys, info.rightKeys, join.getVariablesSet(), join.getJoinType());
} catch (InvalidRelException e) {
EnumerableRules.LOGGER.debug(e.toString());
return null;
}
if (!info.isEqui()) {
newRel = new EnumerableFilter(cluster, newRel.getTraitSet(), newRel, info.getRemaining(cluster.getRexBuilder()));
}
return newRel;
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster in project calcite by apache.
the class EnumerableProject method create.
/**
* Creates an EnumerableProject, specifying row type rather than field
* names.
*/
public static EnumerableProject create(final RelNode input, final List<? extends RexNode> projects, RelDataType rowType) {
final RelOptCluster cluster = input.getCluster();
final RelMetadataQuery mq = cluster.getMetadataQuery();
final RelTraitSet traitSet = cluster.traitSet().replace(EnumerableConvention.INSTANCE).replaceIfs(RelCollationTraitDef.INSTANCE, new Supplier<List<RelCollation>>() {
public List<RelCollation> get() {
return RelMdCollation.project(mq, input, projects);
}
});
return new EnumerableProject(cluster, traitSet, input, projects, rowType);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.plan.RelOptCluster in project calcite by apache.
the class AggregateReduceFunctionsRule method reduceStddev.
private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
// stddev_pop(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / count(x),
// .5)
//
// stddev_samp(x) ==>
// power(
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / nullif(count(x) - 1, 0),
// .5)
final int nGroups = oldAggRel.getGroupCount();
final RelOptCluster cluster = oldAggRel.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), argOrdinalType.isNullable());
final RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
final AggregateCall sumArgSquaredAggCall = createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM, argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
final RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
final AggregateCall sumArgAggCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), ImmutableIntList.of(argOrdinal), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
final RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
final AggregateCall countArgAggCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg, oldAggRel.getGroupCount(), oldAggRel, null, null);
final RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, oldAggRel.indicator, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
final RexNode avgSumSquaredArg = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumArgSquared, avgSumSquaredArg);
final RexNode denominator;
if (biased) {
denominator = countArg;
} else {
final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
}
final RexNode div = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
RexNode result = div;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
}
return rexBuilder.makeCast(oldCall.getType(), result);
}
Aggregations