Search in sources :

Example 1 with ClusterStateTaskExecutor

use of org.opensearch.cluster.ClusterStateTaskExecutor in project OpenSearch by opensearch-project.

the class MasterService method submitStateUpdateTasks.

/**
 * Submits a batch of cluster state update tasks; submitted updates are guaranteed to be processed together,
 * potentially with more tasks of the same executor.
 *
 * @param source   the source of the cluster state update task
 * @param tasks    a map of update tasks and their corresponding listeners
 * @param config   the cluster state update task configuration
 * @param executor the cluster state update task executor; tasks
 *                 that share the same executor will be executed
 *                 batches on this executor
 * @param <T>      the type of the cluster state update task state
 */
public <T> void submitStateUpdateTasks(final String source, final Map<T, ClusterStateTaskListener> tasks, final ClusterStateTaskConfig config, final ClusterStateTaskExecutor<T> executor) {
    if (!lifecycle.started()) {
        return;
    }
    final ThreadContext threadContext = threadPool.getThreadContext();
    final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(true);
    try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
        threadContext.markAsSystemContext();
        List<Batcher.UpdateTask> safeTasks = tasks.entrySet().stream().map(e -> taskBatcher.new UpdateTask(config.priority(), source, e.getKey(), safe(e.getValue(), supplier), executor)).collect(Collectors.toList());
        taskBatcher.submitTasks(safeTasks, config.timeout());
    } catch (OpenSearchRejectedExecutionException e) {
        // to be done here...
        if (!lifecycle.stoppedOrClosed()) {
            throw e;
        }
    }
}
Also used : ClusterStateTaskListener(org.opensearch.cluster.ClusterStateTaskListener) OpenSearchRejectedExecutionException(org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException) AckedClusterStateTaskListener(org.opensearch.cluster.AckedClusterStateTaskListener) DiscoveryNodes(org.opensearch.cluster.node.DiscoveryNodes) Arrays(java.util.Arrays) Metadata(org.opensearch.cluster.metadata.Metadata) CountDown(org.opensearch.common.util.concurrent.CountDown) PrioritizedOpenSearchThreadPoolExecutor(org.opensearch.common.util.concurrent.PrioritizedOpenSearchThreadPoolExecutor) ThreadPool(org.opensearch.threadpool.ThreadPool) Priority(org.opensearch.common.Priority) Node(org.opensearch.node.Node) FutureUtils(org.opensearch.common.util.concurrent.FutureUtils) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) ParameterizedMessage(org.apache.logging.log4j.message.ParameterizedMessage) OpenSearchExecutors(org.opensearch.common.util.concurrent.OpenSearchExecutors) Supplier(java.util.function.Supplier) ClusterState(org.opensearch.cluster.ClusterState) DiscoveryNode(org.opensearch.cluster.node.DiscoveryNode) ClusterTasksResult(org.opensearch.cluster.ClusterStateTaskExecutor.ClusterTasksResult) PlainActionFuture(org.opensearch.action.support.PlainActionFuture) Locale(java.util.Locale) OpenSearchExecutors.daemonThreadFactory(org.opensearch.common.util.concurrent.OpenSearchExecutors.daemonThreadFactory) Assertions(org.opensearch.Assertions) Map(java.util.Map) ClusterStatePublisher(org.opensearch.cluster.coordination.ClusterStatePublisher) ClusterStateTaskConfig(org.opensearch.cluster.ClusterStateTaskConfig) ClusterSettings(org.opensearch.common.settings.ClusterSettings) ProcessClusterEventTimeoutException(org.opensearch.cluster.metadata.ProcessClusterEventTimeoutException) Setting(org.opensearch.common.settings.Setting) TimeValue(org.opensearch.common.unit.TimeValue) ClusterStateTaskExecutor(org.opensearch.cluster.ClusterStateTaskExecutor) Settings(org.opensearch.common.settings.Settings) Discovery(org.opensearch.discovery.Discovery) Collectors(java.util.stream.Collectors) Nullable(org.opensearch.common.Nullable) FailedToCommitClusterStateException(org.opensearch.cluster.coordination.FailedToCommitClusterStateException) Objects(java.util.Objects) TimeUnit(java.util.concurrent.TimeUnit) AbstractLifecycleComponent(org.opensearch.common.component.AbstractLifecycleComponent) List(java.util.List) Logger(org.apache.logging.log4j.Logger) Builder(org.opensearch.cluster.ClusterState.Builder) RoutingTable(org.opensearch.cluster.routing.RoutingTable) LogManager(org.apache.logging.log4j.LogManager) Collections(java.util.Collections) Text(org.opensearch.common.text.Text) Scheduler(org.opensearch.threadpool.Scheduler) ClusterChangedEvent(org.opensearch.cluster.ClusterChangedEvent) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) OpenSearchRejectedExecutionException(org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException)

