Search in sources :

Example 6 with LogicalAggregate

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate in project beam by apache.

the class AggregateScanConverter method convert.

@Override
public RelNode convert(ResolvedAggregateScan zetaNode, List<RelNode> inputs) {
    LogicalProject input = convertAggregateScanInputScanToLogicalProject(zetaNode, inputs.get(0));
    // Calcite LogicalAggregate's GroupSet is indexes of group fields starting from 0.
    int groupFieldsListSize = zetaNode.getGroupByList().size();
    ImmutableBitSet groupSet;
    if (groupFieldsListSize != 0) {
        groupSet = ImmutableBitSet.of(IntStream.rangeClosed(0, groupFieldsListSize - 1).boxed().collect(Collectors.toList()));
    } else {
        groupSet = ImmutableBitSet.of();
    }
    // TODO: add support for indicator
    List<AggregateCall> aggregateCalls;
    if (zetaNode.getAggregateList().isEmpty()) {
        aggregateCalls = ImmutableList.of();
    } else {
        aggregateCalls = new ArrayList<>();
        // For aggregate calls, their input ref follow after GROUP BY input ref.
        int columnRefoff = groupFieldsListSize;
        for (ResolvedComputedColumn computedColumn : zetaNode.getAggregateList()) {
            AggregateCall aggCall = convertAggCall(computedColumn, columnRefoff, groupSet.size(), input);
            aggregateCalls.add(aggCall);
            if (!aggCall.getArgList().isEmpty()) {
                // Only increment column reference offset when aggregates use them (BEAM-8042).
                // Ex: COUNT(*) does not have arguments, while COUNT(`field`) does.
                columnRefoff++;
            }
        }
    }
    LogicalAggregate logicalAggregate = new LogicalAggregate(getCluster(), input.getTraitSet(), input, groupSet, ImmutableList.of(groupSet), aggregateCalls);
    return logicalAggregate;
}
Also used : AggregateCall(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.core.AggregateCall) LogicalAggregate(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate) ImmutableBitSet(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.util.ImmutableBitSet) LogicalProject(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject) ResolvedComputedColumn(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn)

Example 7 with LogicalAggregate

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate in project beam by apache.

the class AggregateScanConverter method convertAggregateScanInputScanToLogicalProject.

private LogicalProject convertAggregateScanInputScanToLogicalProject(ResolvedAggregateScan node, RelNode input) {
    // AggregateScan's input is the source of data (e.g. TableScan), which is different from the
    // design of CalciteSQL, in which the LogicalAggregate's input is a LogicalProject, whose input
    // is a LogicalTableScan. When AggregateScan's input is WithRefScan, the WithRefScan is
    // ebullient to a LogicalTableScan. So it's still required to build another LogicalProject as
    // the input of LogicalAggregate.
    List<RexNode> projects = new ArrayList<>();
    List<String> fieldNames = new ArrayList<>();
    // LogicalAggregate.
    for (ResolvedComputedColumn computedColumn : node.getGroupByList()) {
        projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(computedColumn.getExpr(), node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
        fieldNames.add(getTrait().resolveAlias(computedColumn.getColumn()));
    }
    // TODO: remove duplicate columns in projects.
    for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) {
        // Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn.
        // TODO: handle aggregate function with more than one argument and handle OVER
        // TODO: is there is general way for column reference tracking and deduplication for
        // aggregation?
        ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
        if (aggregateFunctionCall.getArgumentList() != null && aggregateFunctionCall.getArgumentList().size() >= 1) {
            ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
            for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) {
                if (i == 0) {
                    // TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
                    // TODO: user might use multiple CAST so we need to handle this rare case.
                    projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(resolvedExpr, node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
                } else {
                    projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(aggregateFunctionCall.getArgumentList().get(i)));
                }
                fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
            }
        }
    }
    return LogicalProject.create(input, ImmutableList.of(), projects, fieldNames);
}
Also used : ResolvedAggregateFunctionCall(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall) ArrayList(java.util.ArrayList) ResolvedExpr(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr) ResolvedComputedColumn(com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn) RexNode(org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rex.RexNode)

