use of org.apache.samza.operators.stream.IntermediateMessageStreamImpl in project samza by apache.
the class TestMessageStreamImpl method testPartitionBy.
@Test
public void testPartitionBy() throws IOException {
StreamApplicationDescriptorImpl mockGraph = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec mockOpSpec = mock(OperatorSpec.class);
String mockOpName = "mockName";
when(mockGraph.getNextOpId(anyObject(), anyObject())).thenReturn(mockOpName);
OutputStreamImpl mockOutputStreamImpl = mock(OutputStreamImpl.class);
KVSerde mockKVSerde = mock(KVSerde.class);
IntermediateMessageStreamImpl mockIntermediateStream = mock(IntermediateMessageStreamImpl.class);
when(mockGraph.getIntermediateStream(eq(mockOpName), eq(mockKVSerde), eq(false))).thenReturn(mockIntermediateStream);
when(mockIntermediateStream.getOutputStream()).thenReturn(mockOutputStreamImpl);
when(mockIntermediateStream.isKeyed()).thenReturn(true);
MessageStreamImpl<TestMessageEnvelope> inputStream = new MessageStreamImpl<>(mockGraph, mockOpSpec);
MapFunction mockKeyFunction = mock(MapFunction.class);
MapFunction mockValueFunction = mock(MapFunction.class);
inputStream.partitionBy(mockKeyFunction, mockValueFunction, mockKVSerde, "p1");
ArgumentCaptor<OperatorSpec> registeredOpCaptor = ArgumentCaptor.forClass(OperatorSpec.class);
verify(mockOpSpec).registerNextOperatorSpec(registeredOpCaptor.capture());
OperatorSpec<?, TestMessageEnvelope> registeredOpSpec = registeredOpCaptor.getValue();
assertTrue(registeredOpSpec instanceof PartitionByOperatorSpec);
assertEquals(OpCode.PARTITION_BY, registeredOpSpec.getOpCode());
assertEquals(mockOutputStreamImpl, ((PartitionByOperatorSpec) registeredOpSpec).getOutputStream());
assertEquals(mockKeyFunction, ((PartitionByOperatorSpec) registeredOpSpec).getKeyFunction());
assertEquals(mockValueFunction, ((PartitionByOperatorSpec) registeredOpSpec).getValueFunction());
}
use of org.apache.samza.operators.stream.IntermediateMessageStreamImpl in project samza by apache.
the class RepartitionJoinWindowApp method describe.
@Override
public void describe(StreamApplicationDescriptor appDescriptor) {
// offset.default = oldest required for tests since checkpoint topic is empty on start and messages are published
// before the application is run
Config config = appDescriptor.getConfig();
String inputTopic1 = config.get(INPUT_TOPIC_1_CONFIG_KEY);
String inputTopic2 = config.get(INPUT_TOPIC_2_CONFIG_KEY);
String outputTopic = config.get(OUTPUT_TOPIC_CONFIG_KEY);
KafkaSystemDescriptor ksd = new KafkaSystemDescriptor(SYSTEM);
KafkaInputDescriptor<PageView> id1 = ksd.getInputDescriptor(inputTopic1, new JsonSerdeV2<>(PageView.class));
KafkaInputDescriptor<AdClick> id2 = ksd.getInputDescriptor(inputTopic2, new JsonSerdeV2<>(AdClick.class));
MessageStream<PageView> pageViews = appDescriptor.getInputStream(id1);
MessageStream<AdClick> adClicks = appDescriptor.getInputStream(id2);
MessageStream<KV<String, PageView>> pageViewsRepartitionedByViewId = pageViews.partitionBy(PageView::getViewId, pv -> pv, new KVSerde<>(new StringSerde(), new JsonSerdeV2<>(PageView.class)), "pageViewsByViewId");
MessageStream<PageView> pageViewsRepartitionedByViewIdValueONly = pageViewsRepartitionedByViewId.map(KV::getValue);
MessageStream<KV<String, AdClick>> adClicksRepartitionedByViewId = adClicks.partitionBy(AdClick::getViewId, ac -> ac, new KVSerde<>(new StringSerde(), new JsonSerdeV2<>(AdClick.class)), "adClicksByViewId");
MessageStream<AdClick> adClicksRepartitionedByViewIdValueOnly = adClicksRepartitionedByViewId.map(KV::getValue);
MessageStream<UserPageAdClick> userPageAdClicks = pageViewsRepartitionedByViewIdValueONly.join(adClicksRepartitionedByViewIdValueOnly, new UserPageViewAdClicksJoiner(), new StringSerde(), new JsonSerdeV2<>(PageView.class), new JsonSerdeV2<>(AdClick.class), Duration.ofMinutes(1), "pageViewAdClickJoin");
MessageStream<KV<String, UserPageAdClick>> userPageAdClicksByUserId = userPageAdClicks.partitionBy(UserPageAdClick::getUserId, upac -> upac, KVSerde.of(new StringSerde(), new JsonSerdeV2<>(UserPageAdClick.class)), "userPageAdClicksByUserId");
userPageAdClicksByUserId.map(KV::getValue).window(Windows.keyedSessionWindow(UserPageAdClick::getUserId, Duration.ofSeconds(3), new StringSerde(), new JsonSerdeV2<>(UserPageAdClick.class)), "userAdClickWindow").map(windowPane -> KV.of(windowPane.getKey().getKey(), String.valueOf(windowPane.getMessage().size()))).sink((message, messageCollector, taskCoordinator) -> {
taskCoordinator.commit(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER);
messageCollector.send(new OutgoingMessageEnvelope(new SystemStream("kafka", outputTopic), null, message.getKey(), message.getValue()));
});
intermediateStreamIds.add(((IntermediateMessageStreamImpl) pageViewsRepartitionedByViewId).getStreamId());
intermediateStreamIds.add(((IntermediateMessageStreamImpl) adClicksRepartitionedByViewId).getStreamId());
intermediateStreamIds.add(((IntermediateMessageStreamImpl) userPageAdClicksByUserId).getStreamId());
}
use of org.apache.samza.operators.stream.IntermediateMessageStreamImpl in project samza by apache.
the class TestJoinTranslator method testTranslateStreamToTableJoin.
private void testTranslateStreamToTableJoin(boolean isRemoteTable) throws IOException, ClassNotFoundException {
// setup mock values to the constructor of JoinTranslator
final String logicalOpId = "sql0_join3";
final int queryId = 0;
LogicalJoin mockJoin = PowerMockito.mock(LogicalJoin.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
RelNode mockLeftInput = PowerMockito.mock(LogicalTableScan.class);
RelNode mockRightInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockLeftInput);
inputs.add(mockRightInput);
RelOptTable mockLeftTable = mock(RelOptTable.class);
when(mockLeftInput.getTable()).thenReturn(mockLeftTable);
List<String> qualifiedTableName = Arrays.asList("test", "LeftTable");
when(mockLeftTable.getQualifiedName()).thenReturn(qualifiedTableName);
when(mockLeftInput.getId()).thenReturn(1);
when(mockRightInput.getId()).thenReturn(2);
when(mockJoin.getId()).thenReturn(3);
when(mockJoin.getInputs()).thenReturn(inputs);
when(mockJoin.getLeft()).thenReturn(mockLeftInput);
when(mockJoin.getRight()).thenReturn(mockRightInput);
RexCall mockJoinCondition = mock(RexCall.class);
when(mockJoinCondition.isAlwaysTrue()).thenReturn(false);
when(mockJoinCondition.getKind()).thenReturn(SqlKind.EQUALS);
when(mockJoin.getCondition()).thenReturn(mockJoinCondition);
RexInputRef mockLeftConditionInput = mock(RexInputRef.class);
RexInputRef mockRightConditionInput = mock(RexInputRef.class);
when(mockLeftConditionInput.getIndex()).thenReturn(0);
when(mockRightConditionInput.getIndex()).thenReturn(0);
List<RexNode> condOperands = new ArrayList<>();
condOperands.add(mockLeftConditionInput);
condOperands.add(mockRightConditionInput);
when(mockJoinCondition.getOperands()).thenReturn(condOperands);
RelDataType mockLeftCondDataType = mock(RelDataType.class);
RelDataType mockRightCondDataType = mock(RelDataType.class);
when(mockLeftCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockRightCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockLeftConditionInput.getType()).thenReturn(mockLeftCondDataType);
when(mockRightConditionInput.getType()).thenReturn(mockRightCondDataType);
RelDataType mockLeftRowType = mock(RelDataType.class);
// ?? why ??
when(mockLeftRowType.getFieldCount()).thenReturn(0);
when(mockLeftInput.getRowType()).thenReturn(mockLeftRowType);
List<String> leftFieldNames = Collections.singletonList("test_table_field1");
List<String> rightStreamFieldNames = Collections.singletonList("test_stream_field1");
when(mockLeftRowType.getFieldNames()).thenReturn(leftFieldNames);
RelDataType mockRightRowType = mock(RelDataType.class);
when(mockRightInput.getRowType()).thenReturn(mockRightRowType);
when(mockRightRowType.getFieldNames()).thenReturn(rightStreamFieldNames);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockLeftInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockLeftInputStream = new MessageStreamImpl<>(mockAppDesc, mockLeftInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockLeftInput.getId()))).thenReturn(mockLeftInputStream);
OperatorSpec<Object, SamzaSqlRelMessage> mockRightInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockRightInputStream = new MessageStreamImpl<>(mockAppDesc, mockRightInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockRightInput.getId()))).thenReturn(mockRightInputStream);
when(mockTranslatorContext.getStreamAppDescriptor()).thenReturn(mockAppDesc);
InputOperatorSpec mockInputOp = mock(InputOperatorSpec.class);
OutputStreamImpl mockOutputStream = mock(OutputStreamImpl.class);
when(mockInputOp.isKeyed()).thenReturn(true);
when(mockOutputStream.isKeyed()).thenReturn(true);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(3), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
if (isRemoteTable) {
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RemoteTableDescriptor.class));
} else {
IntermediateMessageStreamImpl mockPartitionedStream = new IntermediateMessageStreamImpl(mockAppDesc, mockInputOp, mockOutputStream);
when(mockAppDesc.getIntermediateStream(any(String.class), any(Serde.class), eq(false))).thenReturn(mockPartitionedStream);
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RocksDbTableDescriptor.class));
}
when(mockJoin.getJoinType()).thenReturn(JoinRelType.INNER);
SamzaSqlExecutionContext mockExecutionContext = mock(SamzaSqlExecutionContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(mockExecutionContext);
SamzaSqlApplicationConfig mockAppConfig = mock(SamzaSqlApplicationConfig.class);
when(mockExecutionContext.getSamzaSqlApplicationConfig()).thenReturn(mockAppConfig);
Map<String, SqlIOConfig> ssConfigBySource = mock(HashMap.class);
when(mockAppConfig.getInputSystemStreamConfigBySource()).thenReturn(ssConfigBySource);
SqlIOConfig mockIOConfig = mock(SqlIOConfig.class);
TableDescriptor mockTableDesc;
if (isRemoteTable) {
mockTableDesc = mock(RemoteTableDescriptor.class);
} else {
mockTableDesc = mock(RocksDbTableDescriptor.class);
}
when(ssConfigBySource.get(String.join(".", qualifiedTableName))).thenReturn(mockIOConfig);
when(mockIOConfig.getTableDescriptor()).thenReturn(Optional.of(mockTableDesc));
JoinTranslator joinTranslator = new JoinTranslator(logicalOpId, "", queryId);
// Verify Metrics Works with Join
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
TranslatorInputMetricsMapFunction inputMetricsMF = joinTranslator.getInputMetricsMF();
assertNotNull(inputMetricsMF);
inputMetricsMF.init(mockContext);
TranslatorOutputMetricsMapFunction outputMetricsMF = joinTranslator.getOutputMetricsMF();
assertNotNull(outputMetricsMF);
outputMetricsMF.init(mockContext);
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(logicalOpId).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(1).getCount());
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
// Apply translate() method to verify that we are getting the correct map operator constructed
joinTranslator.translate(mockJoin, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerMessageStream(3, this.getRegisteredMessageStream(3));
when(mockTranslatorContext.getRelNode(3)).thenReturn(mockJoin);
when(mockTranslatorContext.getMessageStream(3)).thenReturn(this.getRegisteredMessageStream(3));
StreamTableJoinOperatorSpec joinSpec = (StreamTableJoinOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(3), "operatorSpec");
assertNotNull(joinSpec);
assertEquals(joinSpec.getOpCode(), OperatorSpec.OpCode.JOIN);
// Verify joinSpec has the corresponding setup
StreamTableJoinFunction joinFn = joinSpec.getJoinFn();
assertNotNull(joinFn);
if (isRemoteTable) {
assertTrue(joinFn instanceof SamzaSqlRemoteTableJoinFunction);
} else {
assertTrue(joinFn instanceof SamzaSqlLocalTableJoinFunction);
}
assertTrue(Whitebox.getInternalState(joinFn, "isTablePosOnRight").equals(false));
assertEquals(Collections.singletonList(0), Whitebox.getInternalState(joinFn, "streamFieldIds"));
assertEquals(leftFieldNames, Whitebox.getInternalState(joinFn, "tableFieldNames"));
List<String> outputFieldNames = new ArrayList<>();
outputFieldNames.addAll(leftFieldNames);
outputFieldNames.addAll(rightStreamFieldNames);
assertEquals(outputFieldNames, Whitebox.getInternalState(joinFn, "outFieldNames"));
}
use of org.apache.samza.operators.stream.IntermediateMessageStreamImpl in project samza by apache.
the class StreamApplicationDescriptorImpl method getIntermediateStream.
/**
* Internal helper for {@link MessageStreamImpl} to add an intermediate {@link MessageStream} to the graph.
* An intermediate {@link MessageStream} is both an output and an input stream.
*
* @param streamId the id of the stream to be created.
* @param serde the {@link Serde} to use for the message in the intermediate stream. If null, the default serde
* is used.
* @param isBroadcast whether the stream is a broadcast stream.
* @param <M> the type of messages in the intermediate {@link MessageStream}
* @return the intermediate {@link MessageStreamImpl}
*/
@VisibleForTesting
public <M> IntermediateMessageStreamImpl<M> getIntermediateStream(String streamId, Serde<M> serde, boolean isBroadcast) {
Preconditions.checkNotNull(serde, "serde must not be null for intermediate stream: " + streamId);
Preconditions.checkState(!inputOperators.containsKey(streamId) && !outputStreams.containsKey(streamId), "getIntermediateStream must not be called multiple times with the same streamId: " + streamId);
if (isBroadcast) {
intermediateBroadcastStreamIds.add(streamId);
}
boolean isKeyed = serde instanceof KVSerde;
KV<Serde, Serde> kvSerdes = getOrCreateStreamSerdes(streamId, serde);
InputTransformer transformer = (InputTransformer) getDefaultSystemDescriptor().flatMap(SystemDescriptor::getTransformer).orElse(null);
InputOperatorSpec inputOperatorSpec = OperatorSpecs.createInputOperatorSpec(streamId, kvSerdes.getKey(), kvSerdes.getValue(), transformer, isKeyed, this.getNextOpId(OpCode.INPUT, null));
inputOperators.put(streamId, inputOperatorSpec);
outputStreams.put(streamId, new OutputStreamImpl(streamId, kvSerdes.getKey(), kvSerdes.getValue(), isKeyed));
return new IntermediateMessageStreamImpl<>(this, inputOperators.get(streamId), outputStreams.get(streamId));
}
Aggregations