Example 2 with ClusterStateTaskExecutor

use of org.opensearch.cluster.ClusterStateTaskExecutor in project OpenSearch by opensearch-project.

the class MasterServiceTests method testClusterStateBatchedUpdates.

public void testClusterStateBatchedUpdates() throws BrokenBarrierException, InterruptedException {
    AtomicInteger counter = new AtomicInteger();
    class Task {

        private AtomicBoolean state = new AtomicBoolean();

        private final int id;

        Task(int id) {
            this.id = id;
        }

        public void execute() {
            if (!state.compareAndSet(false, true)) {
                throw new IllegalStateException();
            } else {
                counter.incrementAndGet();
            }
        }

        @Override
        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            Task task = (Task) o;
            return id == task.id;
        }

        @Override
        public int hashCode() {
            return id;
        }

        @Override
        public String toString() {
            return Integer.toString(id);
        }
    }
    int numberOfThreads = randomIntBetween(2, 8);
    int taskSubmissionsPerThread = randomIntBetween(1, 64);
    int numberOfExecutors = Math.max(1, numberOfThreads / 4);
    final Semaphore semaphore = new Semaphore(numberOfExecutors);
    class TaskExecutor implements ClusterStateTaskExecutor<Task> {

        private final List<Set<Task>> taskGroups;

        private AtomicInteger counter = new AtomicInteger();

        private AtomicInteger batches = new AtomicInteger();

        private AtomicInteger published = new AtomicInteger();

        TaskExecutor(List<Set<Task>> taskGroups) {
            this.taskGroups = taskGroups;
        }

        @Override
        public ClusterTasksResult<Task> execute(ClusterState currentState, List<Task> tasks) throws Exception {
            for (Set<Task> expectedSet : taskGroups) {
                long count = tasks.stream().filter(expectedSet::contains).count();
                assertThat("batched set should be executed together or not at all. Expected " + expectedSet + "s. Executing " + tasks, count, anyOf(equalTo(0L), equalTo((long) expectedSet.size())));
            }
            tasks.forEach(Task::execute);
            counter.addAndGet(tasks.size());
            ClusterState maybeUpdatedClusterState = currentState;
            if (randomBoolean()) {
                maybeUpdatedClusterState = ClusterState.builder(currentState).build();
                batches.incrementAndGet();
                semaphore.acquire();
            }
            return ClusterTasksResult.<Task>builder().successes(tasks).build(maybeUpdatedClusterState);
        }

        @Override
        public void clusterStatePublished(ClusterChangedEvent clusterChangedEvent) {
            published.incrementAndGet();
            semaphore.release();
        }
    }
    ConcurrentMap<String, AtomicInteger> processedStates = new ConcurrentHashMap<>();
    List<Set<Task>> taskGroups = new ArrayList<>();
    List<TaskExecutor> executors = new ArrayList<>();
    for (int i = 0; i < numberOfExecutors; i++) {
        executors.add(new TaskExecutor(taskGroups));
    }
    // randomly assign tasks to executors
    List<Tuple<TaskExecutor, Set<Task>>> assignments = new ArrayList<>();
    int taskId = 0;
    for (int i = 0; i < numberOfThreads; i++) {
        for (int j = 0; j < taskSubmissionsPerThread; j++) {
            TaskExecutor executor = randomFrom(executors);
            Set<Task> tasks = new HashSet<>();
            for (int t = randomInt(3); t >= 0; t--) {
                tasks.add(new Task(taskId++));
            }
            taskGroups.add(tasks);
            assignments.add(Tuple.tuple(executor, tasks));
        }
    }
    Map<TaskExecutor, Integer> counts = new HashMap<>();
    int totalTaskCount = 0;
    for (Tuple<TaskExecutor, Set<Task>> assignment : assignments) {
        final int taskCount = assignment.v2().size();
        counts.merge(assignment.v1(), taskCount, (previous, count) -> previous + count);
        totalTaskCount += taskCount;
    }
    final CountDownLatch updateLatch = new CountDownLatch(totalTaskCount);
    final ClusterStateTaskListener listener = new ClusterStateTaskListener() {

        @Override
        public void onFailure(String source, Exception e) {
            throw new AssertionError(e);
        }

        @Override
        public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) {
            processedStates.computeIfAbsent(source, key -> new AtomicInteger()).incrementAndGet();
            updateLatch.countDown();
        }
    };
    try (MasterService masterService = createMasterService(true)) {
        final ConcurrentMap<String, AtomicInteger> submittedTasksPerThread = new ConcurrentHashMap<>();
        CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
        for (int i = 0; i < numberOfThreads; i++) {
            final int index = i;
            Thread thread = new Thread(() -> {
                final String threadName = Thread.currentThread().getName();
                try {
                    barrier.await();
                    for (int j = 0; j < taskSubmissionsPerThread; j++) {
                        Tuple<TaskExecutor, Set<Task>> assignment = assignments.get(index * taskSubmissionsPerThread + j);
                        final Set<Task> tasks = assignment.v2();
                        submittedTasksPerThread.computeIfAbsent(threadName, key -> new AtomicInteger()).addAndGet(tasks.size());
                        final TaskExecutor executor = assignment.v1();
                        if (tasks.size() == 1) {
                            masterService.submitStateUpdateTask(threadName, tasks.stream().findFirst().get(), ClusterStateTaskConfig.build(randomFrom(Priority.values())), executor, listener);
                        } else {
                            Map<Task, ClusterStateTaskListener> taskListeners = new HashMap<>();
                            tasks.forEach(t -> taskListeners.put(t, listener));
                            masterService.submitStateUpdateTasks(threadName, taskListeners, ClusterStateTaskConfig.build(randomFrom(Priority.values())), executor);
                        }
                    }
                    barrier.await();
                } catch (BrokenBarrierException | InterruptedException e) {
                    throw new AssertionError(e);
                }
            });
            thread.start();
        }
        // wait for all threads to be ready
        barrier.await();
        // wait for all threads to finish
        barrier.await();
        // wait until all the cluster state updates have been processed
        updateLatch.await();
        // and until all of the publication callbacks have completed
        semaphore.acquire(numberOfExecutors);
        // assert the number of executed tasks is correct
        assertEquals(totalTaskCount, counter.get());
        // assert each executor executed the correct number of tasks
        for (TaskExecutor executor : executors) {
            if (counts.containsKey(executor)) {
                assertEquals((int) counts.get(executor), executor.counter.get());
                assertEquals(executor.batches.get(), executor.published.get());
            }
        }
        // assert the correct number of clusterStateProcessed events were triggered
        for (Map.Entry<String, AtomicInteger> entry : processedStates.entrySet()) {
            assertThat(submittedTasksPerThread, hasKey(entry.getKey()));
            assertEquals("not all tasks submitted by " + entry.getKey() + " received a processed event", entry.getValue().get(), submittedTasksPerThread.get(entry.getKey()).get());
        }
    }
}
Also used : TestThreadPool(org.opensearch.threadpool.TestThreadPool) Level(org.apache.logging.log4j.Level) Version(org.opensearch.Version) OpenSearchException(org.opensearch.OpenSearchException) ThreadContext(org.opensearch.common.util.concurrent.ThreadContext) Matchers.hasKey(org.hamcrest.Matchers.hasKey) DiscoveryNode(org.opensearch.cluster.node.DiscoveryNode) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) Map(java.util.Map) ClusterStatePublisher(org.opensearch.cluster.coordination.ClusterStatePublisher) AfterClass(org.junit.AfterClass) CyclicBarrier(java.util.concurrent.CyclicBarrier) TimeValue(org.opensearch.common.unit.TimeValue) OpenSearchTestCase(org.opensearch.test.OpenSearchTestCase) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) Set(java.util.Set) ClusterStateTaskExecutor(org.opensearch.cluster.ClusterStateTaskExecutor) Settings(org.opensearch.common.settings.Settings) Nullable(org.opensearch.common.Nullable) Tuple(org.opensearch.common.collect.Tuple) FailedToCommitClusterStateException(org.opensearch.cluster.coordination.FailedToCommitClusterStateException) CountDownLatch(java.util.concurrent.CountDownLatch) List(java.util.List) ClusterStateUpdateTask(org.opensearch.cluster.ClusterStateUpdateTask) Matchers.equalTo(org.hamcrest.Matchers.equalTo) Matchers.anyOf(org.hamcrest.Matchers.anyOf) Matchers.containsString(org.hamcrest.Matchers.containsString) ClusterStateTaskListener(org.opensearch.cluster.ClusterStateTaskListener) DiscoveryNodes(org.opensearch.cluster.node.DiscoveryNodes) MockLogAppender(org.opensearch.test.MockLogAppender) BeforeClass(org.junit.BeforeClass) ThreadPool(org.opensearch.threadpool.ThreadPool) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) Priority(org.opensearch.common.Priority) HashMap(java.util.HashMap) Node(org.opensearch.node.Node) AtomicReference(java.util.concurrent.atomic.AtomicReference) ArrayList(java.util.ArrayList) ConcurrentMap(java.util.concurrent.ConcurrentMap) HashSet(java.util.HashSet) ClusterState(org.opensearch.cluster.ClusterState) AckedClusterStateUpdateTask(org.opensearch.cluster.AckedClusterStateUpdateTask) ClusterStateTaskConfig(org.opensearch.cluster.ClusterStateTaskConfig) ClusterSettings(org.opensearch.common.settings.ClusterSettings) ClusterBlocks(org.opensearch.cluster.block.ClusterBlocks) Before(org.junit.Before) Collections.emptyMap(java.util.Collections.emptyMap) Collections.emptySet(java.util.Collections.emptySet) Semaphore(java.util.concurrent.Semaphore) BrokenBarrierException(java.util.concurrent.BrokenBarrierException) BaseFuture(org.opensearch.common.util.concurrent.BaseFuture) LocalClusterUpdateTask(org.opensearch.cluster.LocalClusterUpdateTask) TestLogging(org.opensearch.test.junit.annotations.TestLogging) TimeUnit(java.util.concurrent.TimeUnit) ClusterName(org.opensearch.cluster.ClusterName) LogManager(org.apache.logging.log4j.LogManager) Collections(java.util.Collections) ClusterChangedEvent(org.opensearch.cluster.ClusterChangedEvent) ClusterStateUpdateTask(org.opensearch.cluster.ClusterStateUpdateTask) AckedClusterStateUpdateTask(org.opensearch.cluster.AckedClusterStateUpdateTask) LocalClusterUpdateTask(org.opensearch.cluster.LocalClusterUpdateTask) Set(java.util.Set) HashSet(java.util.HashSet) Collections.emptySet(java.util.Collections.emptySet) BrokenBarrierException(java.util.concurrent.BrokenBarrierException) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) ClusterChangedEvent(org.opensearch.cluster.ClusterChangedEvent) Semaphore(java.util.concurrent.Semaphore) Matchers.containsString(org.hamcrest.Matchers.containsString) List(java.util.List) ArrayList(java.util.ArrayList) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashSet(java.util.HashSet) ClusterStateTaskListener(org.opensearch.cluster.ClusterStateTaskListener) ClusterState(org.opensearch.cluster.ClusterState) CountDownLatch(java.util.concurrent.CountDownLatch) OpenSearchException(org.opensearch.OpenSearchException) FailedToCommitClusterStateException(org.opensearch.cluster.coordination.FailedToCommitClusterStateException) BrokenBarrierException(java.util.concurrent.BrokenBarrierException) CyclicBarrier(java.util.concurrent.CyclicBarrier) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) AtomicBoolean(java.util.concurrent.atomic.AtomicBoolean) ClusterStateTaskExecutor(org.opensearch.cluster.ClusterStateTaskExecutor) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ClusterStateTaskExecutor(org.opensearch.cluster.ClusterStateTaskExecutor) Map(java.util.Map) ConcurrentHashMap(java.util.concurrent.ConcurrentHashMap) HashMap(java.util.HashMap) ConcurrentMap(java.util.concurrent.ConcurrentMap) Collections.emptyMap(java.util.Collections.emptyMap) Tuple(org.opensearch.common.collect.Tuple)

Aggregations

Collections (java.util.Collections)2 List (java.util.List)2 Map (java.util.Map)2 TimeUnit (java.util.concurrent.TimeUnit)2 LogManager (org.apache.logging.log4j.LogManager)2 ClusterChangedEvent (org.opensearch.cluster.ClusterChangedEvent)2 ClusterState (org.opensearch.cluster.ClusterState)2 ClusterStateTaskConfig (org.opensearch.cluster.ClusterStateTaskConfig)2 ClusterStateTaskExecutor (org.opensearch.cluster.ClusterStateTaskExecutor)2 ClusterStateTaskListener (org.opensearch.cluster.ClusterStateTaskListener)2 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 Collections.emptyMap (java.util.Collections.emptyMap)1 Collections.emptySet (java.util.Collections.emptySet)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 Locale (java.util.Locale)1 Objects (java.util.Objects)1 Set (java.util.Set)1 BrokenBarrierException (java.util.concurrent.BrokenBarrierException)1