Search in sources :

Example 1 with RexToJavaCompiler

use of org.apache.samza.sql.data.RexToJavaCompiler in project samza by apache.

the class TranslatorContext method createExpressionCompiler.

private RexToJavaCompiler createExpressionCompiler(RelRoot relRoot) {
    RelDataTypeFactory dataTypeFactory = relRoot.project().getCluster().getTypeFactory();
    RexBuilder rexBuilder = new SamzaSqlRexBuilder(dataTypeFactory);
    return new RexToJavaCompiler(rexBuilder);
}
Also used : RelDataTypeFactory(org.apache.calcite.rel.type.RelDataTypeFactory) RexBuilder(org.apache.calcite.rex.RexBuilder) RexToJavaCompiler(org.apache.samza.sql.data.RexToJavaCompiler)

Example 2 with RexToJavaCompiler

use of org.apache.samza.sql.data.RexToJavaCompiler in project samza by apache.

the class TestFilterTranslator method testTranslate.

@Test
public void testTranslate() throws IOException, ClassNotFoundException {
    // setup mock values to the constructor of FilterTranslator
    LogicalFilter mockFilter = PowerMockito.mock(LogicalFilter.class);
    Context mockContext = mock(Context.class);
    ContainerContext mockContainerContext = mock(ContainerContext.class);
    TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
    TestMetricsRegistryImpl metricsRegistry = new TestMetricsRegistryImpl();
    RelNode mockInput = mock(RelNode.class);
    when(mockFilter.getInput()).thenReturn(mockInput);
    when(mockInput.getId()).thenReturn(1);
    when(mockFilter.getId()).thenReturn(2);
    StreamApplicationDescriptorImpl mockGraph = mock(StreamApplicationDescriptorImpl.class);
    OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
    MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockGraph, mockInputOp);
    when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
    doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
    RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
    when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
    Expression mockExpr = mock(Expression.class);
    when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
    when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
    when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(metricsRegistry);
    // Apply translate() method to verify that we are getting the correct filter operator constructed
    FilterTranslator filterTranslator = new FilterTranslator(1);
    filterTranslator.translate(mockFilter, LOGICAL_OP_ID, mockTranslatorContext);
    // make sure that context has been registered with LogicFilter and output message streams
    verify(mockTranslatorContext, times(1)).registerRelNode(2, mockFilter);
    verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
    when(mockTranslatorContext.getRelNode(2)).thenReturn(mockFilter);
    when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
    StreamOperatorSpec filterSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
    assertNotNull(filterSpec);
    assertEquals(filterSpec.getOpCode(), OperatorSpec.OpCode.FILTER);
    // Verify that the describe() method will establish the context for the filter function
    Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
    mockContexts.put(1, mockTranslatorContext);
    when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
    filterSpec.getTransformFn().init(mockContext);
    FilterFunction filterFn = (FilterFunction) Whitebox.getInternalState(filterSpec, "filterFn");
    assertNotNull(filterFn);
    assertEquals(mockTranslatorContext, Whitebox.getInternalState(filterFn, "translatorContext"));
    assertEquals(mockFilter, Whitebox.getInternalState(filterFn, "filter"));
    assertEquals(mockExpr, Whitebox.getInternalState(filterFn, "expr"));
    // Verify MetricsRegistry works with Project
    assertEquals(1, metricsRegistry.getGauges().size());
    assertTrue(metricsRegistry.getGauges().get(LOGICAL_OP_ID).size() > 0);
    assertEquals(1, metricsRegistry.getCounters().size());
    assertEquals(3, metricsRegistry.getCounters().get(LOGICAL_OP_ID).size());
    assertEquals(0, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
    assertEquals(0, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
    // Calling filterFn.apply() to verify the filter function is correctly applied to the input message
    SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(), new SamzaSqlRelMsgMetadata(0L, 0L));
    SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
    DataContext dataContext = mock(DataContext.class);
    when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
    when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
    Object[] result = new Object[1];
    doAnswer(invocation -> {
        Object[] retValue = invocation.getArgumentAt(4, Object[].class);
        retValue[0] = new Boolean(true);
        return null;
    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
    assertTrue(filterFn.apply(mockInputMsg));
    doAnswer(invocation -> {
        Object[] retValue = invocation.getArgumentAt(4, Object[].class);
        retValue[0] = new Boolean(false);
        return null;
    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
    assertFalse(filterFn.apply(mockInputMsg));
    // Verify filterFn.apply() updates the MetricsRegistry metrics
    assertEquals(2, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
    assertEquals(1, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
Also used : MessageStreamImpl(org.apache.samza.operators.MessageStreamImpl) FilterFunction(org.apache.samza.operators.functions.FilterFunction) HashMap(java.util.HashMap) ContainerContext(org.apache.samza.context.ContainerContext) StreamOperatorSpec(org.apache.samza.operators.spec.StreamOperatorSpec) DataContext(org.apache.calcite.DataContext) StreamApplicationDescriptorImpl(org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl) SamzaSqlApplicationContext(org.apache.samza.sql.runner.SamzaSqlApplicationContext) MessageStream(org.apache.samza.operators.MessageStream) ContainerContext(org.apache.samza.context.ContainerContext) SamzaSqlExecutionContext(org.apache.samza.sql.data.SamzaSqlExecutionContext) DataContext(org.apache.calcite.DataContext) Context(org.apache.samza.context.Context) SamzaSqlApplicationContext(org.apache.samza.sql.runner.SamzaSqlApplicationContext) SamzaSqlRelMsgMetadata(org.apache.samza.sql.data.SamzaSqlRelMsgMetadata) LogicalFilter(org.apache.calcite.rel.logical.LogicalFilter) RexToJavaCompiler(org.apache.samza.sql.data.RexToJavaCompiler) TestMetricsRegistryImpl(org.apache.samza.sql.util.TestMetricsRegistryImpl) RelNode(org.apache.calcite.rel.RelNode) Expression(org.apache.samza.sql.data.Expression) SamzaSqlExecutionContext(org.apache.samza.sql.data.SamzaSqlExecutionContext) SamzaSqlRelMessage(org.apache.samza.sql.data.SamzaSqlRelMessage) PrepareForTest(org.powermock.core.classloader.annotations.PrepareForTest) Test(org.junit.Test)

Example 3 with RexToJavaCompiler

use of org.apache.samza.sql.data.RexToJavaCompiler in project samza by apache.

the class TestJoinTranslator method testTranslateStreamToTableJoin.

private void testTranslateStreamToTableJoin(boolean isRemoteTable) throws IOException, ClassNotFoundException {
    // setup mock values to the constructor of JoinTranslator
    final String logicalOpId = "sql0_join3";
    final int queryId = 0;
    LogicalJoin mockJoin = PowerMockito.mock(LogicalJoin.class);
    TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
    RelNode mockLeftInput = PowerMockito.mock(LogicalTableScan.class);
    RelNode mockRightInput = mock(RelNode.class);
    List<RelNode> inputs = new ArrayList<>();
    inputs.add(mockLeftInput);
    inputs.add(mockRightInput);
    RelOptTable mockLeftTable = mock(RelOptTable.class);
    when(mockLeftInput.getTable()).thenReturn(mockLeftTable);
    List<String> qualifiedTableName = Arrays.asList("test", "LeftTable");
    when(mockLeftTable.getQualifiedName()).thenReturn(qualifiedTableName);
    when(mockLeftInput.getId()).thenReturn(1);
    when(mockRightInput.getId()).thenReturn(2);
    when(mockJoin.getId()).thenReturn(3);
    when(mockJoin.getInputs()).thenReturn(inputs);
    when(mockJoin.getLeft()).thenReturn(mockLeftInput);
    when(mockJoin.getRight()).thenReturn(mockRightInput);
    RexCall mockJoinCondition = mock(RexCall.class);
    when(mockJoinCondition.isAlwaysTrue()).thenReturn(false);
    when(mockJoinCondition.getKind()).thenReturn(SqlKind.EQUALS);
    when(mockJoin.getCondition()).thenReturn(mockJoinCondition);
    RexInputRef mockLeftConditionInput = mock(RexInputRef.class);
    RexInputRef mockRightConditionInput = mock(RexInputRef.class);
    when(mockLeftConditionInput.getIndex()).thenReturn(0);
    when(mockRightConditionInput.getIndex()).thenReturn(0);
    List<RexNode> condOperands = new ArrayList<>();
    condOperands.add(mockLeftConditionInput);
    condOperands.add(mockRightConditionInput);
    when(mockJoinCondition.getOperands()).thenReturn(condOperands);
    RelDataType mockLeftCondDataType = mock(RelDataType.class);
    RelDataType mockRightCondDataType = mock(RelDataType.class);
    when(mockLeftCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
    when(mockRightCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
    when(mockLeftConditionInput.getType()).thenReturn(mockLeftCondDataType);
    when(mockRightConditionInput.getType()).thenReturn(mockRightCondDataType);
    RelDataType mockLeftRowType = mock(RelDataType.class);
    // ?? why ??
    when(mockLeftRowType.getFieldCount()).thenReturn(0);
    when(mockLeftInput.getRowType()).thenReturn(mockLeftRowType);
    List<String> leftFieldNames = Collections.singletonList("test_table_field1");
    List<String> rightStreamFieldNames = Collections.singletonList("test_stream_field1");
    when(mockLeftRowType.getFieldNames()).thenReturn(leftFieldNames);
    RelDataType mockRightRowType = mock(RelDataType.class);
    when(mockRightInput.getRowType()).thenReturn(mockRightRowType);
    when(mockRightRowType.getFieldNames()).thenReturn(rightStreamFieldNames);
    StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
    OperatorSpec<Object, SamzaSqlRelMessage> mockLeftInputOp = mock(OperatorSpec.class);
    MessageStream<SamzaSqlRelMessage> mockLeftInputStream = new MessageStreamImpl<>(mockAppDesc, mockLeftInputOp);
    when(mockTranslatorContext.getMessageStream(eq(mockLeftInput.getId()))).thenReturn(mockLeftInputStream);
    OperatorSpec<Object, SamzaSqlRelMessage> mockRightInputOp = mock(OperatorSpec.class);
    MessageStream<SamzaSqlRelMessage> mockRightInputStream = new MessageStreamImpl<>(mockAppDesc, mockRightInputOp);
    when(mockTranslatorContext.getMessageStream(eq(mockRightInput.getId()))).thenReturn(mockRightInputStream);
    when(mockTranslatorContext.getStreamAppDescriptor()).thenReturn(mockAppDesc);
    InputOperatorSpec mockInputOp = mock(InputOperatorSpec.class);
    OutputStreamImpl mockOutputStream = mock(OutputStreamImpl.class);
    when(mockInputOp.isKeyed()).thenReturn(true);
    when(mockOutputStream.isKeyed()).thenReturn(true);
    doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(3), any(MessageStream.class));
    RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
    when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
    Expression mockExpr = mock(Expression.class);
    when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
    if (isRemoteTable) {
        doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RemoteTableDescriptor.class));
    } else {
        IntermediateMessageStreamImpl mockPartitionedStream = new IntermediateMessageStreamImpl(mockAppDesc, mockInputOp, mockOutputStream);
        when(mockAppDesc.getIntermediateStream(any(String.class), any(Serde.class), eq(false))).thenReturn(mockPartitionedStream);
        doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RocksDbTableDescriptor.class));
    }
    when(mockJoin.getJoinType()).thenReturn(JoinRelType.INNER);
    SamzaSqlExecutionContext mockExecutionContext = mock(SamzaSqlExecutionContext.class);
    when(mockTranslatorContext.getExecutionContext()).thenReturn(mockExecutionContext);
    SamzaSqlApplicationConfig mockAppConfig = mock(SamzaSqlApplicationConfig.class);
    when(mockExecutionContext.getSamzaSqlApplicationConfig()).thenReturn(mockAppConfig);
    Map<String, SqlIOConfig> ssConfigBySource = mock(HashMap.class);
    when(mockAppConfig.getInputSystemStreamConfigBySource()).thenReturn(ssConfigBySource);
    SqlIOConfig mockIOConfig = mock(SqlIOConfig.class);
    TableDescriptor mockTableDesc;
    if (isRemoteTable) {
        mockTableDesc = mock(RemoteTableDescriptor.class);
    } else {
        mockTableDesc = mock(RocksDbTableDescriptor.class);
    }
    when(ssConfigBySource.get(String.join(".", qualifiedTableName))).thenReturn(mockIOConfig);
    when(mockIOConfig.getTableDescriptor()).thenReturn(Optional.of(mockTableDesc));
    JoinTranslator joinTranslator = new JoinTranslator(logicalOpId, "", queryId);
    // Verify Metrics Works with Join
    Context mockContext = mock(Context.class);
    ContainerContext mockContainerContext = mock(ContainerContext.class);
    TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
    when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
    when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
    TranslatorInputMetricsMapFunction inputMetricsMF = joinTranslator.getInputMetricsMF();
    assertNotNull(inputMetricsMF);
    inputMetricsMF.init(mockContext);
    TranslatorOutputMetricsMapFunction outputMetricsMF = joinTranslator.getOutputMetricsMF();
    assertNotNull(outputMetricsMF);
    outputMetricsMF.init(mockContext);
    assertEquals(1, testMetricsRegistryImpl.getCounters().size());
    assertEquals(2, testMetricsRegistryImpl.getCounters().get(logicalOpId).size());
    assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(0).getCount());
    assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(1).getCount());
    assertEquals(1, testMetricsRegistryImpl.getGauges().size());
    // Apply translate() method to verify that we are getting the correct map operator constructed
    joinTranslator.translate(mockJoin, mockTranslatorContext);
    // make sure that context has been registered with LogicFilter and output message streams
    verify(mockTranslatorContext, times(1)).registerMessageStream(3, this.getRegisteredMessageStream(3));
    when(mockTranslatorContext.getRelNode(3)).thenReturn(mockJoin);
    when(mockTranslatorContext.getMessageStream(3)).thenReturn(this.getRegisteredMessageStream(3));
    StreamTableJoinOperatorSpec joinSpec = (StreamTableJoinOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(3), "operatorSpec");
    assertNotNull(joinSpec);
    assertEquals(joinSpec.getOpCode(), OperatorSpec.OpCode.JOIN);
    // Verify joinSpec has the corresponding setup
    StreamTableJoinFunction joinFn = joinSpec.getJoinFn();
    assertNotNull(joinFn);
    if (isRemoteTable) {
        assertTrue(joinFn instanceof SamzaSqlRemoteTableJoinFunction);
    } else {
        assertTrue(joinFn instanceof SamzaSqlLocalTableJoinFunction);
    }
    assertTrue(Whitebox.getInternalState(joinFn, "isTablePosOnRight").equals(false));
    assertEquals(Collections.singletonList(0), Whitebox.getInternalState(joinFn, "streamFieldIds"));
    assertEquals(leftFieldNames, Whitebox.getInternalState(joinFn, "tableFieldNames"));
    List<String> outputFieldNames = new ArrayList<>();
    outputFieldNames.addAll(leftFieldNames);
    outputFieldNames.addAll(rightStreamFieldNames);
    assertEquals(outputFieldNames, Whitebox.getInternalState(joinFn, "outFieldNames"));
}
Also used : Serde(org.apache.samza.serializers.Serde) IntermediateMessageStreamImpl(org.apache.samza.operators.stream.IntermediateMessageStreamImpl) MessageStreamImpl(org.apache.samza.operators.MessageStreamImpl) OutputStreamImpl(org.apache.samza.operators.spec.OutputStreamImpl) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) RexCall(org.apache.calcite.rex.RexCall) ContainerContext(org.apache.samza.context.ContainerContext) StreamApplicationDescriptorImpl(org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl) RocksDbTableDescriptor(org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor) MessageStream(org.apache.samza.operators.MessageStream) StreamTableJoinFunction(org.apache.samza.operators.functions.StreamTableJoinFunction) SqlIOConfig(org.apache.samza.sql.interfaces.SqlIOConfig) SamzaSqlExecutionContext(org.apache.samza.sql.data.SamzaSqlExecutionContext) Context(org.apache.samza.context.Context) ContainerContext(org.apache.samza.context.ContainerContext) InputOperatorSpec(org.apache.samza.operators.spec.InputOperatorSpec) SamzaSqlApplicationConfig(org.apache.samza.sql.runner.SamzaSqlApplicationConfig) RemoteTableDescriptor(org.apache.samza.table.descriptors.RemoteTableDescriptor) IntermediateMessageStreamImpl(org.apache.samza.operators.stream.IntermediateMessageStreamImpl) RexToJavaCompiler(org.apache.samza.sql.data.RexToJavaCompiler) TableDescriptor(org.apache.samza.table.descriptors.TableDescriptor) RocksDbTableDescriptor(org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor) RemoteTableDescriptor(org.apache.samza.table.descriptors.RemoteTableDescriptor) TestMetricsRegistryImpl(org.apache.samza.sql.util.TestMetricsRegistryImpl) RelNode(org.apache.calcite.rel.RelNode) Expression(org.apache.samza.sql.data.Expression) LogicalJoin(org.apache.calcite.rel.logical.LogicalJoin) RexInputRef(org.apache.calcite.rex.RexInputRef) SamzaSqlExecutionContext(org.apache.samza.sql.data.SamzaSqlExecutionContext) RelOptTable(org.apache.calcite.plan.RelOptTable) StreamTableJoinOperatorSpec(org.apache.samza.operators.spec.StreamTableJoinOperatorSpec) SamzaSqlRelMessage(org.apache.samza.sql.data.SamzaSqlRelMessage) RexNode(org.apache.calcite.rex.RexNode)

Example 4 with RexToJavaCompiler

use of org.apache.samza.sql.data.RexToJavaCompiler in project samza by apache.

the class TestProjectTranslator method testTranslate.

@Test
public void testTranslate() throws IOException, ClassNotFoundException {
    // setup mock values to the constructor of FilterTranslator
    LogicalProject mockProject = PowerMockito.mock(LogicalProject.class);
    Context mockContext = mock(Context.class);
    ContainerContext mockContainerContext = mock(ContainerContext.class);
    TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
    TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
    RelNode mockInput = mock(RelNode.class);
    List<RelNode> inputs = new ArrayList<>();
    inputs.add(mockInput);
    when(mockInput.getId()).thenReturn(1);
    when(mockProject.getId()).thenReturn(2);
    when(mockProject.getInputs()).thenReturn(inputs);
    when(mockProject.getInput()).thenReturn(mockInput);
    RelDataType mockRowType = mock(RelDataType.class);
    when(mockRowType.getFieldCount()).thenReturn(1);
    when(mockProject.getRowType()).thenReturn(mockRowType);
    RexNode mockRexField = mock(RexNode.class);
    List<Pair<RexNode, String>> namedProjects = new ArrayList<>();
    namedProjects.add(Pair.of(mockRexField, "test_field"));
    when(mockProject.getNamedProjects()).thenReturn(namedProjects);
    StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
    OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
    MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockAppDesc, mockInputOp);
    when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
    doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
    RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
    when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
    Expression mockExpr = mock(Expression.class);
    when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
    when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
    when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
    // Apply translate() method to verify that we are getting the correct map operator constructed
    ProjectTranslator projectTranslator = new ProjectTranslator(1);
    projectTranslator.translate(mockProject, LOGICAL_OP_ID, mockTranslatorContext);
    // make sure that context has been registered with LogicFilter and output message streams
    verify(mockTranslatorContext, times(1)).registerRelNode(2, mockProject);
    verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
    when(mockTranslatorContext.getRelNode(2)).thenReturn(mockProject);
    when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
    StreamOperatorSpec projectSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
    assertNotNull(projectSpec);
    assertEquals(projectSpec.getOpCode(), OperatorSpec.OpCode.MAP);
    // Verify that the bootstrap() method will establish the context for the map function
    Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
    mockContexts.put(1, mockTranslatorContext);
    when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
    projectSpec.getTransformFn().init(mockContext);
    MapFunction mapFn = (MapFunction) Whitebox.getInternalState(projectSpec, "mapFn");
    assertNotNull(mapFn);
    assertEquals(mockTranslatorContext, Whitebox.getInternalState(mapFn, "translatorContext"));
    assertEquals(mockProject, Whitebox.getInternalState(mapFn, "project"));
    assertEquals(mockExpr, Whitebox.getInternalState(mapFn, "expr"));
    // Verify TestMetricsRegistryImpl works with Project
    assertEquals(1, testMetricsRegistryImpl.getGauges().size());
    assertEquals(2, testMetricsRegistryImpl.getGauges().get(LOGICAL_OP_ID).size());
    assertEquals(1, testMetricsRegistryImpl.getCounters().size());
    assertEquals(2, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).size());
    assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
    assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
    // Calling mapFn.apply() to verify the filter function is correctly applied to the input message
    SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(), new SamzaSqlRelMsgMetadata(0L, 0L));
    SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
    DataContext dataContext = mock(DataContext.class);
    when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
    when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
    Object[] result = new Object[1];
    final Object mockFieldObj = new Object();
    doAnswer(invocation -> {
        Object[] retValue = invocation.getArgumentAt(4, Object[].class);
        retValue[0] = mockFieldObj;
        return null;
    }).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
    SamzaSqlRelMessage retMsg = (SamzaSqlRelMessage) mapFn.apply(mockInputMsg);
    assertEquals(retMsg.getSamzaSqlRelRecord().getFieldNames(), Collections.singletonList("test_field"));
    assertEquals(retMsg.getSamzaSqlRelRecord().getFieldValues(), Collections.singletonList(mockFieldObj));
    // Verify mapFn.apply() updates the TestMetricsRegistryImpl metrics
    assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
    assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
