summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/tests/common/fibers.cpp71
1 files changed, 40 insertions, 31 deletions
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp
index 4fd92428f..4757dd2b4 100644
--- a/src/tests/common/fibers.cpp
+++ b/src/tests/common/fibers.cpp
@@ -6,18 +6,40 @@
6#include <cstdlib> 6#include <cstdlib>
7#include <functional> 7#include <functional>
8#include <memory> 8#include <memory>
9#include <mutex>
10#include <stdexcept>
9#include <thread> 11#include <thread>
10#include <unordered_map> 12#include <unordered_map>
11#include <vector> 13#include <vector>
12 14
13#include <catch2/catch.hpp> 15#include <catch2/catch.hpp>
14#include <math.h> 16
15#include "common/common_types.h" 17#include "common/common_types.h"
16#include "common/fiber.h" 18#include "common/fiber.h"
17#include "common/spin_lock.h"
18 19
19namespace Common { 20namespace Common {
20 21
22class ThreadIds {
23public:
24 void Register(u32 id) {
25 const auto thread_id = std::this_thread::get_id();
26 std::scoped_lock lock{mutex};
27 if (ids.contains(thread_id)) {
28 throw std::logic_error{"Registering the same thread twice"};
29 }
30 ids.emplace(thread_id, id);
31 }
32
33 [[nodiscard]] u32 Get() const {
34 std::scoped_lock lock{mutex};
35 return ids.at(std::this_thread::get_id());
36 }
37
38private:
39 mutable std::mutex mutex;
40 std::unordered_map<std::thread::id, u32> ids;
41};
42
21class TestControl1 { 43class TestControl1 {
22public: 44public:
23 TestControl1() = default; 45 TestControl1() = default;
@@ -26,7 +48,7 @@ public:
26 48
27 void ExecuteThread(u32 id); 49 void ExecuteThread(u32 id);
28 50
29 std::unordered_map<std::thread::id, u32> ids; 51 ThreadIds thread_ids;
30 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; 52 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
31 std::vector<std::shared_ptr<Common::Fiber>> work_fibers; 53 std::vector<std::shared_ptr<Common::Fiber>> work_fibers;
32 std::vector<u32> items; 54 std::vector<u32> items;
@@ -39,8 +61,7 @@ static void WorkControl1(void* control) {
39} 61}
40 62
41void TestControl1::DoWork() { 63void TestControl1::DoWork() {
42 std::thread::id this_id = std::this_thread::get_id(); 64 const u32 id = thread_ids.Get();
43 u32 id = ids[this_id];
44 u32 value = items[id]; 65 u32 value = items[id];
45 for (u32 i = 0; i < id; i++) { 66 for (u32 i = 0; i < id; i++) {
46 value++; 67 value++;
@@ -50,8 +71,7 @@ void TestControl1::DoWork() {
50} 71}
51 72
52void TestControl1::ExecuteThread(u32 id) { 73void TestControl1::ExecuteThread(u32 id) {
53 std::thread::id this_id = std::this_thread::get_id(); 74 thread_ids.Register(id);
54 ids[this_id] = id;
55 auto thread_fiber = Fiber::ThreadToFiber(); 75 auto thread_fiber = Fiber::ThreadToFiber();
56 thread_fibers[id] = thread_fiber; 76 thread_fibers[id] = thread_fiber;
57 work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this); 77 work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
@@ -98,8 +118,7 @@ public:
98 value1 += i; 118 value1 += i;
99 } 119 }
100 Fiber::YieldTo(fiber1, fiber3); 120 Fiber::YieldTo(fiber1, fiber3);
101 std::thread::id this_id = std::this_thread::get_id(); 121 const u32 id = thread_ids.Get();
102 u32 id = ids[this_id];
103 assert1 = id == 1; 122 assert1 = id == 1;
104 value2 += 5000; 123 value2 += 5000;
105 Fiber::YieldTo(fiber1, thread_fibers[id]); 124 Fiber::YieldTo(fiber1, thread_fibers[id]);
@@ -115,8 +134,7 @@ public:
115 } 134 }
116 135
117 void DoWork3() { 136 void DoWork3() {
118 std::thread::id this_id = std::this_thread::get_id(); 137 const u32 id = thread_ids.Get();
119 u32 id = ids[this_id];
120 assert2 = id == 0; 138 assert2 = id == 0;
121 value1 += 1000; 139 value1 += 1000;
122 Fiber::YieldTo(fiber3, thread_fibers[id]); 140 Fiber::YieldTo(fiber3, thread_fibers[id]);
@@ -125,14 +143,12 @@ public:
125 void ExecuteThread(u32 id); 143 void ExecuteThread(u32 id);
126 144
127 void CallFiber1() { 145 void CallFiber1() {
128 std::thread::id this_id = std::this_thread::get_id(); 146 const u32 id = thread_ids.Get();
129 u32 id = ids[this_id];
130 Fiber::YieldTo(thread_fibers[id], fiber1); 147 Fiber::YieldTo(thread_fibers[id], fiber1);
131 } 148 }
132 149
133 void CallFiber2() { 150 void CallFiber2() {
134 std::thread::id this_id = std::this_thread::get_id(); 151 const u32 id = thread_ids.Get();
135 u32 id = ids[this_id];
136 Fiber::YieldTo(thread_fibers[id], fiber2); 152 Fiber::YieldTo(thread_fibers[id], fiber2);
137 } 153 }
138 154
@@ -145,7 +161,7 @@ public:
145 u32 value2{}; 161 u32 value2{};
146 std::atomic<bool> trap{true}; 162 std::atomic<bool> trap{true};
147 std::atomic<bool> trap2{true}; 163 std::atomic<bool> trap2{true};
148 std::unordered_map<std::thread::id, u32> ids; 164 ThreadIds thread_ids;
149 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; 165 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
150 std::shared_ptr<Common::Fiber> fiber1; 166 std::shared_ptr<Common::Fiber> fiber1;
151 std::shared_ptr<Common::Fiber> fiber2; 167 std::shared_ptr<Common::Fiber> fiber2;
@@ -168,15 +184,13 @@ static void WorkControl2_3(void* control) {
168} 184}
169 185
170void TestControl2::ExecuteThread(u32 id) { 186void TestControl2::ExecuteThread(u32 id) {
171 std::thread::id this_id = std::this_thread::get_id(); 187 thread_ids.Register(id);
172 ids[this_id] = id;
173 auto thread_fiber = Fiber::ThreadToFiber(); 188 auto thread_fiber = Fiber::ThreadToFiber();
174 thread_fibers[id] = thread_fiber; 189 thread_fibers[id] = thread_fiber;
175} 190}
176 191
177void TestControl2::Exit() { 192void TestControl2::Exit() {
178 std::thread::id this_id = std::this_thread::get_id(); 193 const u32 id = thread_ids.Get();
179 u32 id = ids[this_id];
180 thread_fibers[id]->Exit(); 194 thread_fibers[id]->Exit();
181} 195}
182 196
@@ -228,24 +242,21 @@ public:
228 void DoWork1() { 242 void DoWork1() {
229 value1 += 1; 243 value1 += 1;
230 Fiber::YieldTo(fiber1, fiber2); 244 Fiber::YieldTo(fiber1, fiber2);
231 std::thread::id this_id = std::this_thread::get_id(); 245 const u32 id = thread_ids.Get();
232 u32 id = ids[this_id];
233 value3 += 1; 246 value3 += 1;
234 Fiber::YieldTo(fiber1, thread_fibers[id]); 247 Fiber::YieldTo(fiber1, thread_fibers[id]);
235 } 248 }
236 249
237 void DoWork2() { 250 void DoWork2() {
238 value2 += 1; 251 value2 += 1;
239 std::thread::id this_id = std::this_thread::get_id(); 252 const u32 id = thread_ids.Get();
240 u32 id = ids[this_id];
241 Fiber::YieldTo(fiber2, thread_fibers[id]); 253 Fiber::YieldTo(fiber2, thread_fibers[id]);
242 } 254 }
243 255
244 void ExecuteThread(u32 id); 256 void ExecuteThread(u32 id);
245 257
246 void CallFiber1() { 258 void CallFiber1() {
247 std::thread::id this_id = std::this_thread::get_id(); 259 const u32 id = thread_ids.Get();
248 u32 id = ids[this_id];
249 Fiber::YieldTo(thread_fibers[id], fiber1); 260 Fiber::YieldTo(thread_fibers[id], fiber1);
250 } 261 }
251 262
@@ -254,7 +265,7 @@ public:
254 u32 value1{}; 265 u32 value1{};
255 u32 value2{}; 266 u32 value2{};
256 u32 value3{}; 267 u32 value3{};
257 std::unordered_map<std::thread::id, u32> ids; 268 ThreadIds thread_ids;
258 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers; 269 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
259 std::shared_ptr<Common::Fiber> fiber1; 270 std::shared_ptr<Common::Fiber> fiber1;
260 std::shared_ptr<Common::Fiber> fiber2; 271 std::shared_ptr<Common::Fiber> fiber2;
@@ -271,15 +282,13 @@ static void WorkControl3_2(void* control) {
271} 282}
272 283
273void TestControl3::ExecuteThread(u32 id) { 284void TestControl3::ExecuteThread(u32 id) {
274 std::thread::id this_id = std::this_thread::get_id(); 285 thread_ids.Register(id);
275 ids[this_id] = id;
276 auto thread_fiber = Fiber::ThreadToFiber(); 286 auto thread_fiber = Fiber::ThreadToFiber();
277 thread_fibers[id] = thread_fiber; 287 thread_fibers[id] = thread_fiber;
278} 288}
279 289
280void TestControl3::Exit() { 290void TestControl3::Exit() {
281 std::thread::id this_id = std::this_thread::get_id(); 291 const u32 id = thread_ids.Get();
282 u32 id = ids[this_id];
283 thread_fibers[id]->Exit(); 292 thread_fibers[id]->Exit();
284} 293}
285 294