use of org.apache.samza.operators.spec.StreamTableJoinOperatorSpec in project samza by apache.
the class OperatorSpecGraphAnalyzer method getJoinToInputOperatorSpecs.
/**
* Returns a grouping of {@link InputOperatorSpec}s by the joins, i.e. {@link JoinOperatorSpec}s and
* {@link StreamTableJoinOperatorSpec}s, they participate in.
*
* The key of the returned Multimap is of type {@link OperatorSpec} due to the lack of a stricter
* base type for {@link JoinOperatorSpec} and {@link StreamTableJoinOperatorSpec}. However, key
* objects are guaranteed to be of either type only.
*/
public static Multimap<OperatorSpec, InputOperatorSpec> getJoinToInputOperatorSpecs(Collection<InputOperatorSpec> inputOpSpecs) {
Multimap<OperatorSpec, InputOperatorSpec> joinToInputOpSpecs = HashMultimap.create();
// Create a getNextOpSpecs() function that emulates connections between every SendToTableOperatorSpec
// — which are terminal OperatorSpecs — and all StreamTableJoinOperatorSpecs referencing the same table.
//
// This is necessary to support Stream-Table Join scenarios because it allows us to associate streams behind
// SendToTableOperatorSpecs with streams participating in Stream-Table Joins, a connection that would not be
// easy to make otherwise since SendToTableOperatorSpecs are terminal operator specs.
Function<OperatorSpec, Iterable<OperatorSpec>> getNextOpSpecs = getCustomGetNextOpSpecs(inputOpSpecs);
// and join-related operator specs.
for (InputOperatorSpec inputOpSpec : inputOpSpecs) {
// Observe all join-related operator specs reachable from this input operator spec.
JoinVisitor joinVisitor = new JoinVisitor();
traverse(inputOpSpec, joinVisitor, getNextOpSpecs);
// Associate every encountered join-related operator spec with this input operator spec.
for (OperatorSpec joinOpSpec : joinVisitor.getJoins()) {
joinToInputOpSpecs.put(joinOpSpec, inputOpSpec);
}
}
return joinToInputOpSpecs;
}
use of org.apache.samza.operators.spec.StreamTableJoinOperatorSpec in project samza by apache.
the class OperatorSpecGraphAnalyzer method getCustomGetNextOpSpecs.
/**
* Creates a function that retrieves the next {@link OperatorSpec}s of any given {@link OperatorSpec} in the specified
* {@code operatorSpecGraph}.
*
* Calling the returned function with any {@link SendToTableOperatorSpec} will return a collection of all
* {@link StreamTableJoinOperatorSpec}s that reference the same table as the specified
* {@link SendToTableOperatorSpec}, as if they were actually connected.
*/
private static Function<OperatorSpec, Iterable<OperatorSpec>> getCustomGetNextOpSpecs(Iterable<InputOperatorSpec> inputOpSpecs) {
// Traverse operatorSpecGraph to create mapping between every SendToTableOperatorSpec and all
// StreamTableJoinOperatorSpecs referencing the same table.
TableJoinVisitor tableJoinVisitor = new TableJoinVisitor();
for (InputOperatorSpec inputOpSpec : inputOpSpecs) {
traverse(inputOpSpec, tableJoinVisitor, opSpec -> opSpec.getRegisteredOperatorSpecs());
}
Multimap<SendToTableOperatorSpec, StreamTableJoinOperatorSpec> sendToTableOpSpecToStreamTableJoinOpSpecs = tableJoinVisitor.getSendToTableOpSpecToStreamTableJoinOpSpecs();
return operatorSpec -> {
// For all other types of operator specs, return the next registered operator specs.
if (operatorSpec instanceof SendToTableOperatorSpec) {
SendToTableOperatorSpec sendToTableOperatorSpec = (SendToTableOperatorSpec) operatorSpec;
return Collections.unmodifiableCollection(sendToTableOpSpecToStreamTableJoinOpSpecs.get(sendToTableOperatorSpec));
}
return operatorSpec.getRegisteredOperatorSpecs();
};
}
use of org.apache.samza.operators.spec.StreamTableJoinOperatorSpec in project samza by apache.
the class TestStreamTableJoinOperatorImpl method testHandleMessage.
@Test
public void testHandleMessage() {
String tableId = "t1";
StreamTableJoinOperatorSpec mockJoinOpSpec = mock(StreamTableJoinOperatorSpec.class);
when(mockJoinOpSpec.getTableId()).thenReturn(tableId);
when(mockJoinOpSpec.getArgs()).thenReturn(new Object[0]);
when(mockJoinOpSpec.getJoinFn()).thenReturn(new StreamTableJoinFunction<String, KV<String, String>, KV<String, String>, String>() {
@Override
public String apply(KV<String, String> message, KV<String, String> record) {
if ("1".equals(message.getKey())) {
Assert.assertEquals("m1", message.getValue());
Assert.assertEquals("r1", record.getValue());
return "m1r1";
} else if ("2".equals(message.getKey())) {
Assert.assertEquals("m2", message.getValue());
Assert.assertNull(record);
return null;
}
throw new SamzaException("Should never reach here!");
}
@Override
public String getMessageKey(KV<String, String> message) {
return message.getKey();
}
@Override
public String getRecordKey(KV<String, String> record) {
return record.getKey();
}
});
ReadWriteUpdateTable table = mock(ReadWriteUpdateTable.class);
when(table.getAsync("1")).thenReturn(CompletableFuture.completedFuture("r1"));
when(table.getAsync("2")).thenReturn(CompletableFuture.completedFuture(null));
Context context = new MockContext();
when(context.getTaskContext().getUpdatableTable(tableId)).thenReturn(table);
MessageCollector mockMessageCollector = mock(MessageCollector.class);
TaskCoordinator mockTaskCoordinator = mock(TaskCoordinator.class);
StreamTableJoinOperatorImpl streamTableJoinOperator = new StreamTableJoinOperatorImpl(mockJoinOpSpec, context);
// Table has the key
Collection<TestMessageEnvelope> result;
result = streamTableJoinOperator.handleMessage(KV.of("1", "m1"), mockMessageCollector, mockTaskCoordinator);
Assert.assertEquals(1, result.size());
Assert.assertEquals("m1r1", result.iterator().next());
// Table doesn't have the key
result = streamTableJoinOperator.handleMessage(KV.of("2", "m2"), mockMessageCollector, mockTaskCoordinator);
Assert.assertEquals(0, result.size());
}
use of org.apache.samza.operators.spec.StreamTableJoinOperatorSpec in project samza by apache.
the class TestJoinTranslator method testTranslateStreamToTableJoin.
private void testTranslateStreamToTableJoin(boolean isRemoteTable) throws IOException, ClassNotFoundException {
// setup mock values to the constructor of JoinTranslator
final String logicalOpId = "sql0_join3";
final int queryId = 0;
LogicalJoin mockJoin = PowerMockito.mock(LogicalJoin.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
RelNode mockLeftInput = PowerMockito.mock(LogicalTableScan.class);
RelNode mockRightInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockLeftInput);
inputs.add(mockRightInput);
RelOptTable mockLeftTable = mock(RelOptTable.class);
when(mockLeftInput.getTable()).thenReturn(mockLeftTable);
List<String> qualifiedTableName = Arrays.asList("test", "LeftTable");
when(mockLeftTable.getQualifiedName()).thenReturn(qualifiedTableName);
when(mockLeftInput.getId()).thenReturn(1);
when(mockRightInput.getId()).thenReturn(2);
when(mockJoin.getId()).thenReturn(3);
when(mockJoin.getInputs()).thenReturn(inputs);
when(mockJoin.getLeft()).thenReturn(mockLeftInput);
when(mockJoin.getRight()).thenReturn(mockRightInput);
RexCall mockJoinCondition = mock(RexCall.class);
when(mockJoinCondition.isAlwaysTrue()).thenReturn(false);
when(mockJoinCondition.getKind()).thenReturn(SqlKind.EQUALS);
when(mockJoin.getCondition()).thenReturn(mockJoinCondition);
RexInputRef mockLeftConditionInput = mock(RexInputRef.class);
RexInputRef mockRightConditionInput = mock(RexInputRef.class);
when(mockLeftConditionInput.getIndex()).thenReturn(0);
when(mockRightConditionInput.getIndex()).thenReturn(0);
List<RexNode> condOperands = new ArrayList<>();
condOperands.add(mockLeftConditionInput);
condOperands.add(mockRightConditionInput);
when(mockJoinCondition.getOperands()).thenReturn(condOperands);
RelDataType mockLeftCondDataType = mock(RelDataType.class);
RelDataType mockRightCondDataType = mock(RelDataType.class);
when(mockLeftCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockRightCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockLeftConditionInput.getType()).thenReturn(mockLeftCondDataType);
when(mockRightConditionInput.getType()).thenReturn(mockRightCondDataType);
RelDataType mockLeftRowType = mock(RelDataType.class);
// ?? why ??
when(mockLeftRowType.getFieldCount()).thenReturn(0);
when(mockLeftInput.getRowType()).thenReturn(mockLeftRowType);
List<String> leftFieldNames = Collections.singletonList("test_table_field1");
List<String> rightStreamFieldNames = Collections.singletonList("test_stream_field1");
when(mockLeftRowType.getFieldNames()).thenReturn(leftFieldNames);
RelDataType mockRightRowType = mock(RelDataType.class);
when(mockRightInput.getRowType()).thenReturn(mockRightRowType);
when(mockRightRowType.getFieldNames()).thenReturn(rightStreamFieldNames);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockLeftInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockLeftInputStream = new MessageStreamImpl<>(mockAppDesc, mockLeftInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockLeftInput.getId()))).thenReturn(mockLeftInputStream);
OperatorSpec<Object, SamzaSqlRelMessage> mockRightInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockRightInputStream = new MessageStreamImpl<>(mockAppDesc, mockRightInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockRightInput.getId()))).thenReturn(mockRightInputStream);
when(mockTranslatorContext.getStreamAppDescriptor()).thenReturn(mockAppDesc);
InputOperatorSpec mockInputOp = mock(InputOperatorSpec.class);
OutputStreamImpl mockOutputStream = mock(OutputStreamImpl.class);
when(mockInputOp.isKeyed()).thenReturn(true);
when(mockOutputStream.isKeyed()).thenReturn(true);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(3), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
if (isRemoteTable) {
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RemoteTableDescriptor.class));
} else {
IntermediateMessageStreamImpl mockPartitionedStream = new IntermediateMessageStreamImpl(mockAppDesc, mockInputOp, mockOutputStream);
when(mockAppDesc.getIntermediateStream(any(String.class), any(Serde.class), eq(false))).thenReturn(mockPartitionedStream);
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RocksDbTableDescriptor.class));
}
when(mockJoin.getJoinType()).thenReturn(JoinRelType.INNER);
SamzaSqlExecutionContext mockExecutionContext = mock(SamzaSqlExecutionContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(mockExecutionContext);
SamzaSqlApplicationConfig mockAppConfig = mock(SamzaSqlApplicationConfig.class);
when(mockExecutionContext.getSamzaSqlApplicationConfig()).thenReturn(mockAppConfig);
Map<String, SqlIOConfig> ssConfigBySource = mock(HashMap.class);
when(mockAppConfig.getInputSystemStreamConfigBySource()).thenReturn(ssConfigBySource);
SqlIOConfig mockIOConfig = mock(SqlIOConfig.class);
TableDescriptor mockTableDesc;
if (isRemoteTable) {
mockTableDesc = mock(RemoteTableDescriptor.class);
} else {
mockTableDesc = mock(RocksDbTableDescriptor.class);
}
when(ssConfigBySource.get(String.join(".", qualifiedTableName))).thenReturn(mockIOConfig);
when(mockIOConfig.getTableDescriptor()).thenReturn(Optional.of(mockTableDesc));
JoinTranslator joinTranslator = new JoinTranslator(logicalOpId, "", queryId);
// Verify Metrics Works with Join
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
TranslatorInputMetricsMapFunction inputMetricsMF = joinTranslator.getInputMetricsMF();
assertNotNull(inputMetricsMF);
inputMetricsMF.init(mockContext);
TranslatorOutputMetricsMapFunction outputMetricsMF = joinTranslator.getOutputMetricsMF();
assertNotNull(outputMetricsMF);
outputMetricsMF.init(mockContext);
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(logicalOpId).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(1).getCount());
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
// Apply translate() method to verify that we are getting the correct map operator constructed
joinTranslator.translate(mockJoin, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerMessageStream(3, this.getRegisteredMessageStream(3));
when(mockTranslatorContext.getRelNode(3)).thenReturn(mockJoin);
when(mockTranslatorContext.getMessageStream(3)).thenReturn(this.getRegisteredMessageStream(3));
StreamTableJoinOperatorSpec joinSpec = (StreamTableJoinOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(3), "operatorSpec");
assertNotNull(joinSpec);
assertEquals(joinSpec.getOpCode(), OperatorSpec.OpCode.JOIN);
// Verify joinSpec has the corresponding setup
StreamTableJoinFunction joinFn = joinSpec.getJoinFn();
assertNotNull(joinFn);
if (isRemoteTable) {
assertTrue(joinFn instanceof SamzaSqlRemoteTableJoinFunction);
} else {
assertTrue(joinFn instanceof SamzaSqlLocalTableJoinFunction);
}
assertTrue(Whitebox.getInternalState(joinFn, "isTablePosOnRight").equals(false));
assertEquals(Collections.singletonList(0), Whitebox.getInternalState(joinFn, "streamFieldIds"));
assertEquals(leftFieldNames, Whitebox.getInternalState(joinFn, "tableFieldNames"));
List<String> outputFieldNames = new ArrayList<>();
outputFieldNames.addAll(leftFieldNames);
outputFieldNames.addAll(rightStreamFieldNames);
assertEquals(outputFieldNames, Whitebox.getInternalState(joinFn, "outFieldNames"));
}
use of org.apache.samza.operators.spec.StreamTableJoinOperatorSpec in project samza by apache.
the class TestStreamTableJoinOperatorImpl method testJoinFunctionIsInvokedOnlyOnce.
/**
* Ensure join function is not invoked more than once when join function returns null on the first invocation
*/
@Test
public void testJoinFunctionIsInvokedOnlyOnce() {
final String tableId = "testTable";
final CountDownLatch joinInvokedLatch = new CountDownLatch(1);
StreamTableJoinOperatorSpec mockJoinOpSpec = mock(StreamTableJoinOperatorSpec.class);
when(mockJoinOpSpec.getTableId()).thenReturn(tableId);
when(mockJoinOpSpec.getArgs()).thenReturn(new Object[0]);
when(mockJoinOpSpec.getJoinFn()).thenReturn(new StreamTableJoinFunction<String, KV<String, String>, KV<String, String>, String>() {
@Override
public String apply(KV<String, String> message, KV<String, String> record) {
joinInvokedLatch.countDown();
return null;
}
@Override
public String getMessageKey(KV<String, String> message) {
return message.getKey();
}
@Override
public String getRecordKey(KV<String, String> record) {
return record.getKey();
}
});
ReadWriteUpdateTable table = mock(ReadWriteUpdateTable.class);
when(table.getAsync("1")).thenReturn(CompletableFuture.completedFuture("r1"));
Context context = new MockContext();
when(context.getTaskContext().getUpdatableTable(tableId)).thenReturn(table);
MessageCollector mockMessageCollector = mock(MessageCollector.class);
TaskCoordinator mockTaskCoordinator = mock(TaskCoordinator.class);
StreamTableJoinOperatorImpl streamTableJoinOperator = new StreamTableJoinOperatorImpl(mockJoinOpSpec, context);
// Table has the key
streamTableJoinOperator.handleMessage(KV.of("1", "m1"), mockMessageCollector, mockTaskCoordinator);
assertEquals("Join function should only be invoked once", 0, joinInvokedLatch.getCount());
}
Aggregations