use of org.apache.samza.sql.data.SamzaSqlRelMessage in project samza by apache.
the class TestSamzaSqlRemoteTableJoinFunction method testWithInnerJoinWithTableOnRight.
@Test
public void testWithInnerJoinWithTableOnRight() {
Map<String, String> props = new HashMap<>();
SystemStream ss = new SystemStream("test", "nestedRecord");
props.put(String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, ss.getSystem(), ss.getStream()), SimpleRecord.SCHEMA$.toString());
ConfigBasedAvroRelSchemaProviderFactory factory = new ConfigBasedAvroRelSchemaProviderFactory();
AvroRelSchemaProvider schemaProvider = (AvroRelSchemaProvider) factory.create(ss, new MapConfig(props));
AvroRelConverter relConverter = new AvroRelConverter(ss, schemaProvider, new MapConfig());
SamzaRelTableKeyConverter relTableKeyConverter = new SampleRelTableKeyConverter();
String remoteTableName = "testDb.testTable.$table";
GenericData.Record tableRecord = new GenericData.Record(SimpleRecord.SCHEMA$);
tableRecord.put("id", 1);
tableRecord.put("name", "name1");
SamzaSqlRelMessage streamMsg = new SamzaSqlRelMessage(streamFieldNames, streamFieldValues, new SamzaSqlRelMsgMetadata(0L, 0L));
SamzaSqlRelMessage tableMsg = relConverter.convertToRelMessage(new KV(tableRecord.get("id"), tableRecord));
JoinRelType joinRelType = JoinRelType.INNER;
List<Integer> streamKeyIds = Arrays.asList(1);
List<Integer> tableKeyIds = Arrays.asList(0);
KV<Object, GenericRecord> record = KV.of(tableRecord.get("id"), tableRecord);
JoinInputNode mockTableInputNode = mock(JoinInputNode.class);
when(mockTableInputNode.getKeyIds()).thenReturn(tableKeyIds);
when(mockTableInputNode.isPosOnRight()).thenReturn(true);
when(mockTableInputNode.getFieldNames()).thenReturn(tableMsg.getSamzaSqlRelRecord().getFieldNames());
when(mockTableInputNode.getSourceName()).thenReturn(remoteTableName);
JoinInputNode mockStreamInputNode = mock(JoinInputNode.class);
when(mockStreamInputNode.getKeyIds()).thenReturn(streamKeyIds);
when(mockStreamInputNode.isPosOnRight()).thenReturn(false);
when(mockStreamInputNode.getFieldNames()).thenReturn(streamFieldNames);
SamzaSqlRemoteTableJoinFunction joinFn = new SamzaSqlRemoteTableJoinFunction(relConverter, relTableKeyConverter, mockStreamInputNode, mockTableInputNode, joinRelType, 0);
SamzaSqlRelMessage outMsg = joinFn.apply(streamMsg, record);
Assert.assertEquals(outMsg.getSamzaSqlRelRecord().getFieldValues().size(), outMsg.getSamzaSqlRelRecord().getFieldNames().size());
List<String> expectedFieldNames = new ArrayList<>(streamFieldNames);
expectedFieldNames.addAll(tableMsg.getSamzaSqlRelRecord().getFieldNames());
List<Object> expectedFieldValues = new ArrayList<>(streamFieldValues);
expectedFieldValues.addAll(tableMsg.getSamzaSqlRelRecord().getFieldValues());
Assert.assertEquals(expectedFieldNames, outMsg.getSamzaSqlRelRecord().getFieldNames());
Assert.assertEquals(expectedFieldValues, outMsg.getSamzaSqlRelRecord().getFieldValues());
}
use of org.apache.samza.sql.data.SamzaSqlRelMessage in project samza by apache.
the class TestProjectTranslator method testTranslate.
@Test
public void testTranslate() throws IOException, ClassNotFoundException {
// setup mock values to the constructor of FilterTranslator
LogicalProject mockProject = PowerMockito.mock(LogicalProject.class);
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
RelNode mockInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockInput);
when(mockInput.getId()).thenReturn(1);
when(mockProject.getId()).thenReturn(2);
when(mockProject.getInputs()).thenReturn(inputs);
when(mockProject.getInput()).thenReturn(mockInput);
RelDataType mockRowType = mock(RelDataType.class);
when(mockRowType.getFieldCount()).thenReturn(1);
when(mockProject.getRowType()).thenReturn(mockRowType);
RexNode mockRexField = mock(RexNode.class);
List<Pair<RexNode, String>> namedProjects = new ArrayList<>();
namedProjects.add(Pair.of(mockRexField, "test_field"));
when(mockProject.getNamedProjects()).thenReturn(namedProjects);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockAppDesc, mockInputOp);
when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
// Apply translate() method to verify that we are getting the correct map operator constructed
ProjectTranslator projectTranslator = new ProjectTranslator(1);
projectTranslator.translate(mockProject, LOGICAL_OP_ID, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerRelNode(2, mockProject);
verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
when(mockTranslatorContext.getRelNode(2)).thenReturn(mockProject);
when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
StreamOperatorSpec projectSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
assertNotNull(projectSpec);
assertEquals(projectSpec.getOpCode(), OperatorSpec.OpCode.MAP);
// Verify that the bootstrap() method will establish the context for the map function
Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
mockContexts.put(1, mockTranslatorContext);
when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
projectSpec.getTransformFn().init(mockContext);
MapFunction mapFn = (MapFunction) Whitebox.getInternalState(projectSpec, "mapFn");
assertNotNull(mapFn);
assertEquals(mockTranslatorContext, Whitebox.getInternalState(mapFn, "translatorContext"));
assertEquals(mockProject, Whitebox.getInternalState(mapFn, "project"));
assertEquals(mockExpr, Whitebox.getInternalState(mapFn, "expr"));
// Verify TestMetricsRegistryImpl works with Project
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
assertEquals(2, testMetricsRegistryImpl.getGauges().get(LOGICAL_OP_ID).size());
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
// Calling mapFn.apply() to verify the filter function is correctly applied to the input message
SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(), new SamzaSqlRelMsgMetadata(0L, 0L));
SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
DataContext dataContext = mock(DataContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
Object[] result = new Object[1];
final Object mockFieldObj = new Object();
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = mockFieldObj;
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
SamzaSqlRelMessage retMsg = (SamzaSqlRelMessage) mapFn.apply(mockInputMsg);
assertEquals(retMsg.getSamzaSqlRelRecord().getFieldNames(), Collections.singletonList("test_field"));
assertEquals(retMsg.getSamzaSqlRelRecord().getFieldValues(), Collections.singletonList(mockFieldObj));
// Verify mapFn.apply() updates the TestMetricsRegistryImpl metrics
assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(1, testMetricsRegistryImpl.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
use of org.apache.samza.sql.data.SamzaSqlRelMessage in project samza by apache.
the class TestSamzaSqlRelRecordSerde method testWithDifferentFields.
@Test
public void testWithDifferentFields() {
SamzaSqlRelRecord record = new SamzaSqlRelMessage(names, values, new SamzaSqlRelMsgMetadata(0L, 0L)).getSamzaSqlRelRecord();
SamzaSqlRelRecordSerde serde = (SamzaSqlRelRecordSerde) new SamzaSqlRelRecordSerdeFactory().getSerde(null, null);
SamzaSqlRelRecord resultRecord = serde.fromBytes(serde.toBytes(record));
Assert.assertEquals(names, resultRecord.getFieldNames());
Assert.assertEquals(values, resultRecord.getFieldValues());
}
use of org.apache.samza.sql.data.SamzaSqlRelMessage in project samza by apache.
the class JoinTranslator method getTable.
private Table getTable(JoinInputNode tableNode, TranslatorContext context) {
SqlIOConfig sourceTableConfig = resolveSQlIOForTable(tableNode.getRelNode(), context.getExecutionContext().getSamzaSqlApplicationConfig().getInputSystemStreamConfigBySource());
if (sourceTableConfig == null || !sourceTableConfig.getTableDescriptor().isPresent()) {
String errMsg = "Failed to resolve table source in join operation: node=" + tableNode.getRelNode();
log.error(errMsg);
throw new SamzaException(errMsg);
}
Table<KV<SamzaSqlRelRecord, SamzaSqlRelMessage>> table = context.getStreamAppDescriptor().getTable(sourceTableConfig.getTableDescriptor().get());
if (tableNode.isRemoteTable()) {
return table;
}
// If local table, load the table.
// Load the local table with the fields in the join condition as composite key and relational message as the value.
// Send the messages from the input stream denoted as 'table' to the created table store.
MessageStream<SamzaSqlRelMessage> relOutputStream = context.getMessageStream(tableNode.getRelNode().getId());
SamzaSqlRelRecordSerdeFactory.SamzaSqlRelRecordSerde keySerde = (SamzaSqlRelRecordSerdeFactory.SamzaSqlRelRecordSerde) new SamzaSqlRelRecordSerdeFactory().getSerde(null, null);
SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde valueSerde = (SamzaSqlRelMessageSerdeFactory.SamzaSqlRelMessageSerde) new SamzaSqlRelMessageSerdeFactory().getSerde(null, null);
List<Integer> tableKeyIds = tableNode.getKeyIds();
// Let's always repartition by the join fields as key before sending the key and value to the table.
// We need to repartition the stream denoted as table to ensure that both the stream and table that are joined
// have the same partitioning scheme with the same partition key and number. Please note that bootstrap semantic is
// not propagated to the intermediate streams. Please refer SAMZA-1613 for more details on this. Subsequently, the
// results are consistent only after the local table is caught up.
relOutputStream.partitionBy(m -> createSamzaSqlCompositeKey(m, tableKeyIds), m -> m, KVSerde.of(keySerde, valueSerde), intermediateStreamPrefix + "table_" + logicalOpId).sendTo(table);
return table;
}
use of org.apache.samza.sql.data.SamzaSqlRelMessage in project samza by apache.
the class JoinTranslator method translate.
void translate(final LogicalJoin join, final TranslatorContext translatorContext) {
JoinInputNode.InputType inputTypeOnLeft = JoinInputNode.getInputType(join.getLeft(), translatorContext.getExecutionContext().getSamzaSqlApplicationConfig().getInputSystemStreamConfigBySource());
JoinInputNode.InputType inputTypeOnRight = JoinInputNode.getInputType(join.getRight(), translatorContext.getExecutionContext().getSamzaSqlApplicationConfig().getInputSystemStreamConfigBySource());
// Do the validation of join query
validateJoinQuery(join, inputTypeOnLeft, inputTypeOnRight);
// At this point, one of the sides is a table. Let's figure out if it is on left or right side.
boolean isTablePosOnRight = inputTypeOnRight != JoinInputNode.InputType.STREAM;
// stream and table keyIds are used to extract the join condition field (key) names and values out of the stream
// and table records.
List<Integer> streamKeyIds = new LinkedList<>();
List<Integer> tableKeyIds = new LinkedList<>();
// Fetch the stream and table indices corresponding to the fields given in the join condition.
final int leftSideSize = join.getLeft().getRowType().getFieldCount();
final int tableStartIdx = isTablePosOnRight ? leftSideSize : 0;
final int streamStartIdx = isTablePosOnRight ? 0 : leftSideSize;
final int tableEndIdx = isTablePosOnRight ? join.getRowType().getFieldCount() : leftSideSize;
join.getCondition().accept(new RexShuttle() {
@Override
public RexNode visitInputRef(RexInputRef inputRef) {
// Validate the type of the input ref.
validateJoinKeyType(inputRef);
int index = inputRef.getIndex();
if (index >= tableStartIdx && index < tableEndIdx) {
tableKeyIds.add(index - tableStartIdx);
} else {
streamKeyIds.add(index - streamStartIdx);
}
return inputRef;
}
});
Collections.sort(tableKeyIds);
Collections.sort(streamKeyIds);
// Get the two input nodes (stream and table nodes) for the join.
JoinInputNode streamNode = new JoinInputNode(isTablePosOnRight ? join.getLeft() : join.getRight(), streamKeyIds, isTablePosOnRight ? inputTypeOnLeft : inputTypeOnRight, !isTablePosOnRight);
JoinInputNode tableNode = new JoinInputNode(isTablePosOnRight ? join.getRight() : join.getLeft(), tableKeyIds, isTablePosOnRight ? inputTypeOnRight : inputTypeOnLeft, isTablePosOnRight);
MessageStream<SamzaSqlRelMessage> inputStream = translatorContext.getMessageStream(streamNode.getRelNode().getId());
Table table = getTable(tableNode, translatorContext);
MessageStream<SamzaSqlRelMessage> outputStream = joinStreamWithTable(inputStream, table, streamNode, tableNode, join, translatorContext);
translatorContext.registerMessageStream(join.getId(), outputStream);
outputStream.map(outputMetricsMF);
}
Aggregations