summaryrefslogtreecommitdiff
path: root/src/tests/common/fibers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/tests/common/fibers.cpp')
-rw-r--r--src/tests/common/fibers.cpp367
1 files changed, 367 insertions, 0 deletions
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp
new file mode 100644
index 000000000..4757dd2b4
--- /dev/null
+++ b/src/tests/common/fibers.cpp
@@ -0,0 +1,367 @@
1// Copyright 2020 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include <atomic>
6#include <cstdlib>
7#include <functional>
8#include <memory>
9#include <mutex>
10#include <stdexcept>
11#include <thread>
12#include <unordered_map>
13#include <vector>
14
15#include <catch2/catch.hpp>
16
17#include "common/common_types.h"
18#include "common/fiber.h"
19
20namespace Common {
21
22class ThreadIds {
23public:
24 void Register(u32 id) {
25 const auto thread_id = std::this_thread::get_id();
26 std::scoped_lock lock{mutex};
27 if (ids.contains(thread_id)) {
28 throw std::logic_error{"Registering the same thread twice"};
29 }
30 ids.emplace(thread_id, id);
31 }
32
33 [[nodiscard]] u32 Get() const {
34 std::scoped_lock lock{mutex};
35 return ids.at(std::this_thread::get_id());
36 }
37
38private:
39 mutable std::mutex mutex;
40 std::unordered_map<std::thread::id, u32> ids;
41};
42
43class TestControl1 {
44public:
45 TestControl1() = default;
46
47 void DoWork();
48
49 void ExecuteThread(u32 id);
50
51 ThreadIds thread_ids;
52 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
53 std::vector<std::shared_ptr<Common::Fiber>> work_fibers;
54 std::vector<u32> items;
55 std::vector<u32> results;
56};
57
58static void WorkControl1(void* control) {
59 auto* test_control = static_cast<TestControl1*>(control);
60 test_control->DoWork();
61}
62
63void TestControl1::DoWork() {
64 const u32 id = thread_ids.Get();
65 u32 value = items[id];
66 for (u32 i = 0; i < id; i++) {
67 value++;
68 }
69 results[id] = value;
70 Fiber::YieldTo(work_fibers[id], thread_fibers[id]);
71}
72
73void TestControl1::ExecuteThread(u32 id) {
74 thread_ids.Register(id);
75 auto thread_fiber = Fiber::ThreadToFiber();
76 thread_fibers[id] = thread_fiber;
77 work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
78 items[id] = rand() % 256;
79 Fiber::YieldTo(thread_fibers[id], work_fibers[id]);
80 thread_fibers[id]->Exit();
81}
82
83static void ThreadStart1(u32 id, TestControl1& test_control) {
84 test_control.ExecuteThread(id);
85}
86
87/** This test checks for fiber setup configuration and validates that fibers are
88 * doing all the work required.
89 */
90TEST_CASE("Fibers::Setup", "[common]") {
91 constexpr std::size_t num_threads = 7;
92 TestControl1 test_control{};
93 test_control.thread_fibers.resize(num_threads);
94 test_control.work_fibers.resize(num_threads);
95 test_control.items.resize(num_threads, 0);
96 test_control.results.resize(num_threads, 0);
97 std::vector<std::thread> threads;
98 for (u32 i = 0; i < num_threads; i++) {
99 threads.emplace_back(ThreadStart1, i, std::ref(test_control));
100 }
101 for (u32 i = 0; i < num_threads; i++) {
102 threads[i].join();
103 }
104 for (u32 i = 0; i < num_threads; i++) {
105 REQUIRE(test_control.items[i] + i == test_control.results[i]);
106 }
107}
108
109class TestControl2 {
110public:
111 TestControl2() = default;
112
113 void DoWork1() {
114 trap2 = false;
115 while (trap.load())
116 ;
117 for (u32 i = 0; i < 12000; i++) {
118 value1 += i;
119 }
120 Fiber::YieldTo(fiber1, fiber3);
121 const u32 id = thread_ids.Get();
122 assert1 = id == 1;
123 value2 += 5000;
124 Fiber::YieldTo(fiber1, thread_fibers[id]);
125 }
126
127 void DoWork2() {
128 while (trap2.load())
129 ;
130 value2 = 2000;
131 trap = false;
132 Fiber::YieldTo(fiber2, fiber1);
133 assert3 = false;
134 }
135
136 void DoWork3() {
137 const u32 id = thread_ids.Get();
138 assert2 = id == 0;
139 value1 += 1000;
140 Fiber::YieldTo(fiber3, thread_fibers[id]);
141 }
142
143 void ExecuteThread(u32 id);
144
145 void CallFiber1() {
146 const u32 id = thread_ids.Get();
147 Fiber::YieldTo(thread_fibers[id], fiber1);
148 }
149
150 void CallFiber2() {
151 const u32 id = thread_ids.Get();
152 Fiber::YieldTo(thread_fibers[id], fiber2);
153 }
154
155 void Exit();
156
157 bool assert1{};
158 bool assert2{};
159 bool assert3{true};
160 u32 value1{};
161 u32 value2{};
162 std::atomic<bool> trap{true};
163 std::atomic<bool> trap2{true};
164 ThreadIds thread_ids;
165 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
166 std::shared_ptr<Common::Fiber> fiber1;
167 std::shared_ptr<Common::Fiber> fiber2;
168 std::shared_ptr<Common::Fiber> fiber3;
169};
170
171static void WorkControl2_1(void* control) {
172 auto* test_control = static_cast<TestControl2*>(control);
173 test_control->DoWork1();
174}
175
176static void WorkControl2_2(void* control) {
177 auto* test_control = static_cast<TestControl2*>(control);
178 test_control->DoWork2();
179}
180
181static void WorkControl2_3(void* control) {
182 auto* test_control = static_cast<TestControl2*>(control);
183 test_control->DoWork3();
184}
185
186void TestControl2::ExecuteThread(u32 id) {
187 thread_ids.Register(id);
188 auto thread_fiber = Fiber::ThreadToFiber();
189 thread_fibers[id] = thread_fiber;
190}
191
192void TestControl2::Exit() {
193 const u32 id = thread_ids.Get();
194 thread_fibers[id]->Exit();
195}
196
197static void ThreadStart2_1(u32 id, TestControl2& test_control) {
198 test_control.ExecuteThread(id);
199 test_control.CallFiber1();
200 test_control.Exit();
201}
202
203static void ThreadStart2_2(u32 id, TestControl2& test_control) {
204 test_control.ExecuteThread(id);
205 test_control.CallFiber2();
206 test_control.Exit();
207}
208
209/** This test checks for fiber thread exchange configuration and validates that fibers are
210 * that a fiber has been succesfully transfered from one thread to another and that the TLS
211 * region of the thread is kept while changing fibers.
212 */
213TEST_CASE("Fibers::InterExchange", "[common]") {
214 TestControl2 test_control{};
215 test_control.thread_fibers.resize(2);
216 test_control.fiber1 =
217 std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_1}, &test_control);
218 test_control.fiber2 =
219 std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_2}, &test_control);
220 test_control.fiber3 =
221 std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_3}, &test_control);
222 std::thread thread1(ThreadStart2_1, 0, std::ref(test_control));
223 std::thread thread2(ThreadStart2_2, 1, std::ref(test_control));
224 thread1.join();
225 thread2.join();
226 REQUIRE(test_control.assert1);
227 REQUIRE(test_control.assert2);
228 REQUIRE(test_control.assert3);
229 REQUIRE(test_control.value2 == 7000);
230 u32 cal_value = 0;
231 for (u32 i = 0; i < 12000; i++) {
232 cal_value += i;
233 }
234 cal_value += 1000;
235 REQUIRE(test_control.value1 == cal_value);
236}
237
238class TestControl3 {
239public:
240 TestControl3() = default;
241
242 void DoWork1() {
243 value1 += 1;
244 Fiber::YieldTo(fiber1, fiber2);
245 const u32 id = thread_ids.Get();
246 value3 += 1;
247 Fiber::YieldTo(fiber1, thread_fibers[id]);
248 }
249
250 void DoWork2() {
251 value2 += 1;
252 const u32 id = thread_ids.Get();
253 Fiber::YieldTo(fiber2, thread_fibers[id]);
254 }
255
256 void ExecuteThread(u32 id);
257
258 void CallFiber1() {
259 const u32 id = thread_ids.Get();
260 Fiber::YieldTo(thread_fibers[id], fiber1);
261 }
262
263 void Exit();
264
265 u32 value1{};
266 u32 value2{};
267 u32 value3{};
268 ThreadIds thread_ids;
269 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
270 std::shared_ptr<Common::Fiber> fiber1;
271 std::shared_ptr<Common::Fiber> fiber2;
272};
273
274static void WorkControl3_1(void* control) {
275 auto* test_control = static_cast<TestControl3*>(control);
276 test_control->DoWork1();
277}
278
279static void WorkControl3_2(void* control) {
280 auto* test_control = static_cast<TestControl3*>(control);
281 test_control->DoWork2();
282}
283
284void TestControl3::ExecuteThread(u32 id) {
285 thread_ids.Register(id);
286 auto thread_fiber = Fiber::ThreadToFiber();
287 thread_fibers[id] = thread_fiber;
288}
289
290void TestControl3::Exit() {
291 const u32 id = thread_ids.Get();
292 thread_fibers[id]->Exit();
293}
294
295static void ThreadStart3(u32 id, TestControl3& test_control) {
296 test_control.ExecuteThread(id);
297 test_control.CallFiber1();
298 test_control.Exit();
299}
300
301/** This test checks for one two threads racing for starting the same fiber.
302 * It checks execution occured in an ordered manner and by no time there were
303 * two contexts at the same time.
304 */
305TEST_CASE("Fibers::StartRace", "[common]") {
306 TestControl3 test_control{};
307 test_control.thread_fibers.resize(2);
308 test_control.fiber1 =
309 std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_1}, &test_control);
310 test_control.fiber2 =
311 std::make_shared<Fiber>(std::function<void(void*)>{WorkControl3_2}, &test_control);
312 std::thread thread1(ThreadStart3, 0, std::ref(test_control));
313 std::thread thread2(ThreadStart3, 1, std::ref(test_control));
314 thread1.join();
315 thread2.join();
316 REQUIRE(test_control.value1 == 1);
317 REQUIRE(test_control.value2 == 1);
318 REQUIRE(test_control.value3 == 1);
319}
320
321class TestControl4;
322
323static void WorkControl4(void* control);
324
325class TestControl4 {
326public:
327 TestControl4() {
328 fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl4}, this);
329 goal_reached = false;
330 rewinded = false;
331 }
332
333 void Execute() {
334 thread_fiber = Fiber::ThreadToFiber();
335 Fiber::YieldTo(thread_fiber, fiber1);
336 thread_fiber->Exit();
337 }
338
339 void DoWork() {
340 fiber1->SetRewindPoint(std::function<void(void*)>{WorkControl4}, this);
341 if (rewinded) {
342 goal_reached = true;
343 Fiber::YieldTo(fiber1, thread_fiber);
344 }
345 rewinded = true;
346 fiber1->Rewind();
347 }
348
349 std::shared_ptr<Common::Fiber> fiber1;
350 std::shared_ptr<Common::Fiber> thread_fiber;
351 bool goal_reached;
352 bool rewinded;
353};
354
355static void WorkControl4(void* control) {
356 auto* test_control = static_cast<TestControl4*>(control);
357 test_control->DoWork();
358}
359
360TEST_CASE("Fibers::Rewind", "[common]") {
361 TestControl4 test_control{};
362 test_control.Execute();
363 REQUIRE(test_control.goal_reached);
364 REQUIRE(test_control.rewinded);
365}
366
367} // namespace Common