use of org.apache.calcite.util.Pair in project calcite by apache.
the class AggregateProjectPullUpConstantsRule method onMatch.
// ~ Methods ----------------------------------------------------------------
public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
final RelNode input = call.rel(1);
assert !aggregate.indicator : "predicate ensured no grouping sets";
final int groupCount = aggregate.getGroupCount();
if (groupCount == 1) {
// GROUP BY list to the empty one.
return;
}
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelMetadataQuery mq = call.getMetadataQuery();
final RelOptPredicateList predicates = mq.getPulledUpPredicates(aggregate.getInput());
if (predicates == null) {
return;
}
final NavigableMap<Integer, RexNode> map = new TreeMap<>();
for (int key : aggregate.getGroupSet()) {
final RexInputRef ref = rexBuilder.makeInputRef(aggregate.getInput(), key);
if (predicates.constantMap.containsKey(ref)) {
map.put(key, predicates.constantMap.get(ref));
}
}
// None of the group expressions are constant. Nothing to do.
if (map.isEmpty()) {
return;
}
if (groupCount == map.size()) {
// At least a single item in group by is required.
// Otherwise "GROUP BY 1, 2" might be altered to "GROUP BY ()".
// Removing of the first element is not optimal here,
// however it will allow us to use fast path below (just trim
// groupCount).
map.remove(map.navigableKeySet().first());
}
ImmutableBitSet newGroupSet = aggregate.getGroupSet();
for (int key : map.keySet()) {
newGroupSet = newGroupSet.clear(key);
}
final int newGroupCount = newGroupSet.cardinality();
// If the constants are on the trailing edge of the group list, we just
// reduce the group count.
final RelBuilder relBuilder = call.builder();
relBuilder.push(input);
// Clone aggregate calls.
final List<AggregateCall> newAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
newAggCalls.add(aggCall.adaptTo(input, aggCall.getArgList(), aggCall.filterArg, groupCount, newGroupCount));
}
relBuilder.aggregate(relBuilder.groupKey(newGroupSet, null), newAggCalls);
// Create a projection back again.
List<Pair<RexNode, String>> projects = new ArrayList<>();
int source = 0;
for (RelDataTypeField field : aggregate.getRowType().getFieldList()) {
RexNode expr;
final int i = field.getIndex();
if (i >= groupCount) {
// Aggregate expressions' names and positions are unchanged.
expr = relBuilder.field(i - map.size());
} else {
int pos = aggregate.getGroupSet().nth(i);
if (map.containsKey(pos)) {
// Re-generate the constant expression in the project.
RelDataType originalType = aggregate.getRowType().getFieldList().get(projects.size()).getType();
if (!originalType.equals(map.get(pos).getType())) {
expr = rexBuilder.makeCast(originalType, map.get(pos), true);
} else {
expr = map.get(pos);
}
} else {
// Project the aggregation expression, in its original
// position.
expr = relBuilder.field(source);
++source;
}
}
projects.add(Pair.of(expr, field.getName()));
}
// inverse
relBuilder.project(Pair.left(projects), Pair.right(projects));
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.util.Pair in project calcite by apache.
the class MultiJoinOptimizeBushyRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final MultiJoin multiJoinRel = call.rel(0);
final RexBuilder rexBuilder = multiJoinRel.getCluster().getRexBuilder();
final RelBuilder relBuilder = call.builder();
final RelMetadataQuery mq = call.getMetadataQuery();
final LoptMultiJoin multiJoin = new LoptMultiJoin(multiJoinRel);
final List<Vertex> vertexes = Lists.newArrayList();
int x = 0;
for (int i = 0; i < multiJoin.getNumJoinFactors(); i++) {
final RelNode rel = multiJoin.getJoinFactor(i);
double cost = mq.getRowCount(rel);
vertexes.add(new LeafVertex(i, rel, cost, x));
x += rel.getRowType().getFieldCount();
}
assert x == multiJoin.getNumTotalFields();
final List<LoptMultiJoin.Edge> unusedEdges = Lists.newArrayList();
for (RexNode node : multiJoin.getJoinFilters()) {
unusedEdges.add(multiJoin.createEdge(node));
}
// Comparator that chooses the best edge. A "good edge" is one that has
// a large difference in the number of rows on LHS and RHS.
final Comparator<LoptMultiJoin.Edge> edgeComparator = new Comparator<LoptMultiJoin.Edge>() {
public int compare(LoptMultiJoin.Edge e0, LoptMultiJoin.Edge e1) {
return Double.compare(rowCountDiff(e0), rowCountDiff(e1));
}
private double rowCountDiff(LoptMultiJoin.Edge edge) {
assert edge.factors.cardinality() == 2 : edge.factors;
final int factor0 = edge.factors.nextSetBit(0);
final int factor1 = edge.factors.nextSetBit(factor0 + 1);
return Math.abs(vertexes.get(factor0).cost - vertexes.get(factor1).cost);
}
};
final List<LoptMultiJoin.Edge> usedEdges = Lists.newArrayList();
for (; ; ) {
final int edgeOrdinal = chooseBestEdge(unusedEdges, edgeComparator);
if (pw != null) {
trace(vertexes, unusedEdges, usedEdges, edgeOrdinal, pw);
}
final int[] factors;
if (edgeOrdinal == -1) {
// No more edges. Are there any un-joined vertexes?
final Vertex lastVertex = Util.last(vertexes);
final int z = lastVertex.factors.previousClearBit(lastVertex.id - 1);
if (z < 0) {
break;
}
factors = new int[] { z, lastVertex.id };
} else {
final LoptMultiJoin.Edge bestEdge = unusedEdges.get(edgeOrdinal);
// factors on this edge.
assert bestEdge.factors.cardinality() == 2;
factors = bestEdge.factors.toArray();
}
// Determine which factor is to be on the LHS of the join.
final int majorFactor;
final int minorFactor;
if (vertexes.get(factors[0]).cost <= vertexes.get(factors[1]).cost) {
majorFactor = factors[0];
minorFactor = factors[1];
} else {
majorFactor = factors[1];
minorFactor = factors[0];
}
final Vertex majorVertex = vertexes.get(majorFactor);
final Vertex minorVertex = vertexes.get(minorFactor);
// Find the join conditions. All conditions whose factors are now all in
// the join can now be used.
final int v = vertexes.size();
final ImmutableBitSet newFactors = majorVertex.factors.rebuild().addAll(minorVertex.factors).set(v).build();
final List<RexNode> conditions = Lists.newArrayList();
final Iterator<LoptMultiJoin.Edge> edgeIterator = unusedEdges.iterator();
while (edgeIterator.hasNext()) {
LoptMultiJoin.Edge edge = edgeIterator.next();
if (newFactors.contains(edge.factors)) {
conditions.add(edge.condition);
edgeIterator.remove();
usedEdges.add(edge);
}
}
double cost = majorVertex.cost * minorVertex.cost * RelMdUtil.guessSelectivity(RexUtil.composeConjunction(rexBuilder, conditions, false));
final Vertex newVertex = new JoinVertex(v, majorFactor, minorFactor, newFactors, cost, ImmutableList.copyOf(conditions));
vertexes.add(newVertex);
// Re-compute selectivity of edges above the one just chosen.
// Suppose that we just chose the edge between "product" (10k rows) and
// "product_class" (10 rows).
// Both of those vertices are now replaced by a new vertex "P-PC".
// This vertex has fewer rows (1k rows) -- a fact that is critical to
// decisions made later. (Hence "greedy" algorithm not "simple".)
// The adjacent edges are modified.
final ImmutableBitSet merged = ImmutableBitSet.of(minorFactor, majorFactor);
for (int i = 0; i < unusedEdges.size(); i++) {
final LoptMultiJoin.Edge edge = unusedEdges.get(i);
if (edge.factors.intersects(merged)) {
ImmutableBitSet newEdgeFactors = edge.factors.rebuild().removeAll(newFactors).set(v).build();
assert newEdgeFactors.cardinality() == 2;
final LoptMultiJoin.Edge newEdge = new LoptMultiJoin.Edge(edge.condition, newEdgeFactors, edge.columns);
unusedEdges.set(i, newEdge);
}
}
}
// We have a winner!
List<Pair<RelNode, Mappings.TargetMapping>> relNodes = Lists.newArrayList();
for (Vertex vertex : vertexes) {
if (vertex instanceof LeafVertex) {
LeafVertex leafVertex = (LeafVertex) vertex;
final Mappings.TargetMapping mapping = Mappings.offsetSource(Mappings.createIdentity(leafVertex.rel.getRowType().getFieldCount()), leafVertex.fieldOffset, multiJoin.getNumTotalFields());
relNodes.add(Pair.of(leafVertex.rel, mapping));
} else {
JoinVertex joinVertex = (JoinVertex) vertex;
final Pair<RelNode, Mappings.TargetMapping> leftPair = relNodes.get(joinVertex.leftFactor);
RelNode left = leftPair.left;
final Mappings.TargetMapping leftMapping = leftPair.right;
final Pair<RelNode, Mappings.TargetMapping> rightPair = relNodes.get(joinVertex.rightFactor);
RelNode right = rightPair.left;
final Mappings.TargetMapping rightMapping = rightPair.right;
final Mappings.TargetMapping mapping = Mappings.merge(leftMapping, Mappings.offsetTarget(rightMapping, left.getRowType().getFieldCount()));
if (pw != null) {
pw.println("left: " + leftMapping);
pw.println("right: " + rightMapping);
pw.println("combined: " + mapping);
pw.println();
}
final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle(mapping, left, right);
final RexNode condition = RexUtil.composeConjunction(rexBuilder, joinVertex.conditions, false);
final RelNode join = relBuilder.push(left).push(right).join(JoinRelType.INNER, condition.accept(shuttle)).build();
relNodes.add(Pair.of(join, mapping));
}
if (pw != null) {
pw.println(Util.last(relNodes));
}
}
final Pair<RelNode, Mappings.TargetMapping> top = Util.last(relNodes);
relBuilder.push(top.left).project(relBuilder.fields(top.right));
call.transformTo(relBuilder.build());
}
use of org.apache.calcite.util.Pair in project calcite by apache.
the class PushProjector method createProjectRefsAndExprs.
/**
* Creates a projection based on the inputs specified in a bitmap and the
* expressions that need to be preserved. The expressions are appended after
* the input references.
*
* @param projChild child that the projection will be created on top of
* @param adjust if true, need to create new projection expressions;
* otherwise, the existing ones are reused
* @param rightSide if true, creating a projection for the right hand side
* of a join
* @return created projection
*/
public Project createProjectRefsAndExprs(RelNode projChild, boolean adjust, boolean rightSide) {
List<RexNode> preserveExprs;
int nInputRefs;
int offset;
if (rightSide) {
preserveExprs = rightPreserveExprs;
nInputRefs = nRightProject;
offset = nSysFields + nFields;
} else {
preserveExprs = childPreserveExprs;
nInputRefs = nProject;
offset = nSysFields;
}
int refIdx = offset - 1;
List<Pair<RexNode, String>> newProjects = new ArrayList<Pair<RexNode, String>>();
List<RelDataTypeField> destFields = projChild.getRowType().getFieldList();
// add on the input references
for (int i = 0; i < nInputRefs; i++) {
refIdx = projRefs.nextSetBit(refIdx + 1);
assert refIdx >= 0;
final RelDataTypeField destField = destFields.get(refIdx - offset);
newProjects.add(Pair.of((RexNode) rexBuilder.makeInputRef(destField.getType(), refIdx - offset), destField.getName()));
}
// add on the expressions that need to be preserved, converting the
// arguments to reference the projected columns (if necessary)
int[] adjustments = {};
if ((preserveExprs.size() > 0) && adjust) {
adjustments = new int[childFields.size()];
for (int idx = offset; idx < childFields.size(); idx++) {
adjustments[idx] = -offset;
}
}
for (RexNode projExpr : preserveExprs) {
RexNode newExpr;
if (adjust) {
newExpr = projExpr.accept(new RelOptUtil.RexInputConverter(rexBuilder, childFields, destFields, adjustments));
} else {
newExpr = projExpr;
}
newProjects.add(Pair.of(newExpr, ((RexCall) projExpr).getOperator().getName()));
}
return (Project) relBuilder.push(projChild).projectNamed(Pair.left(newProjects), Pair.right(newProjects), true).build();
}
use of org.apache.calcite.util.Pair in project calcite by apache.
the class JoinProjectTransposeRule method createProjectExprs.
/**
* Creates projection expressions corresponding to one of the inputs into
* the join
*
* @param projRel the projection input into the join (if it exists)
* @param joinChild the child of the projection input (if there is a
* projection); otherwise, this is the join input
* @param adjustmentAmount the amount the expressions need to be shifted by
* @param rexBuilder rex builder
* @param joinChildrenFields concatenation of the fields from the left and
* right join inputs (once the projections have been
* removed)
* @param projects Projection expressions & names to be created
*/
protected void createProjectExprs(Project projRel, RelNode joinChild, int adjustmentAmount, RexBuilder rexBuilder, List<RelDataTypeField> joinChildrenFields, List<Pair<RexNode, String>> projects) {
List<RelDataTypeField> childFields = joinChild.getRowType().getFieldList();
if (projRel != null) {
List<Pair<RexNode, String>> namedProjects = projRel.getNamedProjects();
int nChildFields = childFields.size();
int[] adjustments = new int[nChildFields];
for (int i = 0; i < nChildFields; i++) {
adjustments[i] = adjustmentAmount;
}
for (Pair<RexNode, String> pair : namedProjects) {
RexNode e = pair.left;
if (adjustmentAmount != 0) {
// shift the references by the adjustment amount
e = e.accept(new RelOptUtil.RexInputConverter(rexBuilder, childFields, joinChildrenFields, adjustments));
}
projects.add(Pair.of(e, pair.right));
}
} else {
// no projection; just create references to the inputs
for (int i = 0; i < childFields.size(); i++) {
final RelDataTypeField field = childFields.get(i);
projects.add(Pair.of((RexNode) rexBuilder.makeInputRef(field.getType(), i + adjustmentAmount), field.getName()));
}
}
}
use of org.apache.calcite.util.Pair in project calcite by apache.
the class UnionPullUpConstantsRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
final Union union = call.rel(0);
final int count = union.getRowType().getFieldCount();
if (count == 1) {
// cycle.
return;
}
final RexBuilder rexBuilder = union.getCluster().getRexBuilder();
final RelMetadataQuery mq = call.getMetadataQuery();
final RelOptPredicateList predicates = mq.getPulledUpPredicates(union);
if (predicates == null) {
return;
}
final Map<Integer, RexNode> constants = new HashMap<>();
for (Map.Entry<RexNode, RexNode> e : predicates.constantMap.entrySet()) {
if (e.getKey() instanceof RexInputRef) {
constants.put(((RexInputRef) e.getKey()).getIndex(), e.getValue());
}
}
// None of the expressions are constant. Nothing to do.
if (constants.isEmpty()) {
return;
}
// Create expressions for Project operators before and after the Union
List<RelDataTypeField> fields = union.getRowType().getFieldList();
List<RexNode> topChildExprs = new ArrayList<>();
List<String> topChildExprsFields = new ArrayList<>();
List<RexNode> refs = new ArrayList<>();
ImmutableBitSet.Builder refsIndexBuilder = ImmutableBitSet.builder();
for (RelDataTypeField field : fields) {
final RexNode constant = constants.get(field.getIndex());
if (constant != null) {
topChildExprs.add(constant);
topChildExprsFields.add(field.getName());
} else {
final RexNode expr = rexBuilder.makeInputRef(union, field.getIndex());
topChildExprs.add(expr);
topChildExprsFields.add(field.getName());
refs.add(expr);
refsIndexBuilder.set(field.getIndex());
}
}
ImmutableBitSet refsIndex = refsIndexBuilder.build();
// Update top Project positions
final Mappings.TargetMapping mapping = RelOptUtil.permutation(refs, union.getInput(0).getRowType()).inverse();
topChildExprs = ImmutableList.copyOf(RexUtil.apply(mapping, topChildExprs));
// Create new Project-Union-Project sequences
final RelBuilder relBuilder = call.builder();
for (RelNode input : union.getInputs()) {
List<Pair<RexNode, String>> newChildExprs = new ArrayList<>();
for (int j : refsIndex) {
newChildExprs.add(Pair.<RexNode, String>of(rexBuilder.makeInputRef(input, j), input.getRowType().getFieldList().get(j).getName()));
}
if (newChildExprs.isEmpty()) {
// At least a single item in project is required.
newChildExprs.add(Pair.of(topChildExprs.get(0), topChildExprsFields.get(0)));
}
// Add the input with project on top
relBuilder.push(input);
relBuilder.project(Pair.left(newChildExprs), Pair.right(newChildExprs));
}
relBuilder.union(union.all, union.getInputs().size());
// Create top Project fixing nullability of fields
relBuilder.project(topChildExprs, topChildExprsFields);
relBuilder.convert(union.getRowType(), false);
call.transformTo(relBuilder.build());
}
Aggregations