Also used : MessageStreamImpl(org.apache.samza.operators.MessageStreamImpl) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) RelDataType(org.apache.calcite.rel.type.RelDataType) MapFunction(org.apache.samza.operators.functions.MapFunction) ContainerContext(org.apache.samza.context.ContainerContext) StreamOperatorSpec(org.apache.samza.operators.spec.StreamOperatorSpec) DataContext(org.apache.calcite.DataContext) StreamApplicationDescriptorImpl(org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl) SamzaSqlApplicationContext(org.apache.samza.sql.runner.SamzaSqlApplicationContext) MessageStream(org.apache.samza.operators.MessageStream) Pair(org.apache.calcite.util.Pair) ContainerContext(org.apache.samza.context.ContainerContext) SamzaSqlExecutionContext(org.apache.samza.sql.data.SamzaSqlExecutionContext) DataContext(org.apache.calcite.DataContext) Context(org.apache.samza.context.Context) SamzaSqlApplicationContext(org.apache.samza.sql.runner.SamzaSqlApplicationContext) SamzaSqlRelMsgMetadata(org.apache.samza.sql.data.SamzaSqlRelMsgMetadata) RexToJavaCompiler(org.apache.samza.sql.data.RexToJavaCompiler) TestMetricsRegistryImpl(org.apache.samza.sql.util.TestMetricsRegistryImpl) RelNode(org.apache.calcite.rel.RelNode) Expression(org.apache.samza.sql.data.Expression) SamzaSqlExecutionContext(org.apache.samza.sql.data.SamzaSqlExecutionContext) LogicalProject(org.apache.calcite.rel.logical.LogicalProject) SamzaSqlRelMessage(org.apache.samza.sql.data.SamzaSqlRelMessage) RexNode(org.apache.calcite.rex.RexNode) PrepareForTest(org.powermock.core.classloader.annotations.PrepareForTest) Test(org.junit.Test)

