use of org.apache.samza.context.Context in project samza by apache.
the class TestJoinOperator method createStreamOperatorTask.
private StreamOperatorTask createStreamOperatorTask(Clock clock, StreamApplicationDescriptorImpl graphSpec) throws Exception {
Map<String, String> mapConfig = new HashMap<>();
mapConfig.put("job.name", "jobName");
mapConfig.put("job.id", "jobId");
StreamTestUtils.addStreamConfigs(mapConfig, "inStream", "insystem", "instream");
StreamTestUtils.addStreamConfigs(mapConfig, "inStream2", "insystem", "instream2");
Context context = new MockContext(new MapConfig(mapConfig));
TaskModel taskModel = mock(TaskModel.class);
when(taskModel.getSystemStreamPartitions()).thenReturn(ImmutableSet.of(new SystemStreamPartition("insystem", "instream", new Partition(0)), new SystemStreamPartition("insystem", "instream2", new Partition(0))));
when(context.getTaskContext().getTaskModel()).thenReturn(taskModel);
when(context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap());
when(context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap());
// need to return different stores for left and right side
IntegerSerde integerSerde = new IntegerSerde();
TimestampedValueSerde timestampedValueSerde = new TimestampedValueSerde(new KVSerde(integerSerde, integerSerde));
when(context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(new TestInMemoryStore(integerSerde, timestampedValueSerde));
when(context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(new TestInMemoryStore(integerSerde, timestampedValueSerde));
StreamOperatorTask sot = new StreamOperatorTask(graphSpec.getOperatorSpecGraph(), clock);
sot.init(context);
return sot;
}
use of org.apache.samza.context.Context in project samza by apache.
the class TestLocalTableRead method createTable.
private LocalTable createTable(boolean isTimerDisabled) {
Map<String, String> config = new HashMap<>();
if (isTimerDisabled) {
config.put(MetricsConfig.METRICS_TIMER_ENABLED, "false");
}
Context context = mock(Context.class);
JobContext jobContext = mock(JobContext.class);
when(context.getJobContext()).thenReturn(jobContext);
when(jobContext.getConfig()).thenReturn(new MapConfig(config));
ContainerContext containerContext = mock(ContainerContext.class);
when(context.getContainerContext()).thenReturn(containerContext);
when(containerContext.getContainerMetricsRegistry()).thenReturn(metricsRegistry);
LocalTable table = new LocalTable("t1", kvStore);
table.init(context);
return table;
}
use of org.apache.samza.context.Context in project samza by apache.
the class TestFilterTranslator method testTranslate.
@Test
public void testTranslate() throws IOException, ClassNotFoundException {
// setup mock values to the constructor of FilterTranslator
LogicalFilter mockFilter = PowerMockito.mock(LogicalFilter.class);
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
TestMetricsRegistryImpl metricsRegistry = new TestMetricsRegistryImpl();
RelNode mockInput = mock(RelNode.class);
when(mockFilter.getInput()).thenReturn(mockInput);
when(mockInput.getId()).thenReturn(1);
when(mockFilter.getId()).thenReturn(2);
StreamApplicationDescriptorImpl mockGraph = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockGraph, mockInputOp);
when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(metricsRegistry);
// Apply translate() method to verify that we are getting the correct filter operator constructed
FilterTranslator filterTranslator = new FilterTranslator(1);
filterTranslator.translate(mockFilter, LOGICAL_OP_ID, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerRelNode(2, mockFilter);
verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
when(mockTranslatorContext.getRelNode(2)).thenReturn(mockFilter);
when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
StreamOperatorSpec filterSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
assertNotNull(filterSpec);
assertEquals(filterSpec.getOpCode(), OperatorSpec.OpCode.FILTER);
// Verify that the describe() method will establish the context for the filter function
Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
mockContexts.put(1, mockTranslatorContext);
when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
filterSpec.getTransformFn().init(mockContext);
FilterFunction filterFn = (FilterFunction) Whitebox.getInternalState(filterSpec, "filterFn");
assertNotNull(filterFn);
assertEquals(mockTranslatorContext, Whitebox.getInternalState(filterFn, "translatorContext"));
assertEquals(mockFilter, Whitebox.getInternalState(filterFn, "filter"));
assertEquals(mockExpr, Whitebox.getInternalState(filterFn, "expr"));
// Verify MetricsRegistry works with Project
assertEquals(1, metricsRegistry.getGauges().size());
assertTrue(metricsRegistry.getGauges().get(LOGICAL_OP_ID).size() > 0);
assertEquals(1, metricsRegistry.getCounters().size());
assertEquals(3, metricsRegistry.getCounters().get(LOGICAL_OP_ID).size());
assertEquals(0, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(0, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
// Calling filterFn.apply() to verify the filter function is correctly applied to the input message
SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(), new SamzaSqlRelMsgMetadata(0L, 0L));
SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
DataContext dataContext = mock(DataContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
Object[] result = new Object[1];
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = new Boolean(true);
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
assertTrue(filterFn.apply(mockInputMsg));
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = new Boolean(false);
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext), eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
assertFalse(filterFn.apply(mockInputMsg));
// Verify filterFn.apply() updates the MetricsRegistry metrics
assertEquals(2, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(1, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
use of org.apache.samza.context.Context in project samza by apache.
the class TestJoinTranslator method testTranslateStreamToTableJoin.
private void testTranslateStreamToTableJoin(boolean isRemoteTable) throws IOException, ClassNotFoundException {
// setup mock values to the constructor of JoinTranslator
final String logicalOpId = "sql0_join3";
final int queryId = 0;
LogicalJoin mockJoin = PowerMockito.mock(LogicalJoin.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
RelNode mockLeftInput = PowerMockito.mock(LogicalTableScan.class);
RelNode mockRightInput = mock(RelNode.class);
List<RelNode> inputs = new ArrayList<>();
inputs.add(mockLeftInput);
inputs.add(mockRightInput);
RelOptTable mockLeftTable = mock(RelOptTable.class);
when(mockLeftInput.getTable()).thenReturn(mockLeftTable);
List<String> qualifiedTableName = Arrays.asList("test", "LeftTable");
when(mockLeftTable.getQualifiedName()).thenReturn(qualifiedTableName);
when(mockLeftInput.getId()).thenReturn(1);
when(mockRightInput.getId()).thenReturn(2);
when(mockJoin.getId()).thenReturn(3);
when(mockJoin.getInputs()).thenReturn(inputs);
when(mockJoin.getLeft()).thenReturn(mockLeftInput);
when(mockJoin.getRight()).thenReturn(mockRightInput);
RexCall mockJoinCondition = mock(RexCall.class);
when(mockJoinCondition.isAlwaysTrue()).thenReturn(false);
when(mockJoinCondition.getKind()).thenReturn(SqlKind.EQUALS);
when(mockJoin.getCondition()).thenReturn(mockJoinCondition);
RexInputRef mockLeftConditionInput = mock(RexInputRef.class);
RexInputRef mockRightConditionInput = mock(RexInputRef.class);
when(mockLeftConditionInput.getIndex()).thenReturn(0);
when(mockRightConditionInput.getIndex()).thenReturn(0);
List<RexNode> condOperands = new ArrayList<>();
condOperands.add(mockLeftConditionInput);
condOperands.add(mockRightConditionInput);
when(mockJoinCondition.getOperands()).thenReturn(condOperands);
RelDataType mockLeftCondDataType = mock(RelDataType.class);
RelDataType mockRightCondDataType = mock(RelDataType.class);
when(mockLeftCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockRightCondDataType.getSqlTypeName()).thenReturn(SqlTypeName.BOOLEAN);
when(mockLeftConditionInput.getType()).thenReturn(mockLeftCondDataType);
when(mockRightConditionInput.getType()).thenReturn(mockRightCondDataType);
RelDataType mockLeftRowType = mock(RelDataType.class);
// ?? why ??
when(mockLeftRowType.getFieldCount()).thenReturn(0);
when(mockLeftInput.getRowType()).thenReturn(mockLeftRowType);
List<String> leftFieldNames = Collections.singletonList("test_table_field1");
List<String> rightStreamFieldNames = Collections.singletonList("test_stream_field1");
when(mockLeftRowType.getFieldNames()).thenReturn(leftFieldNames);
RelDataType mockRightRowType = mock(RelDataType.class);
when(mockRightInput.getRowType()).thenReturn(mockRightRowType);
when(mockRightRowType.getFieldNames()).thenReturn(rightStreamFieldNames);
StreamApplicationDescriptorImpl mockAppDesc = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockLeftInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockLeftInputStream = new MessageStreamImpl<>(mockAppDesc, mockLeftInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockLeftInput.getId()))).thenReturn(mockLeftInputStream);
OperatorSpec<Object, SamzaSqlRelMessage> mockRightInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockRightInputStream = new MessageStreamImpl<>(mockAppDesc, mockRightInputOp);
when(mockTranslatorContext.getMessageStream(eq(mockRightInput.getId()))).thenReturn(mockRightInputStream);
when(mockTranslatorContext.getStreamAppDescriptor()).thenReturn(mockAppDesc);
InputOperatorSpec mockInputOp = mock(InputOperatorSpec.class);
OutputStreamImpl mockOutputStream = mock(OutputStreamImpl.class);
when(mockInputOp.isKeyed()).thenReturn(true);
when(mockOutputStream.isKeyed()).thenReturn(true);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(3), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
if (isRemoteTable) {
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RemoteTableDescriptor.class));
} else {
IntermediateMessageStreamImpl mockPartitionedStream = new IntermediateMessageStreamImpl(mockAppDesc, mockInputOp, mockOutputStream);
when(mockAppDesc.getIntermediateStream(any(String.class), any(Serde.class), eq(false))).thenReturn(mockPartitionedStream);
doAnswer(this.getRegisteredTableAnswer()).when(mockAppDesc).getTable(any(RocksDbTableDescriptor.class));
}
when(mockJoin.getJoinType()).thenReturn(JoinRelType.INNER);
SamzaSqlExecutionContext mockExecutionContext = mock(SamzaSqlExecutionContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(mockExecutionContext);
SamzaSqlApplicationConfig mockAppConfig = mock(SamzaSqlApplicationConfig.class);
when(mockExecutionContext.getSamzaSqlApplicationConfig()).thenReturn(mockAppConfig);
Map<String, SqlIOConfig> ssConfigBySource = mock(HashMap.class);
when(mockAppConfig.getInputSystemStreamConfigBySource()).thenReturn(ssConfigBySource);
SqlIOConfig mockIOConfig = mock(SqlIOConfig.class);
TableDescriptor mockTableDesc;
if (isRemoteTable) {
mockTableDesc = mock(RemoteTableDescriptor.class);
} else {
mockTableDesc = mock(RocksDbTableDescriptor.class);
}
when(ssConfigBySource.get(String.join(".", qualifiedTableName))).thenReturn(mockIOConfig);
when(mockIOConfig.getTableDescriptor()).thenReturn(Optional.of(mockTableDesc));
JoinTranslator joinTranslator = new JoinTranslator(logicalOpId, "", queryId);
// Verify Metrics Works with Join
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TestMetricsRegistryImpl testMetricsRegistryImpl = new TestMetricsRegistryImpl();
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(testMetricsRegistryImpl);
TranslatorInputMetricsMapFunction inputMetricsMF = joinTranslator.getInputMetricsMF();
assertNotNull(inputMetricsMF);
inputMetricsMF.init(mockContext);
TranslatorOutputMetricsMapFunction outputMetricsMF = joinTranslator.getOutputMetricsMF();
assertNotNull(outputMetricsMF);
outputMetricsMF.init(mockContext);
assertEquals(1, testMetricsRegistryImpl.getCounters().size());
assertEquals(2, testMetricsRegistryImpl.getCounters().get(logicalOpId).size());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(0).getCount());
assertEquals(0, testMetricsRegistryImpl.getCounters().get(logicalOpId).get(1).getCount());
assertEquals(1, testMetricsRegistryImpl.getGauges().size());
// Apply translate() method to verify that we are getting the correct map operator constructed
joinTranslator.translate(mockJoin, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerMessageStream(3, this.getRegisteredMessageStream(3));
when(mockTranslatorContext.getRelNode(3)).thenReturn(mockJoin);
when(mockTranslatorContext.getMessageStream(3)).thenReturn(this.getRegisteredMessageStream(3));
StreamTableJoinOperatorSpec joinSpec = (StreamTableJoinOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(3), "operatorSpec");
assertNotNull(joinSpec);
assertEquals(joinSpec.getOpCode(), OperatorSpec.OpCode.JOIN);
// Verify joinSpec has the corresponding setup
StreamTableJoinFunction joinFn = joinSpec.getJoinFn();
assertNotNull(joinFn);
if (isRemoteTable) {
assertTrue(joinFn instanceof SamzaSqlRemoteTableJoinFunction);
} else {
assertTrue(joinFn instanceof SamzaSqlLocalTableJoinFunction);
}
assertTrue(Whitebox.getInternalState(joinFn, "isTablePosOnRight").equals(false));
assertEquals(Collections.singletonList(0), Whitebox.getInternalState(joinFn, "streamFieldIds"));
assertEquals(leftFieldNames, Whitebox.getInternalState(joinFn, "tableFieldNames"));
List<String> outputFieldNames = new ArrayList<>();
outputFieldNames.addAll(leftFieldNames);
outputFieldNames.addAll(rightStreamFieldNames);
assertEquals(outputFieldNames, Whitebox.getInternalState(joinFn, "outFieldNames"));
}
use of org.apache.samza.context.Context in project samza by apache.
the class TestEmbeddedTaggedRateLimiter method initRateLimiter.
static void initRateLimiter(RateLimiter rateLimiter) {
Map<TaskName, TaskModel> tasks = IntStream.range(0, NUMBER_OF_TASKS).mapToObj(i -> new TaskName("task-" + i)).collect(Collectors.toMap(Function.identity(), x -> mock(TaskModel.class)));
ContainerModel containerModel = mock(ContainerModel.class);
when(containerModel.getTasks()).thenReturn(tasks);
JobModel jobModel = mock(JobModel.class);
Map<String, ContainerModel> containerModelMap = new HashMap<>();
containerModelMap.put("container-1", containerModel);
when(jobModel.getContainers()).thenReturn(containerModelMap);
Context context = mock(Context.class);
TaskContextImpl taskContext = mock(TaskContextImpl.class);
when(context.getTaskContext()).thenReturn(taskContext);
when(taskContext.getJobModel()).thenReturn(jobModel);
when(context.getTaskContext().getTaskModel()).thenReturn(mock(TaskModel.class));
rateLimiter.init(context);
}
Aggregations