diff options
Diffstat (limited to 'src/common/thread_worker.h')
| -rw-r--r-- | src/common/thread_worker.h | 103 |
1 files changed, 95 insertions, 8 deletions
diff --git a/src/common/thread_worker.h b/src/common/thread_worker.h index f1859971f..cd0017726 100644 --- a/src/common/thread_worker.h +++ b/src/common/thread_worker.h | |||
| @@ -5,26 +5,113 @@ | |||
| 5 | #pragma once | 5 | #pragma once |
| 6 | 6 | ||
| 7 | #include <atomic> | 7 | #include <atomic> |
| 8 | #include <condition_variable> | ||
| 8 | #include <functional> | 9 | #include <functional> |
| 9 | #include <mutex> | 10 | #include <mutex> |
| 11 | #include <stop_token> | ||
| 10 | #include <string> | 12 | #include <string> |
| 13 | #include <thread> | ||
| 14 | #include <type_traits> | ||
| 11 | #include <vector> | 15 | #include <vector> |
| 12 | #include <queue> | 16 | #include <queue> |
| 13 | 17 | ||
| 18 | #include "common/thread.h" | ||
| 19 | #include "common/unique_function.h" | ||
| 20 | |||
| 14 | namespace Common { | 21 | namespace Common { |
| 15 | 22 | ||
| 16 | class ThreadWorker final { | 23 | template <class StateType = void> |
| 24 | class StatefulThreadWorker { | ||
| 25 | static constexpr bool with_state = !std::is_same_v<StateType, void>; | ||
| 26 | |||
| 27 | struct DummyCallable { | ||
| 28 | int operator()() const noexcept { | ||
| 29 | return 0; | ||
| 30 | } | ||
| 31 | }; | ||
| 32 | |||
| 33 | using Task = | ||
| 34 | std::conditional_t<with_state, UniqueFunction<void, StateType*>, UniqueFunction<void>>; | ||
| 35 | using StateMaker = std::conditional_t<with_state, std::function<StateType()>, DummyCallable>; | ||
| 36 | |||
| 17 | public: | 37 | public: |
| 18 | explicit ThreadWorker(std::size_t num_workers, const std::string& name); | 38 | explicit StatefulThreadWorker(size_t num_workers, std::string name, StateMaker func = {}) |
| 19 | ~ThreadWorker(); | 39 | : workers_queued{num_workers}, thread_name{std::move(name)} { |
| 20 | void QueueWork(std::function<void()>&& work); | 40 | const auto lambda = [this, func](std::stop_token stop_token) { |
| 41 | Common::SetCurrentThreadName(thread_name.c_str()); | ||
| 42 | { | ||
| 43 | [[maybe_unused]] std::conditional_t<with_state, StateType, int> state{func()}; | ||
| 44 | while (!stop_token.stop_requested()) { | ||
| 45 | Task task; | ||
| 46 | { | ||
| 47 | std::unique_lock lock{queue_mutex}; | ||
| 48 | if (requests.empty()) { | ||
| 49 | wait_condition.notify_all(); | ||
| 50 | } | ||
| 51 | condition.wait(lock, stop_token, [this] { return !requests.empty(); }); | ||
| 52 | if (stop_token.stop_requested()) { | ||
| 53 | break; | ||
| 54 | } | ||
| 55 | task = std::move(requests.front()); | ||
| 56 | requests.pop(); | ||
| 57 | } | ||
| 58 | if constexpr (with_state) { | ||
| 59 | task(&state); | ||
| 60 | } else { | ||
| 61 | task(); | ||
| 62 | } | ||
| 63 | ++work_done; | ||
| 64 | } | ||
| 65 | } | ||
| 66 | ++workers_stopped; | ||
| 67 | wait_condition.notify_all(); | ||
| 68 | }; | ||
| 69 | threads.reserve(num_workers); | ||
| 70 | for (size_t i = 0; i < num_workers; ++i) { | ||
| 71 | threads.emplace_back(lambda); | ||
| 72 | } | ||
| 73 | } | ||
| 74 | |||
| 75 | StatefulThreadWorker& operator=(const StatefulThreadWorker&) = delete; | ||
| 76 | StatefulThreadWorker(const StatefulThreadWorker&) = delete; | ||
| 77 | |||
| 78 | StatefulThreadWorker& operator=(StatefulThreadWorker&&) = delete; | ||
| 79 | StatefulThreadWorker(StatefulThreadWorker&&) = delete; | ||
| 80 | |||
| 81 | void QueueWork(Task work) { | ||
| 82 | { | ||
| 83 | std::unique_lock lock{queue_mutex}; | ||
| 84 | requests.emplace(std::move(work)); | ||
| 85 | ++work_scheduled; | ||
| 86 | } | ||
| 87 | condition.notify_one(); | ||
| 88 | } | ||
| 89 | |||
| 90 | void WaitForRequests(std::stop_token stop_token = {}) { | ||
| 91 | std::stop_callback callback(stop_token, [this] { | ||
| 92 | for (auto& thread : threads) { | ||
| 93 | thread.request_stop(); | ||
| 94 | } | ||
| 95 | }); | ||
| 96 | std::unique_lock lock{queue_mutex}; | ||
| 97 | wait_condition.wait(lock, [this] { | ||
| 98 | return workers_stopped >= workers_queued || work_done >= work_scheduled; | ||
| 99 | }); | ||
| 100 | } | ||
| 21 | 101 | ||
| 22 | private: | 102 | private: |
| 23 | std::vector<std::thread> threads; | 103 | std::queue<Task> requests; |
| 24 | std::queue<std::function<void()>> requests; | ||
| 25 | std::mutex queue_mutex; | 104 | std::mutex queue_mutex; |
| 26 | std::condition_variable condition; | 105 | std::condition_variable_any condition; |
| 27 | std::atomic_bool stop{}; | 106 | std::condition_variable wait_condition; |
| 107 | std::atomic<size_t> work_scheduled{}; | ||
| 108 | std::atomic<size_t> work_done{}; | ||
| 109 | std::atomic<size_t> workers_stopped{}; | ||
| 110 | std::atomic<size_t> workers_queued{}; | ||
| 111 | std::string thread_name; | ||
| 112 | std::vector<std::jthread> threads; | ||
| 28 | }; | 113 | }; |
| 29 | 114 | ||
| 115 | using ThreadWorker = StatefulThreadWorker<>; | ||
| 116 | |||
| 30 | } // namespace Common | 117 | } // namespace Common |