use of org.apache.samza.operators.functions.JoinFunction in project samza by apache.
the class TestOperatorImpls method testJoinChain.
@Test
public void testJoinChain() throws IllegalAccessException, InvocationTargetException {
// test creation of join chain
StreamGraphImpl mockGraph = mock(StreamGraphImpl.class);
MessageStreamImpl<TestMessageEnvelope> input1 = TestMessageStreamImplUtil.getMessageStreamImpl(mockGraph);
MessageStreamImpl<TestMessageEnvelope> input2 = TestMessageStreamImplUtil.getMessageStreamImpl(mockGraph);
TaskContext mockContext = mock(TaskContext.class);
when(mockContext.getMetricsRegistry()).thenReturn(new MetricsRegistryMap());
Config mockConfig = mock(Config.class);
input1.join(input2, new JoinFunction<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope>() {
@Override
public TestOutputMessageEnvelope apply(TestMessageEnvelope m1, TestMessageEnvelope m2) {
return new TestOutputMessageEnvelope(m1.getKey(), m1.getMessage().getValue().length() + m2.getMessage().getValue().length());
}
@Override
public String getFirstKey(TestMessageEnvelope message) {
return message.getKey();
}
@Override
public String getSecondKey(TestMessageEnvelope message) {
return message.getKey();
}
}, Duration.ofMinutes(1)).map(m -> m);
OperatorImplGraph opGraph = new OperatorImplGraph();
// now, we create chained operators from each input sources
RootOperatorImpl chain1 = (RootOperatorImpl) createOpsMethod.invoke(opGraph, input1, mockConfig, mockContext);
RootOperatorImpl chain2 = (RootOperatorImpl) createOpsMethod.invoke(opGraph, input2, mockConfig, mockContext);
// check that those two chains will merge at map operator
// first branch of the join
Set<OperatorImpl> subsSet = (Set<OperatorImpl>) nextOperatorsField.get(chain1);
assertEquals(subsSet.size(), 1);
OperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> joinOp1 = subsSet.iterator().next();
Set<OperatorImpl> subsOps = (Set<OperatorImpl>) nextOperatorsField.get(joinOp1);
assertEquals(subsOps.size(), 1);
// the map operator consumes the common join output, where two branches merge
OperatorImpl mapImpl = subsOps.iterator().next();
// second branch of the join
subsSet = (Set<OperatorImpl>) nextOperatorsField.get(chain2);
assertEquals(subsSet.size(), 1);
OperatorImpl<TestMessageEnvelope, TestOutputMessageEnvelope> joinOp2 = subsSet.iterator().next();
assertNotSame(joinOp1, joinOp2);
subsOps = (Set<OperatorImpl>) nextOperatorsField.get(joinOp2);
assertEquals(subsOps.size(), 1);
// make sure that the map operator is the same
assertEquals(mapImpl, subsOps.iterator().next());
}
use of org.apache.samza.operators.functions.JoinFunction in project samza by apache.
the class TestOperatorImplGraph method testJoinChain.
@Test
public void testJoinChain() {
String inputStreamId1 = "input1";
String inputStreamId2 = "input2";
String inputSystem = "input-system";
String inputPhysicalName1 = "input-stream1";
String inputPhysicalName2 = "input-stream2";
HashMap<String, String> configs = new HashMap<>();
configs.put(JobConfig.JOB_NAME, "jobName");
configs.put(JobConfig.JOB_ID, "jobId");
StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputPhysicalName1);
StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputPhysicalName2);
Config config = new MapConfig(configs);
when(this.context.getJobContext().getConfig()).thenReturn(config);
Integer joinKey = new Integer(1);
Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey;
JoinFunction testJoinFunction = new TestJoinFunction("jobName-jobId-join-j1", (BiFunction & Serializable) (m1, m2) -> KV.of(m1, m2), keyFn, keyFn);
StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass");
GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class));
GenericInputDescriptor inputDescriptor2 = sd.getInputDescriptor(inputStreamId2, mock(Serde.class));
MessageStream<Object> inputStream1 = appDesc.getInputStream(inputDescriptor1);
MessageStream<Object> inputStream2 = appDesc.getInputStream(inputDescriptor2);
inputStream1.join(inputStream2, testJoinFunction, mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1");
}, config);
TaskName mockTaskName = mock(TaskName.class);
TaskModel taskModel = mock(TaskModel.class);
when(taskModel.getTaskName()).thenReturn(mockTaskName);
when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel);
KeyValueStore mockLeftStore = mock(KeyValueStore.class);
when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore);
KeyValueStore mockRightStore = mock(KeyValueStore.class);
when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore);
OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class));
// verify that join function is initialized once.
assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1);
InputOperatorImpl inputOpImpl1 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName1));
InputOperatorImpl inputOpImpl2 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName2));
PartialJoinOperatorImpl leftPartialJoinOpImpl = (PartialJoinOperatorImpl) inputOpImpl1.registeredOperators.iterator().next();
PartialJoinOperatorImpl rightPartialJoinOpImpl = (PartialJoinOperatorImpl) inputOpImpl2.registeredOperators.iterator().next();
assertEquals(leftPartialJoinOpImpl.getOperatorSpec(), rightPartialJoinOpImpl.getOperatorSpec());
assertNotSame(leftPartialJoinOpImpl, rightPartialJoinOpImpl);
// verify that left partial join operator calls getFirstKey
Object mockLeftMessage = mock(Object.class);
long currentTimeMillis = System.currentTimeMillis();
when(mockLeftStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockLeftMessage, currentTimeMillis));
IncomingMessageEnvelope leftMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockLeftMessage);
inputOpImpl1.onMessage(leftMessage, mock(MessageCollector.class), mock(TaskCoordinator.class));
// verify that right partial join operator calls getSecondKey
Object mockRightMessage = mock(Object.class);
when(mockRightStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockRightMessage, currentTimeMillis));
IncomingMessageEnvelope rightMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockRightMessage);
inputOpImpl2.onMessage(rightMessage, mock(MessageCollector.class), mock(TaskCoordinator.class));
// verify that the join function apply is called with the correct messages on match
assertEquals(((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.size(), 1);
KV joinResult = (KV) ((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.iterator().next();
assertEquals(joinResult.getKey(), mockLeftMessage);
assertEquals(joinResult.getValue(), mockRightMessage);
}
use of org.apache.samza.operators.functions.JoinFunction in project samza by apache.
the class TestJobGraphJsonGenerator method testRepartitionedJoinStreamApplication.
@Test
public void testRepartitionedJoinStreamApplication() throws Exception {
/**
* the graph looks like the following.
* number in parentheses () indicates number of stream partitions.
* number in parentheses in quotes ("") indicates expected partition count.
* number in square brackets [] indicates operator ID.
*
* input3 (32) -> filter [7] -> partitionBy [8] ("64") -> map [10] -> join [14] -> sendTo(output2) [15] (16)
* |
* input2 (16) -> partitionBy [3] ("64") -> filter [5] -| -> sink [13]
* |
* input1 (64) -> map [1] -> join [11] -> sendTo(output1) [12] (8)
*/
Map<String, String> configMap = new HashMap<>();
configMap.put(JobConfig.JOB_NAME, "test-app");
configMap.put(JobConfig.JOB_DEFAULT_SYSTEM, "test-system");
StreamTestUtils.addStreamConfigs(configMap, "input1", "system1", "input1");
StreamTestUtils.addStreamConfigs(configMap, "input2", "system2", "input2");
StreamTestUtils.addStreamConfigs(configMap, "input3", "system2", "input3");
StreamTestUtils.addStreamConfigs(configMap, "output1", "system1", "output1");
StreamTestUtils.addStreamConfigs(configMap, "output2", "system2", "output2");
Config config = new MapConfig(configMap);
// set up external partition count
Map<String, Integer> system1Map = new HashMap<>();
system1Map.put("input1", 64);
system1Map.put("output1", 8);
Map<String, Integer> system2Map = new HashMap<>();
system2Map.put("input2", 16);
system2Map.put("input3", 32);
system2Map.put("output2", 16);
SystemAdmin systemAdmin1 = createSystemAdmin(system1Map);
SystemAdmin systemAdmin2 = createSystemAdmin(system2Map);
SystemAdmins systemAdmins = mock(SystemAdmins.class);
when(systemAdmins.getSystemAdmin("system1")).thenReturn(systemAdmin1);
when(systemAdmins.getSystemAdmin("system2")).thenReturn(systemAdmin2);
StreamManager streamManager = new StreamManager(systemAdmins);
StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> {
KVSerde<Object, Object> kvSerde = new KVSerde<>(new NoOpSerde(), new NoOpSerde());
String mockSystemFactoryClass = "factory.class.name";
GenericSystemDescriptor system1 = new GenericSystemDescriptor("system1", mockSystemFactoryClass);
GenericSystemDescriptor system2 = new GenericSystemDescriptor("system2", mockSystemFactoryClass);
GenericInputDescriptor<KV<Object, Object>> input1Descriptor = system1.getInputDescriptor("input1", kvSerde);
GenericInputDescriptor<KV<Object, Object>> input2Descriptor = system2.getInputDescriptor("input2", kvSerde);
GenericInputDescriptor<KV<Object, Object>> input3Descriptor = system2.getInputDescriptor("input3", kvSerde);
GenericOutputDescriptor<KV<Object, Object>> output1Descriptor = system1.getOutputDescriptor("output1", kvSerde);
GenericOutputDescriptor<KV<Object, Object>> output2Descriptor = system2.getOutputDescriptor("output2", kvSerde);
MessageStream<KV<Object, Object>> messageStream1 = appDesc.getInputStream(input1Descriptor).map(m -> m);
MessageStream<KV<Object, Object>> messageStream2 = appDesc.getInputStream(input2Descriptor).partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p1").filter(m -> true);
MessageStream<KV<Object, Object>> messageStream3 = appDesc.getInputStream(input3Descriptor).filter(m -> true).partitionBy(m -> m.key, m -> m.value, mock(KVSerde.class), "p2").map(m -> m);
OutputStream<KV<Object, Object>> outputStream1 = appDesc.getOutputStream(output1Descriptor);
OutputStream<KV<Object, Object>> outputStream2 = appDesc.getOutputStream(output2Descriptor);
messageStream1.join(messageStream2, (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1").sendTo(outputStream1);
messageStream2.sink((message, collector, coordinator) -> {
});
messageStream3.join(messageStream2, (JoinFunction<Object, KV<Object, Object>, KV<Object, Object>, KV<Object, Object>>) mock(JoinFunction.class), mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2").sendTo(outputStream2);
}, config);
ExecutionPlanner planner = new ExecutionPlanner(config, streamManager);
ExecutionPlan plan = planner.plan(graphSpec);
String json = plan.getPlanAsJson();
System.out.println(json);
// deserialize
ObjectMapper mapper = new ObjectMapper();
JobGraphJsonGenerator.JobGraphJson nodes = mapper.readValue(json, JobGraphJsonGenerator.JobGraphJson.class);
assertEquals(5, nodes.jobs.get(0).operatorGraph.inputStreams.size());
assertEquals(11, nodes.jobs.get(0).operatorGraph.operators.size());
assertEquals(3, nodes.sourceStreams.size());
assertEquals(2, nodes.sinkStreams.size());
assertEquals(2, nodes.intermediateStreams.size());
}
Aggregations