Aggregations

RexToJavaCompiler (org.apache.samza.sql.data.RexToJavaCompiler)4 RelNode (org.apache.calcite.rel.RelNode)3 StreamApplicationDescriptorImpl (org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl)3 ContainerContext (org.apache.samza.context.ContainerContext)3 Context (org.apache.samza.context.Context)3 MessageStream (org.apache.samza.operators.MessageStream)3 MessageStreamImpl (org.apache.samza.operators.MessageStreamImpl)3 Expression (org.apache.samza.sql.data.Expression)3 SamzaSqlExecutionContext (org.apache.samza.sql.data.SamzaSqlExecutionContext)3 SamzaSqlRelMessage (org.apache.samza.sql.data.SamzaSqlRelMessage)3 TestMetricsRegistryImpl (org.apache.samza.sql.util.TestMetricsRegistryImpl)3 ArrayList (java.util.ArrayList)2 HashMap (java.util.HashMap)2 DataContext (org.apache.calcite.DataContext)2 RelDataType (org.apache.calcite.rel.type.RelDataType)2 RexNode (org.apache.calcite.rex.RexNode)2 StreamOperatorSpec (org.apache.samza.operators.spec.StreamOperatorSpec)2 SamzaSqlRelMsgMetadata (org.apache.samza.sql.data.SamzaSqlRelMsgMetadata)2 SamzaSqlApplicationContext (org.apache.samza.sql.runner.SamzaSqlApplicationContext)2 Test (org.junit.Test)2