summaryrefslogtreecommitdiff
path: root/src/tests/common/fibers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/tests/common/fibers.cpp')
-rw-r--r--src/tests/common/fibers.cpp214
1 files changed, 214 insertions, 0 deletions
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp
new file mode 100644
index 000000000..ff840afa6
--- /dev/null
+++ b/src/tests/common/fibers.cpp
@@ -0,0 +1,214 @@
1// Copyright 2020 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include <atomic>
6#include <cstdlib>
7#include <functional>
8#include <memory>
9#include <thread>
10#include <unordered_map>
11#include <vector>
12
13#include <catch2/catch.hpp>
14#include <math.h>
15#include "common/common_types.h"
16#include "common/fiber.h"
17#include "common/spin_lock.h"
18
19namespace Common {
20
21class TestControl1 {
22public:
23 TestControl1() = default;
24
25 void DoWork();
26
27 void ExecuteThread(u32 id);
28
29 std::unordered_map<std::thread::id, u32> ids;
30 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
31 std::vector<std::shared_ptr<Common::Fiber>> work_fibers;
32 std::vector<u32> items;
33 std::vector<u32> results;
34};
35
36static void WorkControl1(void* control) {
37 TestControl1* test_control = static_cast<TestControl1*>(control);
38 test_control->DoWork();
39}
40
41void TestControl1::DoWork() {
42 std::thread::id this_id = std::this_thread::get_id();
43 u32 id = ids[this_id];
44 u32 value = items[id];
45 for (u32 i = 0; i < id; i++) {
46 value++;
47 }
48 results[id] = value;
49 Fiber::YieldTo(work_fibers[id], thread_fibers[id]);
50}
51
52void TestControl1::ExecuteThread(u32 id) {
53 std::thread::id this_id = std::this_thread::get_id();
54 ids[this_id] = id;
55 auto thread_fiber = Fiber::ThreadToFiber();
56 thread_fibers[id] = thread_fiber;
57 work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
58 items[id] = rand() % 256;
59 Fiber::YieldTo(thread_fibers[id], work_fibers[id]);
60 thread_fibers[id]->Exit();
61}
62
63static void ThreadStart1(u32 id, TestControl1& test_control) {
64 test_control.ExecuteThread(id);
65}
66
67
68TEST_CASE("Fibers::Setup", "[common]") {
69 constexpr u32 num_threads = 7;
70 TestControl1 test_control{};
71 test_control.thread_fibers.resize(num_threads, nullptr);
72 test_control.work_fibers.resize(num_threads, nullptr);
73 test_control.items.resize(num_threads, 0);
74 test_control.results.resize(num_threads, 0);
75 std::vector<std::thread> threads;
76 for (u32 i = 0; i < num_threads; i++) {
77 threads.emplace_back(ThreadStart1, i, std::ref(test_control));
78 }
79 for (u32 i = 0; i < num_threads; i++) {
80 threads[i].join();
81 }
82 for (u32 i = 0; i < num_threads; i++) {
83 REQUIRE(test_control.items[i] + i == test_control.results[i]);
84 }
85}
86
87class TestControl2 {
88public:
89 TestControl2() = default;
90
91 void DoWork1() {
92 trap2 = false;
93 while (trap.load());
94 for (u32 i = 0; i < 12000; i++) {
95 value1 += i;
96 }
97 Fiber::YieldTo(fiber1, fiber3);
98 std::thread::id this_id = std::this_thread::get_id();
99 u32 id = ids[this_id];
100 assert1 = id == 1;
101 value2 += 5000;
102 Fiber::YieldTo(fiber1, thread_fibers[id]);
103 }
104
105 void DoWork2() {
106 while (trap2.load());
107 value2 = 2000;
108 trap = false;
109 Fiber::YieldTo(fiber2, fiber1);
110 assert3 = false;
111 }
112
113 void DoWork3() {
114 std::thread::id this_id = std::this_thread::get_id();
115 u32 id = ids[this_id];
116 assert2 = id == 0;
117 value1 += 1000;
118 Fiber::YieldTo(fiber3, thread_fibers[id]);
119 }
120
121 void ExecuteThread(u32 id);
122
123 void CallFiber1() {
124 std::thread::id this_id = std::this_thread::get_id();
125 u32 id = ids[this_id];
126 Fiber::YieldTo(thread_fibers[id], fiber1);
127 }
128
129 void CallFiber2() {
130 std::thread::id this_id = std::this_thread::get_id();
131 u32 id = ids[this_id];
132 Fiber::YieldTo(thread_fibers[id], fiber2);
133 }
134
135 void Exit();
136
137 bool assert1{};
138 bool assert2{};
139 bool assert3{true};
140 u32 value1{};
141 u32 value2{};
142 std::atomic<bool> trap{true};
143 std::atomic<bool> trap2{true};
144 std::unordered_map<std::thread::id, u32> ids;
145 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
146 std::shared_ptr<Common::Fiber> fiber1;
147 std::shared_ptr<Common::Fiber> fiber2;
148 std::shared_ptr<Common::Fiber> fiber3;
149};
150
151static void WorkControl2_1(void* control) {
152 TestControl2* test_control = static_cast<TestControl2*>(control);
153 test_control->DoWork1();
154}
155
156static void WorkControl2_2(void* control) {
157 TestControl2* test_control = static_cast<TestControl2*>(control);
158 test_control->DoWork2();
159}
160
161static void WorkControl2_3(void* control) {
162 TestControl2* test_control = static_cast<TestControl2*>(control);
163 test_control->DoWork3();
164}
165
166void TestControl2::ExecuteThread(u32 id) {
167 std::thread::id this_id = std::this_thread::get_id();
168 ids[this_id] = id;
169 auto thread_fiber = Fiber::ThreadToFiber();
170 thread_fibers[id] = thread_fiber;
171}
172
173void TestControl2::Exit() {
174 std::thread::id this_id = std::this_thread::get_id();
175 u32 id = ids[this_id];
176 thread_fibers[id]->Exit();
177}
178
179static void ThreadStart2_1(u32 id, TestControl2& test_control) {
180 test_control.ExecuteThread(id);
181 test_control.CallFiber1();
182 test_control.Exit();
183}
184
185static void ThreadStart2_2(u32 id, TestControl2& test_control) {
186 test_control.ExecuteThread(id);
187 test_control.CallFiber2();
188 test_control.Exit();
189}
190
191TEST_CASE("Fibers::InterExchange", "[common]") {
192 TestControl2 test_control{};
193 test_control.thread_fibers.resize(2, nullptr);
194 test_control.fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_1}, &test_control);
195 test_control.fiber2 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_2}, &test_control);
196 test_control.fiber3 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_3}, &test_control);
197 std::thread thread1(ThreadStart2_1, 0, std::ref(test_control));
198 std::thread thread2(ThreadStart2_2, 1, std::ref(test_control));
199 thread1.join();
200 thread2.join();
201 REQUIRE(test_control.assert1);
202 REQUIRE(test_control.assert2);
203 REQUIRE(test_control.assert3);
204 REQUIRE(test_control.value2 == 7000);
205 u32 cal_value = 0;
206 for (u32 i = 0; i < 12000; i++) {
207 cal_value += i;
208 }
209 cal_value += 1000;
210 REQUIRE(test_control.value1 == cal_value);
211}
212
213
214} // namespace Common