use of org.apache.samza.operators.MessageStream in project samza by apache.
the class TestProjectTranslator method testTranslate.
@Test
public void testTranslate() throws IOException, ClassNotFoundException {
// setup mock values to the constructor of FilterTranslator
LogicalProject mockProject = PowerMockito.mock(LogicalProject.class);
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
RelNode mockInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockInput);
when(mockInput.getId()).thenReturn(1);
when(mockProject.getId()).thenReturn(2);
when(mockProject.getInputs()).thenReturn(inputs);
when(mockProject.getInput()).thenReturn(mockInput);
RelDataType mockRowType = mock(RelDataType.class);
when(mockRowType.getFieldCount()).thenReturn(1);
when(mockProject.getRowType()).thenReturn(mockRowType);
RexNode mockRexField = mock(RexNode.class);
List<Pair<RexNode, String>> namedProjects = new ArrayList<>();
namedProjects.add(Pair.of(mockRexField, "test_field"));
when(mockProject.getNamedProjects()).thenReturn(namedProjects);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockAppDesc, mockInputOp);
when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
// Apply translate() method to verify that we are getting the correct map operator constructed
ProjectTranslator projectTranslator = new ProjectTranslator(1);
projectTranslator.translate(mockProject, LOGICAL_OP_ID, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerRelNode(2, mockProject);
verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
when(mockTranslatorContext.getRelNode(2)).thenReturn(mockProject);
when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
StreamOperatorSpec projectSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
assertNotNull(projectSpec);
assertEquals(projectSpec.getOpCode(), OperatorSpec.OpCode.MAP);
// Verify that the bootstrap() method will establish the context for the map function
Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
mockContexts.put(1, mockTranslatorContext);
when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
projectSpec.getTransformFn().init(mockContext);
MapFunction mapFn = (MapFunction) Whitebox.getInternalState(projectSpec, "mapFn");
assertNotNull(mapFn);
assertEquals(mockTranslatorContext, Whitebox.getInternalState(mapFn, "translatorContext"));
assertEquals(mockProject, Whitebox.getInternalState(mapFn, "project"));
assertEquals(mockExpr, Whitebox.getInternalState(mapFn, "expr"));
// Verify TestMetricsRegistryImpl works with Project
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
assertEquals(2, testMetricsRegistryImpl.getGauges().get(LOGICAL_OP_ID).size());
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
// Calling mapFn.apply() to verify the filter function is correctly applied to the input message
SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(), new SamzaSqlRelMsgMetadata(0L, 0L));
SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
DataContext dataContext = mock(DataContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
Object[] result = new Object[1];
final Object mockFieldObj = new Object();
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = mockFieldObj;
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
SamzaSqlRelMessage retMsg = (SamzaSqlRelMessage) mapFn.apply(mockInputMsg);
assertEquals(retMsg.getSamzaSqlRelRecord().getFieldNames(), Collections.singletonList("test_field"));
assertEquals(retMsg.getSamzaSqlRelRecord().getFieldValues(), Collections.singletonList(mockFieldObj));
// Verify mapFn.apply() updates the TestMetricsRegistryImpl metrics
assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
use of org.apache.samza.operators.MessageStream in project samza by apache.
the class RepartitionExample method describe.
@Override
public void describe(StreamApplicationDescriptor appDescriptor) {
KafkaSystemDescriptor trackingSystem = new KafkaSystemDescriptor("tracking");
KafkaInputDescriptor<PageViewEvent> inputStreamDescriptor = trackingSystem.getInputDescriptor("pageViewEvent", new JsonSerdeV2<>(PageViewEvent.class));
KafkaOutputDescriptor<KV<String, MyStreamOutput>> outputStreamDescriptor = trackingSystem.getOutputDescriptor("pageViewEventPerMember", KVSerde.of(new StringSerde(), new JsonSerdeV2<>(MyStreamOutput.class)));
appDescriptor.withDefaultSystem(trackingSystem);
MessageStream<PageViewEvent> pageViewEvents = appDescriptor.getInputStream(inputStreamDescriptor);
OutputStream<KV<String, MyStreamOutput>> pageViewEventPerMember = appDescriptor.getOutputStream(outputStreamDescriptor);
pageViewEvents.partitionBy(pve -> pve.getMemberId(), pve -> pve, KVSerde.of(new StringSerde(), new JsonSerdeV2<>(PageViewEvent.class)), "partitionBy").window(Windows.keyedTumblingWindow(KV::getKey, Duration.ofMinutes(5), () -> 0, (m, c) -> c + 1, null, null), "window").map(windowPane -> KV.of(windowPane.getKey().getKey(), new MyStreamOutput(windowPane))).sendTo(pageViewEventPerMember);
}
use of org.apache.samza.operators.MessageStream in project samza by apache.
the class AsyncApplicationExample method describe.
@Override
public void describe(StreamApplicationDescriptor appDescriptor) {
KafkaSystemDescriptor trackingSystem = new KafkaSystemDescriptor("tracking");
KafkaInputDescriptor<AdClickEvent> inputStreamDescriptor = trackingSystem.getInputDescriptor("adClickEvent", new JsonSerdeV2<>(AdClickEvent.class));
KafkaOutputDescriptor<KV<String, EnrichedAdClickEvent>> outputStreamDescriptor = trackingSystem.getOutputDescriptor("enrichedAdClickEvent", KVSerde.of(new StringSerde(), new JsonSerdeV2<>(EnrichedAdClickEvent.class)));
MessageStream<AdClickEvent> adClickEventStream = appDescriptor.getInputStream(inputStreamDescriptor);
OutputStream<KV<String, EnrichedAdClickEvent>> enrichedAdClickStream = appDescriptor.getOutputStream(outputStreamDescriptor);
adClickEventStream.flatMapAsync(AsyncApplicationExample::enrichAdClickEvent).map(enrichedAdClickEvent -> KV.of(enrichedAdClickEvent.getCountry(), enrichedAdClickEvent)).sendTo(enrichedAdClickStream);
}
use of org.apache.samza.operators.MessageStream in project samza by apache.
the class TestOperatorImplGraph method testJoinChain.
@Test
public void testJoinChain() {
String inputStreamId1 = "input1";
String inputStreamId2 = "input2";
String inputSystem = "input-system";
String inputPhysicalName1 = "input-stream1";
String inputPhysicalName2 = "input-stream2";
HashMap<String, String> configs = new HashMap<>();
configs.put(JobConfig.JOB_NAME, "jobName");
configs.put(JobConfig.JOB_ID, "jobId");
StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputPhysicalName1);
StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputPhysicalName2);
Config config = new MapConfig(configs);
when(this.context.getJobContext().getConfig()).thenReturn(config);
Integer joinKey = new Integer(1);
Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey;
JoinFunction testJoinFunction = new TestJoinFunction("jobName-jobId-join-j1", (BiFunction & Serializable) (m1, m2) -> KV.of(m1, m2), keyFn, keyFn);
StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class));
GenericInputDescriptor inputDescriptor2 = sd.getInputDescriptor(inputStreamId2, mock(Serde.class));
MessageStream<Object> inputStream1 = appDesc.getInputStream(inputDescriptor1);
MessageStream<Object> inputStream2 = appDesc.getInputStream(inputDescriptor2);
inputStream1.join(inputStream2, testJoinFunction, mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1");
}, config);
TaskName mockTaskName = mock(TaskName.class);
TaskModel taskModel = mock(TaskModel.class);
when(taskModel.getTaskName()).thenReturn(mockTaskName);
when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
KeyValueStore mockLeftStore = mock(KeyValueStore.class);
when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore);
KeyValueStore mockRightStore = mock(KeyValueStore.class);
when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore);
OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
// verify that join function is initialized once.
assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1);
InputOperatorImpl inputOpImpl1 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName1));
InputOperatorImpl inputOpImpl2 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName2));
PartialJoinOperatorImpl leftPartialJoinOpImpl = (PartialJoinOperatorImpl) inputOpImpl1.registeredOperators.iterator().next();
PartialJoinOperatorImpl rightPartialJoinOpImpl = (PartialJoinOperatorImpl) inputOpImpl2.registeredOperators.iterator().next();
assertEquals(leftPartialJoinOpImpl.getOperatorSpec(), rightPartialJoinOpImpl.getOperatorSpec());
assertNotSame(leftPartialJoinOpImpl, rightPartialJoinOpImpl);
// verify that left partial join operator calls getFirstKey
Object mockLeftMessage = mock(Object.class);
long currentTimeMillis = System.currentTimeMillis();
when(mockLeftStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockLeftMessage, currentTimeMillis));
IncomingMessageEnvelope leftMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockLeftMessage);
inputOpImpl1.onMessage(leftMessage, mock(MessageCollector.class), mock(TaskCoordinator.class));
// verify that right partial join operator calls getSecondKey
Object mockRightMessage = mock(Object.class);
when(mockRightStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockRightMessage, currentTimeMillis));
IncomingMessageEnvelope rightMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockRightMessage);
inputOpImpl2.onMessage(rightMessage, mock(MessageCollector.class), mock(TaskCoordinator.class));
// verify that the join function apply is called with the correct messages on match
assertEquals(((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.size(), 1);
KV joinResult = (KV) ((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.iterator().next();
assertEquals(joinResult.getKey(), mockLeftMessage);
assertEquals(joinResult.getValue(), mockRightMessage);
}
use of org.apache.samza.operators.MessageStream in project samza by apache.
the class TestOperatorImplGraph method testMergeChain.
@Test
public void testMergeChain() {
String inputStreamId = "input";
String inputSystem = "input-system";
StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class));
MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor);
MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class));
MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class));
stream1.merge(Collections.singleton(stream2)).map(new TestMapFunction<Object, Object>("test-map-1", (Function & Serializable) m -> m));
}, getConfig());
TaskName mockTaskName = mock(TaskName.class);
TaskModel taskModel = mock(TaskModel.class);
when(taskModel.getTaskName()).thenReturn(mockTaskName);
when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
Set<OperatorImpl> opSet = opImplGraph.getAllInputOperators().stream().collect(HashSet::new, (s, op) -> addOperatorRecursively(s, op), HashSet::addAll);
Object[] mergeOps = opSet.stream().filter(op -> op.getOperatorSpec().getOpCode() == OpCode.MERGE).toArray();
assertEquals(1, mergeOps.length);
assertEquals(1, ((OperatorImpl) mergeOps[0]).registeredOperators.size());
OperatorImpl mapOp = (OperatorImpl) ((OperatorImpl) mergeOps[0]).registeredOperators.iterator().next();
assertEquals(mapOp.getOperatorSpec().getOpCode(), OpCode.MAP);
// verify that the DAG after merge is only traversed & initialized once
assertEquals(TestMapFunction.getInstanceByTaskName(mockTaskName, "test-map-1").numInitCalled, 1);
}
Aggregations