From 1b104ffcd8bc4573924754552508f5416573a7a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20J=2E=20Est=C3=A9banez?= Date: Tue, 9 Apr 2024 17:26:45 +0200 Subject: [PATCH] WorkerThreadPool: Support daemon-like tasks (via yield semantics) --- core/object/worker_thread_pool.cpp | 164 +++++++++++-------- core/object/worker_thread_pool.h | 19 ++- tests/core/templates/test_command_queue.h | 44 +++-- tests/core/threads/test_worker_thread_pool.h | 67 ++++++++ 4 files changed, 217 insertions(+), 77 deletions(-) diff --git a/core/object/worker_thread_pool.cpp b/core/object/worker_thread_pool.cpp index 4ec0ab4df25..c10f491a115 100644 --- a/core/object/worker_thread_pool.cpp +++ b/core/object/worker_thread_pool.cpp @@ -35,6 +35,8 @@ #include "core/os/thread_safe.h" #include "core/templates/command_queue_mt.h" +WorkerThreadPool::Task *const WorkerThreadPool::ThreadData::YIELDING = (Task *)1; + void WorkerThreadPool::Task::free_template_userdata() { ERR_FAIL_NULL(template_userdata); ERR_FAIL_NULL(native_func_userdata); @@ -391,70 +393,11 @@ Error WorkerThreadPool::wait_for_task_completion(TaskID p_task_id) { task_mutex.unlock(); if (caller_pool_thread) { - while (true) { - Task *task_to_process = nullptr; - { - MutexLock lock(task_mutex); - bool was_signaled = caller_pool_thread->signaled; - caller_pool_thread->signaled = false; - - if (task->completed) { - // This thread was awaken also for some reason, but it's about to exit. - // Let's find out what may be pending and forward the requests. - if (!exit_threads && was_signaled) { - uint32_t to_process = task_queue.first() ? 1 : 0; - uint32_t to_promote = caller_pool_thread->current_task->low_priority && low_priority_task_queue.first() ? 1 : 0; - if (to_process || to_promote) { - // This thread must be left alone since it won't loop again. - caller_pool_thread->signaled = true; - _notify_threads(caller_pool_thread, to_process, to_promote); - } - } - - task->waiting_pool--; - if (task->waiting_pool == 0 && task->waiting_user == 0) { - tasks.erase(p_task_id); - task_allocator.free(task); - } - - break; - } - - if (!exit_threads) { - // This is a thread from the pool. It shouldn't just idle. - // Let's try to process other tasks while we wait. - - if (caller_pool_thread->current_task->low_priority && low_priority_task_queue.first()) { - if (_try_promote_low_priority_task()) { - _notify_threads(caller_pool_thread, 1, 0); - } - } - - if (singleton->task_queue.first()) { - task_to_process = task_queue.first()->self(); - task_queue.remove(task_queue.first()); - } - - if (!task_to_process) { - caller_pool_thread->awaited_task = task; - - if (flushing_cmd_queue) { - flushing_cmd_queue->unlock(); - } - caller_pool_thread->cond_var.wait(lock); - if (flushing_cmd_queue) { - flushing_cmd_queue->lock(); - } - - DEV_ASSERT(exit_threads || caller_pool_thread->signaled || task->completed); - caller_pool_thread->awaited_task = nullptr; - } - } - } - - if (task_to_process) { - _process_task(task_to_process); - } + _wait_collaboratively(caller_pool_thread, task); + task->waiting_pool--; + if (task->waiting_pool == 0 && task->waiting_user == 0) { + tasks.erase(p_task_id); + task_allocator.free(task); } } else { task->done_semaphore.wait(); @@ -470,6 +413,99 @@ Error WorkerThreadPool::wait_for_task_completion(TaskID p_task_id) { return OK; } +void WorkerThreadPool::_wait_collaboratively(ThreadData *p_caller_pool_thread, Task *p_task) { + // Keep processing tasks until the condition to stop waiting is met. + +#define IS_WAIT_OVER (unlikely(p_task == ThreadData::YIELDING) ? p_caller_pool_thread->yield_is_over : p_task->completed) + + while (true) { + Task *task_to_process = nullptr; + { + MutexLock lock(task_mutex); + bool was_signaled = p_caller_pool_thread->signaled; + p_caller_pool_thread->signaled = false; + + if (IS_WAIT_OVER) { + p_caller_pool_thread->yield_is_over = false; + if (!exit_threads && was_signaled) { + // This thread was awaken for some additional reason, but it's about to exit. + // Let's find out what may be pending and forward the requests. + uint32_t to_process = task_queue.first() ? 1 : 0; + uint32_t to_promote = p_caller_pool_thread->current_task->low_priority && low_priority_task_queue.first() ? 1 : 0; + if (to_process || to_promote) { + // This thread must be left alone since it won't loop again. + p_caller_pool_thread->signaled = true; + _notify_threads(p_caller_pool_thread, to_process, to_promote); + } + } + + break; + } + + if (!exit_threads) { + if (p_caller_pool_thread->current_task->low_priority && low_priority_task_queue.first()) { + if (_try_promote_low_priority_task()) { + _notify_threads(p_caller_pool_thread, 1, 0); + } + } + + if (singleton->task_queue.first()) { + task_to_process = task_queue.first()->self(); + task_queue.remove(task_queue.first()); + } + + if (!task_to_process) { + p_caller_pool_thread->awaited_task = p_task; + + if (flushing_cmd_queue) { + flushing_cmd_queue->unlock(); + } + p_caller_pool_thread->cond_var.wait(lock); + if (flushing_cmd_queue) { + flushing_cmd_queue->lock(); + } + + DEV_ASSERT(exit_threads || p_caller_pool_thread->signaled || IS_WAIT_OVER); + p_caller_pool_thread->awaited_task = nullptr; + } + } + } + + if (task_to_process) { + _process_task(task_to_process); + } + } +} + +void WorkerThreadPool::yield() { + int th_index = get_thread_index(); + ERR_FAIL_COND_MSG(th_index == -1, "This function can only be called from a worker thread."); + _wait_collaboratively(&threads[th_index], ThreadData::YIELDING); +} + +void WorkerThreadPool::notify_yield_over(TaskID p_task_id) { + task_mutex.lock(); + Task **taskp = tasks.getptr(p_task_id); + if (!taskp) { + task_mutex.unlock(); + ERR_FAIL_MSG("Invalid Task ID."); + } + Task *task = *taskp; + +#ifdef DEBUG_ENABLED + if (task->pool_thread_index == get_thread_index()) { + WARN_PRINT("A worker thread is attempting to notify itself. That makes no sense."); + } +#endif + + ThreadData &td = threads[task->pool_thread_index]; + td.yield_is_over = true; + td.signaled = true; + td.cond_var.notify_one(); + + task_mutex.unlock(); +} + WorkerThreadPool::GroupID WorkerThreadPool::_add_group_task(const Callable &p_callable, void (*p_func)(void *, uint32_t), void *p_userdata, BaseTemplateUserdata *p_template_userdata, int p_elements, int p_tasks, bool p_high_priority, const String &p_description) { ERR_FAIL_COND_V(p_elements < 0, INVALID_TASK_ID); if (p_tasks < 0) { diff --git a/core/object/worker_thread_pool.h b/core/object/worker_thread_pool.h index fdddc9a647f..64f24df79fd 100644 --- a/core/object/worker_thread_pool.h +++ b/core/object/worker_thread_pool.h @@ -107,13 +107,21 @@ private: BinaryMutex task_mutex; struct ThreadData { + static Task *const YIELDING; // Too bad constexpr doesn't work here. + uint32_t index = 0; Thread thread; - bool ready_for_scripting = false; - bool signaled = false; + bool ready_for_scripting : 1; + bool signaled : 1; + bool yield_is_over : 1; Task *current_task = nullptr; - Task *awaited_task = nullptr; // Null if not awaiting the condition variable. Special value for idle-waiting. + Task *awaited_task = nullptr; // Null if not awaiting the condition variable, or special value (YIELDING). ConditionVariable cond_var; + + ThreadData() : + ready_for_scripting(false), + signaled(false), + yield_is_over(false) {} }; TightLocalVector threads; @@ -177,6 +185,8 @@ private: } }; + void _wait_collaboratively(ThreadData *p_caller_pool_thread, Task *p_task); + protected: static void _bind_methods(); @@ -196,6 +206,9 @@ public: bool is_task_completed(TaskID p_task_id) const; Error wait_for_task_completion(TaskID p_task_id); + void yield(); + void notify_yield_over(TaskID p_task_id); + template GroupID add_template_group_task(C *p_instance, M p_method, U p_userdata, int p_elements, int p_tasks = -1, bool p_high_priority = false, const String &p_description = String()) { typedef GroupUserData GroupUD; diff --git a/tests/core/templates/test_command_queue.h b/tests/core/templates/test_command_queue.h index e94c108694d..d2957b5c402 100644 --- a/tests/core/templates/test_command_queue.h +++ b/tests/core/templates/test_command_queue.h @@ -33,6 +33,7 @@ #include "core/config/project_settings.h" #include "core/math/random_number_generator.h" +#include "core/object/worker_thread_pool.h" #include "core/os/os.h" #include "core/os/thread.h" #include "core/templates/command_queue_mt.h" @@ -100,7 +101,7 @@ public: ThreadWork reader_threadwork; ThreadWork writer_threadwork; - CommandQueueMT command_queue = CommandQueueMT(true); + CommandQueueMT command_queue; enum TestMsgType { TEST_MSG_FUNC1_TRANSFORM, @@ -119,6 +120,7 @@ public: bool exit_threads = false; Thread reader_thread; + WorkerThreadPool::TaskID reader_task_id = WorkerThreadPool::INVALID_TASK_ID; Thread writer_thread; int func1_count = 0; @@ -148,11 +150,16 @@ public: void reader_thread_loop() { reader_threadwork.thread_wait_for_work(); while (!exit_threads) { - if (message_count_to_read < 0) { + if (reader_task_id == WorkerThreadPool::INVALID_TASK_ID) { command_queue.flush_all(); - } - for (int i = 0; i < message_count_to_read; i++) { - command_queue.wait_and_flush(); + } else { + if (message_count_to_read < 0) { + command_queue.flush_all(); + } + for (int i = 0; i < message_count_to_read; i++) { + WorkerThreadPool::get_singleton()->yield(); + command_queue.wait_and_flush(); + } } message_count_to_read = 0; @@ -216,8 +223,13 @@ public: sts->writer_thread_loop(); } - void init_threads() { - reader_thread.start(&SharedThreadState::static_reader_thread_loop, this); + void init_threads(bool p_use_thread_pool_sync = false) { + if (p_use_thread_pool_sync) { + reader_task_id = WorkerThreadPool::get_singleton()->add_native_task(&SharedThreadState::static_reader_thread_loop, this, true); + command_queue.set_pump_task_id(reader_task_id); + } else { + reader_thread.start(&SharedThreadState::static_reader_thread_loop, this); + } writer_thread.start(&SharedThreadState::static_writer_thread_loop, this); } void destroy_threads() { @@ -225,16 +237,20 @@ public: reader_threadwork.main_start_work(); writer_threadwork.main_start_work(); - reader_thread.wait_to_finish(); + if (reader_task_id != WorkerThreadPool::INVALID_TASK_ID) { + WorkerThreadPool::get_singleton()->wait_for_task_completion(reader_task_id); + } else { + reader_thread.wait_to_finish(); + } writer_thread.wait_to_finish(); } }; -TEST_CASE("[CommandQueue] Test Queue Basics") { +static void test_command_queue_basic(bool p_use_thread_pool_sync) { const char *COMMAND_QUEUE_SETTING = "memory/limits/command_queue/multithreading_queue_size_kb"; ProjectSettings::get_singleton()->set_setting(COMMAND_QUEUE_SETTING, 1); SharedThreadState sts; - sts.init_threads(); + sts.init_threads(p_use_thread_pool_sync); sts.add_msg_to_write(SharedThreadState::TEST_MSG_FUNC1_TRANSFORM); sts.writer_threadwork.main_start_work(); @@ -272,6 +288,14 @@ TEST_CASE("[CommandQueue] Test Queue Basics") { ProjectSettings::get_singleton()->property_get_revert(COMMAND_QUEUE_SETTING)); } +TEST_CASE("[CommandQueue] Test Queue Basics") { + test_command_queue_basic(false); +} + +TEST_CASE("[CommandQueue] Test Queue Basics with WorkerThreadPool sync.") { + test_command_queue_basic(true); +} + TEST_CASE("[CommandQueue] Test Queue Wrapping to same spot.") { const char *COMMAND_QUEUE_SETTING = "memory/limits/command_queue/multithreading_queue_size_kb"; ProjectSettings::get_singleton()->set_setting(COMMAND_QUEUE_SETTING, 1); diff --git a/tests/core/threads/test_worker_thread_pool.h b/tests/core/threads/test_worker_thread_pool.h index e9a762b57bb..0a0291d11b2 100644 --- a/tests/core/threads/test_worker_thread_pool.h +++ b/tests/core/threads/test_worker_thread_pool.h @@ -38,6 +38,7 @@ namespace TestWorkerThreadPool { static LocalVector> counter; +static SafeFlag exit; static void static_test(void *p_arg) { counter[(uint64_t)p_arg].increment(); @@ -106,6 +107,72 @@ TEST_CASE("[WorkerThreadPool] Process elements using group tasks") { } } +static void static_test_daemon(void *p_arg) { + while (!exit.is_set()) { + counter[0].add(1); + WorkerThreadPool::get_singleton()->yield(); + } +} + +static void static_busy_task(void *p_arg) { + while (!exit.is_set()) { + OS::get_singleton()->delay_usec(1); + } +} + +static void static_legit_task(void *p_arg) { + *((bool *)p_arg) = counter[0].get() > 0; + counter[1].add(1); +} + +TEST_CASE("[WorkerThreadPool] Run a yielding daemon as the only hope for other tasks to run") { + exit.clear(); + counter.clear(); + counter.resize(2); + + WorkerThreadPool::TaskID daemon_task_id = WorkerThreadPool::get_singleton()->add_native_task(static_test_daemon, nullptr, true); + + int num_threads = WorkerThreadPool::get_singleton()->get_thread_count(); + + // Keep all the other threads busy. + LocalVector task_ids; + for (int i = 0; i < num_threads - 1; i++) { + task_ids.push_back(WorkerThreadPool::get_singleton()->add_native_task(static_busy_task, nullptr, true)); + } + + LocalVector legit_task_ids; + LocalVector legit_task_needed_yield; + int legit_tasks_count = num_threads * 4; + legit_task_needed_yield.resize(legit_tasks_count); + for (int i = 0; i < legit_tasks_count; i++) { + legit_task_needed_yield[i] = false; + task_ids.push_back(WorkerThreadPool::get_singleton()->add_native_task(static_legit_task, &legit_task_needed_yield[i], i >= legit_tasks_count / 2)); + } + + while (counter[1].get() != legit_tasks_count) { + OS::get_singleton()->delay_usec(1); + } + + exit.set(); + for (uint32_t i = 0; i < task_ids.size(); i++) { + WorkerThreadPool::get_singleton()->wait_for_task_completion(task_ids[i]); + } + WorkerThreadPool::get_singleton()->notify_yield_over(daemon_task_id); + WorkerThreadPool::get_singleton()->wait_for_task_completion(daemon_task_id); + + CHECK_MESSAGE(counter[0].get() > 0, "Daemon task should have looped at least once."); + CHECK_MESSAGE(counter[1].get() == legit_tasks_count, "All legit tasks should have been able to run."); + + bool all_needed_yield = true; + for (int i = 0; i < legit_tasks_count; i++) { + if (!legit_task_needed_yield[i]) { + all_needed_yield = false; + break; + } + } + CHECK_MESSAGE(all_needed_yield, "All legit tasks should have needed the daemon yielding to run."); +} + } // namespace TestWorkerThreadPool #endif // TEST_WORKER_THREAD_POOL_H