Example 8 with LogicalAggregate

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate in project samza by apache.

the class QueryTranslator method translate.

/**
 * Translate Calcite plan to Samza stream operators.
 * @param relRoot Calcite plan in the form of {@link RelRoot}. RelRoot should not include the sink ({@link TableModify})
 * @param outputSystemStream Sink associated with the Calcite plan.
 * @param translatorContext Context maintained across translations.
 * @param queryId query index of the sql statement corresponding to the Calcite plan in multi SQL statement scenario
 *                starting with index 0.
 */
public void translate(RelRoot relRoot, String outputSystemStream, TranslatorContext translatorContext, int queryId) {
    final RelNode node = relRoot.project();
    ScanTranslator scanTranslator = new ScanTranslator(sqlConfig.getSamzaRelConverters(), sqlConfig.getInputSystemStreamConfigBySource(), queryId);
    /* update input metrics */
    String queryLogicalId = String.format(TranslatorConstants.LOGSQLID_TEMPLATE, queryId);
    opId = 0;
    node.accept(new RelShuttleImpl() {

        @Override
        public RelNode visit(RelNode relNode) {
            // There should never be a TableModify in the calcite plan.
            Validate.isTrue(!(relNode instanceof TableModify));
            return super.visit(relNode);
        }

        @Override
        public RelNode visit(TableScan scan) {
            RelNode node = super.visit(scan);
            String logicalOpId = String.format(TranslatorConstants.LOGOPID_TEMPLATE, queryId, "scan", opId++);
            scanTranslator.translate(scan, queryLogicalId, logicalOpId, translatorContext, systemDescriptors, inputMsgStreams);
            return node;
        }

        @Override
        public RelNode visit(LogicalFilter filter) {
            RelNode node = visitChild(filter, 0, filter.getInput());
            String logicalOpId = String.format(TranslatorConstants.LOGOPID_TEMPLATE, queryId, "filter", opId++);
            new FilterTranslator(queryId).translate(filter, logicalOpId, translatorContext);
            return node;
        }

        @Override
        public RelNode visit(LogicalProject project) {
            RelNode node = super.visit(project);
            String logicalOpId = String.format(TranslatorConstants.LOGOPID_TEMPLATE, queryId, "project", opId++);
            new ProjectTranslator(queryId).translate(project, logicalOpId, translatorContext);
            return node;
        }

        @Override
        public RelNode visit(LogicalJoin join) {
            RelNode node = super.visit(join);
            String logicalOpId = String.format(TranslatorConstants.LOGOPID_TEMPLATE, queryId, "join", opId++);
            new JoinTranslator(logicalOpId, sqlConfig.getMetadataTopicPrefix(), queryId).translate(join, translatorContext);
            return node;
        }

        @Override
        public RelNode visit(LogicalAggregate aggregate) {
            RelNode node = super.visit(aggregate);
            String logicalOpId = String.format(TranslatorConstants.LOGOPID_TEMPLATE, queryId, "window", opId++);
            new LogicalAggregateTranslator(logicalOpId, sqlConfig.getMetadataTopicPrefix()).translate(aggregate, translatorContext);
            return node;
        }
    });
    String logicalOpId = String.format(TranslatorConstants.LOGOPID_TEMPLATE, queryId, "insert", opId);
    sendToOutputStream(queryLogicalId, logicalOpId, outputSystemStream, streamAppDescriptor, translatorContext, node, queryId);
}
Also used : TableScan(org.apache.calcite.rel.core.TableScan) LogicalFilter(org.apache.calcite.rel.logical.LogicalFilter) RelShuttleImpl(org.apache.calcite.rel.RelShuttleImpl) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RelNode(org.apache.calcite.rel.RelNode) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) LogicalProject(org.apache.calcite.rel.logical.LogicalProject) TableModify(org.apache.calcite.rel.core.TableModify)

Example 9 with LogicalAggregate

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate in project calcite by apache.

the class RelMetadataTest method checkAverageRowSize.

