use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject in project streamline by hortonworks.
the class TestRelNodeCompiler method testFilter.
@Test
public void testFilter() throws Exception {
String sql = "SELECT ID + 1 FROM FOO WHERE ID > 3";
TestCompilerUtils.CalciteState state = TestCompilerUtils.sqlOverDummyTable(sql);
JavaTypeFactory typeFactory = new JavaTypeFactoryImpl(RelDataTypeSystem.DEFAULT);
LogicalProject project = (LogicalProject) state.tree();
LogicalFilter filter = (LogicalFilter) project.getInput();
try (StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw)) {
RelNodeCompiler compiler = new RelNodeCompiler(pw, typeFactory);
// standalone mode doesn't use inputstreams argument
compiler.visitFilter(filter, Collections.EMPTY_LIST);
pw.flush();
Assert.assertThat(sw.toString(), containsString("> 3"));
}
try (StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw)) {
RelNodeCompiler compiler = new RelNodeCompiler(pw, typeFactory);
// standalone mode doesn't use inputstreams argument
compiler.visitProject(project, Collections.EMPTY_LIST);
pw.flush();
Assert.assertThat(sw.toString(), containsString(" + 1"));
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject in project beam by apache.
the class BeamUnnestRule method onMatch.
@Override
public void onMatch(RelOptRuleCall call) {
LogicalCorrelate correlate = call.rel(0);
RelNode outer = call.rel(1);
RelNode uncollect = call.rel(2);
if (correlate.getCorrelationId().getId() != 0) {
// Only one level of correlation nesting is supported
return;
}
if (correlate.getRequiredColumns().cardinality() != 1) {
// can only unnest a single column
return;
}
if (correlate.getJoinType() != JoinRelType.INNER) {
return;
}
if (!(uncollect instanceof Uncollect)) {
// Drop projection
uncollect = ((SingleRel) uncollect).getInput();
if (uncollect instanceof RelSubset) {
uncollect = ((RelSubset) uncollect).getOriginal();
}
if (!(uncollect instanceof Uncollect)) {
return;
}
}
RelNode project = ((Uncollect) uncollect).getInput();
if (project instanceof RelSubset) {
project = ((RelSubset) project).getOriginal();
}
if (!(project instanceof LogicalProject)) {
return;
}
if (((LogicalProject) project).getProjects().size() != 1) {
// can only unnest a single column
return;
}
RexNode exp = ((LogicalProject) project).getProjects().get(0);
if (!(exp instanceof RexFieldAccess)) {
return;
}
RexFieldAccess fieldAccess = (RexFieldAccess) exp;
// Innermost field index comes first (e.g. struct.field1.field2 => [2, 1])
ImmutableList.Builder<Integer> fieldAccessIndices = ImmutableList.builder();
while (true) {
fieldAccessIndices.add(fieldAccess.getField().getIndex());
if (!(fieldAccess.getReferenceExpr() instanceof RexFieldAccess)) {
break;
}
fieldAccess = (RexFieldAccess) fieldAccess.getReferenceExpr();
}
call.transformTo(new BeamUnnestRel(correlate.getCluster(), correlate.getTraitSet().replace(BeamLogicalConvention.INSTANCE), convert(outer, outer.getTraitSet().replace(BeamLogicalConvention.INSTANCE)), call.rel(2).getRowType(), fieldAccessIndices.build()));
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject in project beam by apache.
the class AggregateScanConverter method convertAggregateScanInputScanToLogicalProject.
private LogicalProject convertAggregateScanInputScanToLogicalProject(ResolvedAggregateScan node, RelNode input) {
// AggregateScan's input is the source of data (e.g. TableScan), which is different from the
// design of CalciteSQL, in which the LogicalAggregate's input is a LogicalProject, whose input
// is a LogicalTableScan. When AggregateScan's input is WithRefScan, the WithRefScan is
// ebullient to a LogicalTableScan. So it's still required to build another LogicalProject as
// the input of LogicalAggregate.
List<RexNode> projects = new ArrayList<>();
List<String> fieldNames = new ArrayList<>();
// LogicalAggregate.
for (ResolvedComputedColumn computedColumn : node.getGroupByList()) {
projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(computedColumn.getExpr(), node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
fieldNames.add(getTrait().resolveAlias(computedColumn.getColumn()));
}
// TODO: remove duplicate columns in projects.
for (ResolvedComputedColumn resolvedComputedColumn : node.getAggregateList()) {
// Should create Calcite's RexInputRef from ResolvedColumn from ResolvedComputedColumn.
// TODO: handle aggregate function with more than one argument and handle OVER
// TODO: is there is general way for column reference tracking and deduplication for
// aggregation?
ResolvedAggregateFunctionCall aggregateFunctionCall = ((ResolvedAggregateFunctionCall) resolvedComputedColumn.getExpr());
if (aggregateFunctionCall.getArgumentList() != null && aggregateFunctionCall.getArgumentList().size() >= 1) {
ResolvedExpr resolvedExpr = aggregateFunctionCall.getArgumentList().get(0);
for (int i = 0; i < aggregateFunctionCall.getArgumentList().size(); i++) {
if (i == 0) {
// TODO: assume aggregate function's input is either a ColumnRef or a cast(ColumnRef).
// TODO: user might use multiple CAST so we need to handle this rare case.
projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(resolvedExpr, node.getInputScan().getColumnList(), input.getRowType().getFieldList(), ImmutableMap.of()));
} else {
projects.add(getExpressionConverter().convertRexNodeFromResolvedExpr(aggregateFunctionCall.getArgumentList().get(i)));
}
fieldNames.add(getTrait().resolveAlias(resolvedComputedColumn.getColumn()));
}
}
}
return LogicalProject.create(input, ImmutableList.of(), projects, fieldNames);
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject in project samza by apache.
the class TestQueryPlanner method testRemoteJoinWithFilterHelper.
void testRemoteJoinWithFilterHelper(boolean enableOptimizer) throws SamzaSqlValidatorException {
Map<String, String> staticConfigs = SamzaSqlTestConfig.fetchStaticConfigsWithFactories(1);
String sql = "Insert into testavro.enrichedPageViewTopic " + "select pv.pageKey as __key__, pv.pageKey as pageKey, coalesce(null, 'N/A') as companyName," + " p.name as profileName, p.address as profileAddress " + "from testavro.PAGEVIEW as pv " + "join testRemoteStore.Profile.`$table` as p " + " on p.__key__ = pv.profileId" + " where p.name = pv.pageKey AND p.name = 'Mike' AND pv.profileId = 1";
staticConfigs.put(SamzaSqlApplicationConfig.CFG_SQL_STMT, sql);
staticConfigs.put(SamzaSqlApplicationConfig.CFG_SQL_ENABLE_PLAN_OPTIMIZER, Boolean.toString(enableOptimizer));
Config samzaConfig = new MapConfig(staticConfigs);
DslConverter dslConverter = new SamzaSqlDslConverterFactory().create(samzaConfig);
Collection<RelRoot> relRoots = dslConverter.convertDsl(sql);
/*
Query plan without optimization:
LogicalProject(__key__=[$1], pageKey=[$1], companyName=['N/A'], profileName=[$5], profileAddress=[$7])
LogicalFilter(condition=[AND(=($5, $1), =($5, 'Mike'), =($2, 1))])
LogicalJoin(condition=[=($3, $2)], joinType=[inner])
LogicalTableScan(table=[[testavro, PAGEVIEW]])
LogicalTableScan(table=[[testRemoteStore, Profile, $table]])
Query plan with optimization:
LogicalProject(__key__=[$1], pageKey=[$1], companyName=['N/A'], profileName=[$5], profileAddress=[$7])
LogicalFilter(condition=[AND(=($5, $1), =($5, 'Mike'))])
LogicalJoin(condition=[=($3, $2)], joinType=[inner])
LogicalFilter(condition=[=($2, 1)])
LogicalTableScan(table=[[testavro, PAGEVIEW]])
LogicalTableScan(table=[[testRemoteStore, Profile, $table]])
*/
assertEquals(1, relRoots.size());
RelRoot relRoot = relRoots.iterator().next();
RelNode relNode = relRoot.rel;
assertTrue(relNode instanceof LogicalProject);
relNode = relNode.getInput(0);
assertTrue(relNode instanceof LogicalFilter);
if (enableOptimizer) {
assertEquals("AND(=($1, $5), =($5, 'Mike'))", ((LogicalFilter) relNode).getCondition().toString());
} else {
assertEquals("AND(=(1, $2), =($1, $5), =($5, 'Mike'))", ((LogicalFilter) relNode).getCondition().toString());
}
relNode = relNode.getInput(0);
assertTrue(relNode instanceof LogicalJoin);
assertEquals(2, relNode.getInputs().size());
LogicalJoin join = (LogicalJoin) relNode;
RelNode left = join.getLeft();
RelNode right = join.getRight();
assertTrue(right instanceof LogicalTableScan);
if (enableOptimizer) {
assertTrue(left instanceof LogicalFilter);
assertEquals("=(1, $2)", ((LogicalFilter) left).getCondition().toString());
assertTrue(left.getInput(0) instanceof LogicalTableScan);
} else {
assertTrue(left instanceof LogicalTableScan);
}
}
use of org.apache.beam.vendor.calcite.v1_28_0.org.apache.calcite.rel.logical.LogicalProject in project samza by apache.
the class TestProjectTranslator method testTranslate.
@Test
public void testTranslate() throws IOException, ClassNotFoundException {
// setup mock values to the constructor of FilterTranslator
LogicalProject mockProject = PowerMockito.mock(LogicalProject.class);
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
RelNode mockInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockInput);
when(mockInput.getId()).thenReturn(1);
when(mockProject.getId()).thenReturn(2);
when(mockProject.getInputs()).thenReturn(inputs);
when(mockProject.getInput()).thenReturn(mockInput);
RelDataType mockRowType = mock(RelDataType.class);
when(mockRowType.getFieldCount()).thenReturn(1);
when(mockProject.getRowType()).thenReturn(mockRowType);
RexNode mockRexField = mock(RexNode.class);
List<Pair<RexNode, String>> namedProjects = new ArrayList<>();
namedProjects.add(Pair.of(mockRexField, "test_field"));
when(mockProject.getNamedProjects()).thenReturn(namedProjects);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockAppDesc, mockInputOp);
when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
// Apply translate() method to verify that we are getting the correct map operator constructed
ProjectTranslator projectTranslator = new ProjectTranslator(1);
projectTranslator.translate(mockProject, LOGICAL_OP_ID, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerRelNode(2, mockProject);
verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
when(mockTranslatorContext.getRelNode(2)).thenReturn(mockProject);
when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
StreamOperatorSpec projectSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
assertNotNull(projectSpec);
assertEquals(projectSpec.getOpCode(), OperatorSpec.OpCode.MAP);
// Verify that the bootstrap() method will establish the context for the map function
Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
mockContexts.put(1, mockTranslatorContext);
when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
projectSpec.getTransformFn().init(mockContext);
MapFunction mapFn = (MapFunction) Whitebox.getInternalState(projectSpec, "mapFn");
assertNotNull(mapFn);
assertEquals(mockTranslatorContext, Whitebox.getInternalState(mapFn, "translatorContext"));
assertEquals(mockProject, Whitebox.getInternalState(mapFn, "project"));
assertEquals(mockExpr, Whitebox.getInternalState(mapFn, "expr"));
// Verify TestMetricsRegistryImpl works with Project
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
assertEquals(2, testMetricsRegistryImpl.getGauges().get(LOGICAL_OP_ID).size());
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
// Calling mapFn.apply() to verify the filter function is correctly applied to the input message
SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(), new SamzaSqlRelMsgMetadata(0L, 0L));
SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
DataContext dataContext = mock(DataContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
Object[] result = new Object[1];
final Object mockFieldObj = new Object();
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = mockFieldObj;
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
SamzaSqlRelMessage retMsg = (SamzaSqlRelMessage) mapFn.apply(mockInputMsg);
assertEquals(retMsg.getSamzaSqlRelRecord().getFieldNames(), Collections.singletonList("test_field"));
assertEquals(retMsg.getSamzaSqlRelRecord().getFieldValues(), Collections.singletonList(mockFieldObj));
// Verify mapFn.apply() updates the TestMetricsRegistryImpl metrics
assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
Aggregations