use of org.opensearch.cluster.ClusterStateTaskExecutor in project OpenSearch by opensearch-project.
the class MasterService method submitStateUpdateTasks.
/**
* Submits a batch of cluster state update tasks; submitted updates are guaranteed to be processed together,
* potentially with more tasks of the same executor.
*
* @param source the source of the cluster state update task
* @param tasks a map of update tasks and their corresponding listeners
* @param config the cluster state update task configuration
* @param executor the cluster state update task executor; tasks
* that share the same executor will be executed
* batches on this executor
* @param <T> the type of the cluster state update task state
*/
public <T> void submitStateUpdateTasks(final String source, final Map<T, ClusterStateTaskListener> tasks, final ClusterStateTaskConfig config, final ClusterStateTaskExecutor<T> executor) {
if (!lifecycle.started()) {
return;
}
final ThreadContext threadContext = threadPool.getThreadContext();
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
threadContext.markAsSystemContext();
List<Batcher.UpdateTask> safeTasks = tasks.entrySet().stream().map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), supplier), executor)).collect(Collectors.toList());
taskBatcher.submitTasks(safeTasks, config.timeout());
} catch (OpenSearchRejectedExecutionException e) {
// to be done here...
if (!lifecycle.stoppedOrClosed()) {
throw e;
}
}
}
use of org.opensearch.cluster.ClusterStateTaskExecutor in project OpenSearch by opensearch-project.
the class MasterServiceTests method testClusterStateBatchedUpdates.
public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException {
AtomicInteger counter = new AtomicInteger();
class Task {
private AtomicBoolean state = new AtomicBoolean();
private final int id;
Task(int id) {
this.id = id;
}
public void execute() {
if (!state.compareAndSet(false, true)) {
throw new IllegalStateException();
} else {
counter.incrementAndGet();
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Task task = (Task) o;
return id == task.id;
}
@Override
public int hashCode() {
return id;
}
@Override
public String toString() {
return Integer.toString(id);
}
}
int numberOfThreads = randomIntBetween(2, 8);
int taskSubmissionsPerThread = randomIntBetween(1, 64);
int numberOfExecutors = Math.max(1, numberOfThreads / 4);
final Semaphore semaphore = new Semaphore(numberOfExecutors);
class TaskExecutor implements ClusterStateTaskExecutor<Task> {
private final List<Set<Task>> taskGroups;
private AtomicInteger counter = new AtomicInteger();
private AtomicInteger batches = new AtomicInteger();
private AtomicInteger published = new AtomicInteger();
TaskExecutor(List<Set<Task>> taskGroups) {
this.taskGroups = taskGroups;
}
@Override
public ClusterTasksResult<Task> execute(ClusterState currentState, List<Task> tasks) throws Exception {
for (Set<Task> expectedSet : taskGroups) {
long count = tasks.stream().filter(expectedSet::contains).count();
assertThat("batched set should be executed together or not at all. Expected " + expectedSet + "s. Executing " + tasks, count, anyOf(equalTo(0L), equalTo((long) expectedSet.size())));
}
tasks.forEach(Task::execute);
counter.addAndGet(tasks.size());
ClusterState maybeUpdatedClusterState = currentState;
if (randomBoolean()) {
maybeUpdatedClusterState = ClusterState.builder(currentState).build();
batches.incrementAndGet();
semaphore.acquire();
}
return ClusterTasksResult.<Task>builder().successes(tasks).build(maybeUpdatedClusterState);
}
@Override
public void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
published.incrementAndGet();
semaphore.release();
}
}
ConcurrentMap<String, AtomicInteger> processedStates = new ConcurrentHashMap<>();
List<Set<Task>> taskGroups = new ArrayList<>();
List<TaskExecutor> executors = new ArrayList<>();
for (int i = 0; i < numberOfExecutors; i++) {
executors.add(new TaskExecutor(taskGroups));
}
// randomly assign tasks to executors
List<Tuple<TaskExecutor, Set<Task>>> assignments = new ArrayList<>();
int taskId = 0;
for (int i = 0; i < numberOfThreads; i++) {
for (int j = 0; j < taskSubmissionsPerThread; j++) {
TaskExecutor executor = randomFrom(executors);
Set<Task> tasks = new HashSet<>();
for (int t = randomInt(3); t >= 0; t--) {
tasks.add(new Task(taskId++));
}
taskGroups.add(tasks);
assignments.add(Tuple.tuple(executor, tasks));
}
}
Map<TaskExecutor, Integer> counts = new HashMap<>();
int totalTaskCount = 0;
for (Tuple<TaskExecutor, Set<Task>> assignment : assignments) {
final int taskCount = assignment.v2().size();
counts.merge(assignment.v1(), taskCount, (previous, count) -> previous + count);
totalTaskCount += taskCount;
}
final CountDownLatch updateLatch = new CountDownLatch(totalTaskCount);
final ClusterStateTaskListener listener = new ClusterStateTaskListener() {
@Override
public void onFailure(String source, Exception e) {
throw new AssertionError(e);
}
@Override
public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
processedStates.computeIfAbsent(source, key -> new AtomicInteger()).incrementAndGet();
updateLatch.countDown();
}
};
try (MasterService masterService = createMasterService(true)) {
final ConcurrentMap<String, AtomicInteger> submittedTasksPerThread = new ConcurrentHashMap<>();
CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
for (int i = 0; i < numberOfThreads; i++) {
final int index = i;
Thread thread = new Thread(() -> {
final String threadName = Thread.currentThread().getName();
try {
barrier.await();
for (int j = 0; j < taskSubmissionsPerThread; j++) {
Tuple<TaskExecutor, Set<Task>> assignment = assignments.get(index * taskSubmissionsPerThread + j);
final Set<Task> tasks = assignment.v2();
submittedTasksPerThread.computeIfAbsent(threadName, key -> new AtomicInteger()).addAndGet(tasks.size());
final TaskExecutor executor = assignment.v1();
if (tasks.size() == 1) {
masterService.submitStateUpdateTask(threadName, tasks.stream().findFirst().get(), ClusterStateTaskConfig.build(randomFrom(Priority.values())), executor, listener);
} else {
Map<Task, ClusterStateTaskListener> taskListeners = new HashMap<>();
tasks.forEach(t -> taskListeners.put(t, listener));
masterService.submitStateUpdateTasks(threadName, taskListeners, ClusterStateTaskConfig.build(randomFrom(Priority.values())), executor);
}
}
barrier.await();
} catch (BrokenBarrierException | InterruptedException e) {
throw new AssertionError(e);
}
});
thread.start();
}
// wait for all threads to be ready
barrier.await();
// wait for all threads to finish
barrier.await();
// wait until all the cluster state updates have been processed
updateLatch.await();
// and until all of the publication callbacks have completed
semaphore.acquire(numberOfExecutors);
// assert the number of executed tasks is correct
assertEquals(totalTaskCount, counter.get());
// assert each executor executed the correct number of tasks
for (TaskExecutor executor : executors) {
if (counts.containsKey(executor)) {
assertEquals((int) counts.get(executor), executor.counter.get());
assertEquals(executor.batches.get(), executor.published.get());
}
}
// assert the correct number of clusterStateProcessed events were triggered
for (Map.Entry<String, AtomicInteger> entry : processedStates.entrySet()) {
assertThat(submittedTasksPerThread, hasKey(entry.getKey()));
assertEquals("not all tasks submitted by " + entry.getKey() + " received a processed event", entry.getValue().get(), submittedTasksPerThread.get(entry.getKey()).get());
}
}
}
Aggregations