use of org.apache.flink.runtime.io.network.partition.consumer.RemoteChannelStateChecker in project flink by apache.
the class TaskTest method testOnPartitionStateUpdate.
public void testOnPartitionStateUpdate(ExecutionState initialTaskState) throws Exception {
final ResultPartitionID partitionId = new ResultPartitionID();
final Task task = createTaskBuilder().setInvokable(InvokableBlockingInInvoke.class).build();
RemoteChannelStateChecker checker = new RemoteChannelStateChecker(partitionId, "test task");
// Expected task state for each producer state
final Map<ExecutionState, ExecutionState> expected = new HashMap<>(ExecutionState.values().length);
// Fail the task for unexpected states
for (ExecutionState state : ExecutionState.values()) {
expected.put(state, ExecutionState.FAILED);
}
expected.put(ExecutionState.INITIALIZING, initialTaskState);
expected.put(ExecutionState.RUNNING, initialTaskState);
expected.put(ExecutionState.SCHEDULED, initialTaskState);
expected.put(ExecutionState.DEPLOYING, initialTaskState);
expected.put(ExecutionState.FINISHED, initialTaskState);
expected.put(ExecutionState.CANCELED, ExecutionState.CANCELING);
expected.put(ExecutionState.CANCELING, ExecutionState.CANCELING);
expected.put(ExecutionState.FAILED, ExecutionState.CANCELING);
int producingStateCounter = 0;
for (ExecutionState state : ExecutionState.values()) {
TestTaskBuilder.setTaskState(task, initialTaskState);
if (checker.isProducerReadyOrAbortConsumption(task.new PartitionProducerStateResponseHandle(state, null))) {
producingStateCounter++;
}
ExecutionState newTaskState = task.getExecutionState();
assertEquals(expected.get(state), newTaskState);
}
assertEquals(5, producingStateCounter);
}
use of org.apache.flink.runtime.io.network.partition.consumer.RemoteChannelStateChecker in project flink by apache.
the class TaskTest method testTriggerPartitionStateUpdate.
/**
* Tests the trigger partition state update future completions.
*/
@Test
public void testTriggerPartitionStateUpdate() throws Exception {
final IntermediateDataSetID resultId = new IntermediateDataSetID();
final ResultPartitionID partitionId = new ResultPartitionID();
final PartitionProducerStateChecker partitionChecker = mock(PartitionProducerStateChecker.class);
final ResultPartitionConsumableNotifier consumableNotifier = new NoOpResultPartitionConsumableNotifier();
AtomicInteger callCount = new AtomicInteger(0);
RemoteChannelStateChecker remoteChannelStateChecker = new RemoteChannelStateChecker(partitionId, "test task");
// Test all branches of trigger partition state check
{
// Reset latches
setup();
// PartitionProducerDisposedException
final Task task = createTaskBuilder().setInvokable(InvokableBlockingInInvoke.class).setConsumableNotifier(consumableNotifier).setPartitionProducerStateChecker(partitionChecker).setExecutor(Executors.directExecutor()).build();
TestTaskBuilder.setTaskState(task, ExecutionState.RUNNING);
final CompletableFuture<ExecutionState> promise = new CompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.requestPartitionProducerState(resultId, partitionId, checkResult -> assertThat(remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult), is(false)));
promise.completeExceptionally(new PartitionProducerDisposedException(partitionId));
assertEquals(ExecutionState.CANCELING, task.getExecutionState());
}
{
// Reset latches
setup();
// Any other exception
final Task task = createTaskBuilder().setInvokable(InvokableBlockingInInvoke.class).setConsumableNotifier(consumableNotifier).setPartitionProducerStateChecker(partitionChecker).setExecutor(Executors.directExecutor()).build();
TestTaskBuilder.setTaskState(task, ExecutionState.RUNNING);
final CompletableFuture<ExecutionState> promise = new CompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.requestPartitionProducerState(resultId, partitionId, checkResult -> assertThat(remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult), is(false)));
promise.completeExceptionally(new RuntimeException("Any other exception"));
assertEquals(ExecutionState.FAILED, task.getExecutionState());
}
{
callCount.set(0);
// Reset latches
setup();
// TimeoutException handled special => retry
// Any other exception
final Task task = createTaskBuilder().setInvokable(InvokableBlockingInInvoke.class).setConsumableNotifier(consumableNotifier).setPartitionProducerStateChecker(partitionChecker).setExecutor(Executors.directExecutor()).build();
try {
task.startTaskThread();
awaitLatch.await();
CompletableFuture<ExecutionState> promise = new CompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.requestPartitionProducerState(resultId, partitionId, checkResult -> {
if (remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult)) {
callCount.incrementAndGet();
}
});
promise.completeExceptionally(new TimeoutException());
assertEquals(ExecutionState.RUNNING, task.getExecutionState());
assertEquals(1, callCount.get());
} finally {
task.getExecutingThread().interrupt();
task.getExecutingThread().join();
}
}
{
callCount.set(0);
// Reset latches
setup();
// Success
final Task task = createTaskBuilder().setInvokable(InvokableBlockingInInvoke.class).setConsumableNotifier(consumableNotifier).setPartitionProducerStateChecker(partitionChecker).setExecutor(Executors.directExecutor()).build();
try {
task.startTaskThread();
awaitLatch.await();
CompletableFuture<ExecutionState> promise = new CompletableFuture<>();
when(partitionChecker.requestPartitionProducerState(eq(task.getJobID()), eq(resultId), eq(partitionId))).thenReturn(promise);
task.requestPartitionProducerState(resultId, partitionId, checkResult -> {
if (remoteChannelStateChecker.isProducerReadyOrAbortConsumption(checkResult)) {
callCount.incrementAndGet();
}
});
promise.complete(ExecutionState.RUNNING);
assertEquals(ExecutionState.RUNNING, task.getExecutionState());
assertEquals(1, callCount.get());
} finally {
task.getExecutingThread().interrupt();
task.getExecutingThread().join();
}
}
}
Aggregations