private void checkAverageRowSize(RelOptCluster cluster, RelOptTable empTable, RelOptTable deptTable) {
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    final LogicalTableScan empScan = LogicalTableScan.create(cluster, empTable);
    Double rowSize = mq.getAverageRowSize(empScan);
    List<Double> columnSizes = mq.getAverageColumnSizes(empScan);
    assertThat(columnSizes.size(), equalTo(empScan.getRowType().getFieldCount()));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 40.0, 20.0, 4.0, 8.0, 4.0, 4.0, 4.0, 1.0)));
    assertThat(rowSize, equalTo(89.0));
    // Empty values
    final LogicalValues emptyValues = LogicalValues.createEmpty(cluster, empTable.getRowType());
    rowSize = mq.getAverageRowSize(emptyValues);
    columnSizes = mq.getAverageColumnSizes(emptyValues);
    assertThat(columnSizes.size(), equalTo(emptyValues.getRowType().getFieldCount()));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 40.0, 20.0, 4.0, 8.0, 4.0, 4.0, 4.0, 1.0)));
    assertThat(rowSize, equalTo(89.0));
    // Values
    final RelDataType rowType = cluster.getTypeFactory().builder().add("a", SqlTypeName.INTEGER).add("b", SqlTypeName.VARCHAR).add("c", SqlTypeName.VARCHAR).build();
    final ImmutableList.Builder<ImmutableList<RexLiteral>> tuples = ImmutableList.builder();
    addRow(tuples, rexBuilder, 1, "1234567890", "ABC");
    addRow(tuples, rexBuilder, 2, "1", "A");
    addRow(tuples, rexBuilder, 3, "2", null);
    final LogicalValues values = LogicalValues.create(cluster, rowType, tuples.build());
    rowSize = mq.getAverageRowSize(values);
    columnSizes = mq.getAverageColumnSizes(values);
    assertThat(columnSizes.size(), equalTo(values.getRowType().getFieldCount()));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 8.0, 3.0)));
    assertThat(rowSize, equalTo(15.0));
    // Union
    final LogicalUnion union = LogicalUnion.create(ImmutableList.<RelNode>of(empScan, emptyValues), true);
    rowSize = mq.getAverageRowSize(union);
    columnSizes = mq.getAverageColumnSizes(union);
    assertThat(columnSizes.size(), equalTo(9));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 40.0, 20.0, 4.0, 8.0, 4.0, 4.0, 4.0, 1.0)));
    assertThat(rowSize, equalTo(89.0));
    // Filter
    final LogicalTableScan deptScan = LogicalTableScan.create(cluster, deptTable);
    final LogicalFilter filter = LogicalFilter.create(deptScan, rexBuilder.makeCall(SqlStdOperatorTable.LESS_THAN, rexBuilder.makeInputRef(deptScan, 0), rexBuilder.makeExactLiteral(BigDecimal.TEN)));
    rowSize = mq.getAverageRowSize(filter);
    columnSizes = mq.getAverageColumnSizes(filter);
    assertThat(columnSizes.size(), equalTo(2));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 20.0)));
    assertThat(rowSize, equalTo(24.0));
    // Project
    final LogicalProject deptProject = LogicalProject.create(filter, ImmutableList.of(rexBuilder.makeInputRef(filter, 0), rexBuilder.makeInputRef(filter, 1), rexBuilder.makeCall(SqlStdOperatorTable.PLUS, rexBuilder.makeInputRef(filter, 0), rexBuilder.makeExactLiteral(BigDecimal.ONE)), rexBuilder.makeCall(SqlStdOperatorTable.CHAR_LENGTH, rexBuilder.makeInputRef(filter, 1))), (List<String>) null);
    rowSize = mq.getAverageRowSize(deptProject);
    columnSizes = mq.getAverageColumnSizes(deptProject);
    assertThat(columnSizes.size(), equalTo(4));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 20.0, 4.0, 4.0)));
    assertThat(rowSize, equalTo(32.0));
    // Join
    final LogicalJoin join = LogicalJoin.create(empScan, deptProject, rexBuilder.makeLiteral(true), ImmutableSet.<CorrelationId>of(), JoinRelType.INNER);
    rowSize = mq.getAverageRowSize(join);
    columnSizes = mq.getAverageColumnSizes(join);
    assertThat(columnSizes.size(), equalTo(13));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 40.0, 20.0, 4.0, 8.0, 4.0, 4.0, 4.0, 1.0, 4.0, 20.0, 4.0, 4.0)));
    assertThat(rowSize, equalTo(121.0));
    // Aggregate
    final LogicalAggregate aggregate = LogicalAggregate.create(join, ImmutableBitSet.of(2, 0), ImmutableList.<ImmutableBitSet>of(), ImmutableList.of(AggregateCall.create(SqlStdOperatorTable.COUNT, false, false, ImmutableIntList.of(), -1, 2, join, null, null)));
    rowSize = mq.getAverageRowSize(aggregate);
    columnSizes = mq.getAverageColumnSizes(aggregate);
    assertThat(columnSizes.size(), equalTo(3));
    assertThat(columnSizes, equalTo(Arrays.asList(4.0, 20.0, 8.0)));
    assertThat(rowSize, equalTo(32.0));
    // Smoke test Parallelism and Memory metadata providers
    assertThat(mq.memory(aggregate), nullValue());
    assertThat(mq.cumulativeMemoryWithinPhase(aggregate), nullValue());
    assertThat(mq.cumulativeMemoryWithinPhaseSplit(aggregate), nullValue());
    assertThat(mq.isPhaseTransition(aggregate), is(false));
    assertThat(mq.splitCount(aggregate), is(1));
}
Also used : RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) LogicalUnion(org.apache.calcite.rel.logical.LogicalUnion) ImmutableList(com.google.common.collect.ImmutableList) LogicalFilter(org.apache.calcite.rel.logical.LogicalFilter) RelDataType(org.apache.calcite.rel.type.RelDataType) LogicalTableScan(org.apache.calcite.rel.logical.LogicalTableScan) LogicalValues(org.apache.calcite.rel.logical.LogicalValues) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) RexBuilder(org.apache.calcite.rex.RexBuilder) LogicalProject(org.apache.calcite.rel.logical.LogicalProject)

Example 10 with LogicalAggregate

use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalAggregate in project calcite by apache.

the class RelMetadataTest method testCustomProvider.

@Test
public void testCustomProvider() {
    final List<String> buf = Lists.newArrayList();
    ColTypeImpl.THREAD_LIST.set(buf);
    final String sql = "select deptno, count(*) from emp where deptno > 10 " + "group by deptno having count(*) = 0";
    final RelRoot root = tester.withClusterFactory(new Function<RelOptCluster, RelOptCluster>() {

        public RelOptCluster apply(RelOptCluster cluster) {
            // Create a custom provider that includes ColType.
            // Include the same provider twice just to be devious.
            final ImmutableList<RelMetadataProvider> list = ImmutableList.of(ColTypeImpl.SOURCE, ColTypeImpl.SOURCE, cluster.getMetadataProvider());
            cluster.setMetadataProvider(ChainedRelMetadataProvider.of(list));
            return cluster;
        }
    }).convertSqlToRel(sql);
    final RelNode rel = root.rel;
    // Top node is a filter. Its metadata uses getColType(RelNode, int).
    assertThat(rel, instanceOf(LogicalFilter.class));
    final RelMetadataQuery mq = RelMetadataQuery.instance();
    assertThat(colType(mq, rel, 0), equalTo("DEPTNO-rel"));
    assertThat(colType(mq, rel, 1), equalTo("EXPR$1-rel"));
    // Next node is an aggregate. Its metadata uses
    // getColType(LogicalAggregate, int).
    final RelNode input = rel.getInput(0);
    assertThat(input, instanceOf(LogicalAggregate.class));
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    // There is no caching. Another request causes another call to the provider.
    assertThat(buf.toString(), equalTo("[DEPTNO-rel, EXPR$1-rel, DEPTNO-agg]"));
    assertThat(buf.size(), equalTo(3));
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    assertThat(buf.size(), equalTo(4));
    // Now add a cache. Only the first request for each piece of metadata
    // generates a new call to the provider.
    final RelOptPlanner planner = rel.getCluster().getPlanner();
    rel.getCluster().setMetadataProvider(new CachingRelMetadataProvider(rel.getCluster().getMetadataProvider(), planner));
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    assertThat(buf.size(), equalTo(5));
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    assertThat(buf.size(), equalTo(5));
    assertThat(colType(mq, input, 1), equalTo("EXPR$1-agg"));
    assertThat(buf.size(), equalTo(6));
    assertThat(colType(mq, input, 1), equalTo("EXPR$1-agg"));
    assertThat(buf.size(), equalTo(6));
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    assertThat(buf.size(), equalTo(6));
    // With a different timestamp, a metadata item is re-computed on first call.
    long timestamp = planner.getRelMetadataTimestamp(rel);
    assertThat(timestamp, equalTo(0L));
    ((MockRelOptPlanner) planner).setRelMetadataTimestamp(timestamp + 1);
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    assertThat(buf.size(), equalTo(7));
    assertThat(colType(mq, input, 0), equalTo("DEPTNO-agg"));
    assertThat(buf.size(), equalTo(7));
}
Also used : RelOptCluster(org.apache.calcite.plan.RelOptCluster) RelMetadataQuery(org.apache.calcite.rel.metadata.RelMetadataQuery) CachingRelMetadataProvider(org.apache.calcite.rel.metadata.CachingRelMetadataProvider) LogicalFilter(org.apache.calcite.rel.logical.LogicalFilter) RelRoot(org.apache.calcite.rel.RelRoot) RelOptPlanner(org.apache.calcite.plan.RelOptPlanner) Function(com.google.common.base.Function) LogicalAggregate(org.apache.calcite.rel.logical.LogicalAggregate) RelNode(org.apache.calcite.rel.RelNode) JaninoRelMetadataProvider(org.apache.calcite.rel.metadata.JaninoRelMetadataProvider) DefaultRelMetadataProvider(org.apache.calcite.rel.metadata.DefaultRelMetadataProvider) ChainedRelMetadataProvider(org.apache.calcite.rel.metadata.ChainedRelMetadataProvider) ReflectiveRelMetadataProvider(org.apache.calcite.rel.metadata.ReflectiveRelMetadataProvider) CachingRelMetadataProvider(org.apache.calcite.rel.metadata.CachingRelMetadataProvider) RelMetadataProvider(org.apache.calcite.rel.metadata.RelMetadataProvider) Test(org.junit.Test)

Aggregations

LogicalAggregate (org.apache.calcite.rel.logical.LogicalAggregate)11 RelNode (org.apache.calcite.rel.RelNode)7 LogicalFilter (org.apache.calcite.rel.logical.LogicalFilter)4 ArrayList (java.util.ArrayList)3 RelTraitSet (org.apache.calcite.plan.RelTraitSet)3 LogicalJoin (org.apache.calcite.rel.logical.LogicalJoin)3 RelDataType (org.apache.calcite.rel.type.RelDataType)3 RexBuilder (org.apache.calcite.rex.RexBuilder)3 ResolvedComputedColumn (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedComputedColumn)2 InvalidRelException (org.apache.calcite.rel.InvalidRelException)2 AggregateCall (org.apache.calcite.rel.core.AggregateCall)2 LogicalProject (org.apache.calcite.rel.logical.LogicalProject)2 LogicalTableScan (org.apache.calcite.rel.logical.LogicalTableScan)2 RelMetadataQuery (org.apache.calcite.rel.metadata.RelMetadataQuery)2 RexNode (org.apache.calcite.rex.RexNode)2 Test (org.junit.Test)2 Function (com.google.common.base.Function)1 ImmutableList (com.google.common.collect.ImmutableList)1 ResolvedAggregateFunctionCall (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedAggregateFunctionCall)1 ResolvedExpr (com.google.zetasql.resolvedast.ResolvedNodes.ResolvedExpr)1