summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/common/fiber.cpp32
-rw-r--r--src/common/fiber.h19
-rw-r--r--src/tests/CMakeLists.txt1
-rw-r--r--src/tests/common/fibers.cpp214
4 files changed, 247 insertions, 19 deletions
diff --git a/src/common/fiber.cpp b/src/common/fiber.cpp
index eb59f1aa9..a2c0401c4 100644
--- a/src/common/fiber.cpp
+++ b/src/common/fiber.cpp
@@ -3,18 +3,21 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include "common/fiber.h" 5#include "common/fiber.h"
6#ifdef _MSC_VER
7#include <windows.h>
8#else
9#include <boost/context/detail/fcontext.hpp>
10#endif
6 11
7namespace Common { 12namespace Common {
8 13
9#ifdef _MSC_VER 14#ifdef _MSC_VER
10#include <windows.h>
11 15
12struct Fiber::FiberImpl { 16struct Fiber::FiberImpl {
13 LPVOID handle = nullptr; 17 LPVOID handle = nullptr;
14}; 18};
15 19
16void Fiber::_start([[maybe_unused]] void* parameter) { 20void Fiber::start() {
17 guard.lock();
18 if (previous_fiber) { 21 if (previous_fiber) {
19 previous_fiber->guard.unlock(); 22 previous_fiber->guard.unlock();
20 previous_fiber = nullptr; 23 previous_fiber = nullptr;
@@ -22,10 +25,10 @@ void Fiber::_start([[maybe_unused]] void* parameter) {
22 entry_point(start_parameter); 25 entry_point(start_parameter);
23} 26}
24 27
25static void __stdcall FiberStartFunc(LPVOID lpFiberParameter) 28void __stdcall Fiber::FiberStartFunc(void* fiber_parameter)
26{ 29{
27 auto fiber = static_cast<Fiber *>(lpFiberParameter); 30 auto fiber = static_cast<Fiber *>(fiber_parameter);
28 fiber->_start(nullptr); 31 fiber->start();
29} 32}
30 33
31Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) 34Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter)
@@ -74,30 +77,26 @@ std::shared_ptr<Fiber> Fiber::ThreadToFiber() {
74 77
75#else 78#else
76 79
77#include <boost/context/detail/fcontext.hpp>
78
79constexpr std::size_t default_stack_size = 1024 * 1024 * 4; // 4MB 80constexpr std::size_t default_stack_size = 1024 * 1024 * 4; // 4MB
80 81
81struct Fiber::FiberImpl { 82struct alignas(64) Fiber::FiberImpl {
82 boost::context::detail::fcontext_t context;
83 std::array<u8, default_stack_size> stack; 83 std::array<u8, default_stack_size> stack;
84 boost::context::detail::fcontext_t context;
84}; 85};
85 86
86void Fiber::_start(void* parameter) { 87void Fiber::start(boost::context::detail::transfer_t& transfer) {
87 guard.lock();
88 boost::context::detail::transfer_t* transfer = static_cast<boost::context::detail::transfer_t*>(parameter);
89 if (previous_fiber) { 88 if (previous_fiber) {
90 previous_fiber->impl->context = transfer->fctx; 89 previous_fiber->impl->context = transfer.fctx;
91 previous_fiber->guard.unlock(); 90 previous_fiber->guard.unlock();
92 previous_fiber = nullptr; 91 previous_fiber = nullptr;
93 } 92 }
94 entry_point(start_parameter); 93 entry_point(start_parameter);
95} 94}
96 95
97static void FiberStartFunc(boost::context::detail::transfer_t transfer) 96void Fiber::FiberStartFunc(boost::context::detail::transfer_t transfer)
98{ 97{
99 auto fiber = static_cast<Fiber *>(transfer.data); 98 auto fiber = static_cast<Fiber *>(transfer.data);
100 fiber->_start(&transfer); 99 fiber->start(transfer);
101} 100}
102 101
103Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) 102Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter)
@@ -139,6 +138,7 @@ void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) {
139 138
140std::shared_ptr<Fiber> Fiber::ThreadToFiber() { 139std::shared_ptr<Fiber> Fiber::ThreadToFiber() {
141 std::shared_ptr<Fiber> fiber = std::shared_ptr<Fiber>{new Fiber()}; 140 std::shared_ptr<Fiber> fiber = std::shared_ptr<Fiber>{new Fiber()};
141 fiber->guard.lock();
142 fiber->is_thread_fiber = true; 142 fiber->is_thread_fiber = true;
143 return fiber; 143 return fiber;
144} 144}
diff --git a/src/common/fiber.h b/src/common/fiber.h
index ab44905cf..812d6644a 100644
--- a/src/common/fiber.h
+++ b/src/common/fiber.h
@@ -10,6 +10,12 @@
10#include "common/common_types.h" 10#include "common/common_types.h"
11#include "common/spin_lock.h" 11#include "common/spin_lock.h"
12 12
13#ifndef _MSC_VER
14namespace boost::context::detail {
15 struct transfer_t;
16}
17#endif
18
13namespace Common { 19namespace Common {
14 20
15class Fiber { 21class Fiber {
@@ -31,9 +37,6 @@ public:
31 /// Only call from main thread's fiber 37 /// Only call from main thread's fiber
32 void Exit(); 38 void Exit();
33 39
34 /// Used internally but required to be public, Shall not be used
35 void _start(void* parameter);
36
37 /// Changes the start parameter of the fiber. Has no effect if the fiber already started 40 /// Changes the start parameter of the fiber. Has no effect if the fiber already started
38 void SetStartParameter(void* new_parameter) { 41 void SetStartParameter(void* new_parameter) {
39 start_parameter = new_parameter; 42 start_parameter = new_parameter;
@@ -42,6 +45,16 @@ public:
42private: 45private:
43 Fiber(); 46 Fiber();
44 47
48#ifdef _MSC_VER
49 void start();
50 static void FiberStartFunc(void* fiber_parameter);
51#else
52 void start(boost::context::detail::transfer_t& transfer);
53 static void FiberStartFunc(boost::context::detail::transfer_t transfer);
54#endif
55
56
57
45 struct FiberImpl; 58 struct FiberImpl;
46 59
47 SpinLock guard; 60 SpinLock guard;
diff --git a/src/tests/CMakeLists.txt b/src/tests/CMakeLists.txt
index c7038b217..47ef30aa9 100644
--- a/src/tests/CMakeLists.txt
+++ b/src/tests/CMakeLists.txt
@@ -1,6 +1,7 @@
1add_executable(tests 1add_executable(tests
2 common/bit_field.cpp 2 common/bit_field.cpp
3 common/bit_utils.cpp 3 common/bit_utils.cpp
4 common/fibers.cpp
4 common/multi_level_queue.cpp 5 common/multi_level_queue.cpp
5 common/param_package.cpp 6 common/param_package.cpp
6 common/ring_buffer.cpp 7 common/ring_buffer.cpp
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp
new file mode 100644
index 000000000..ff840afa6
--- /dev/null
+++ b/src/tests/common/fibers.cpp
@@ -0,0 +1,214 @@
1// Copyright 2020 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include <atomic>
6#include <cstdlib>
7#include <functional>
8#include <memory>
9#include <thread>
10#include <unordered_map>
11#include <vector>
12
13#include <catch2/catch.hpp>
14#include <math.h>
15#include "common/common_types.h"
16#include "common/fiber.h"
17#include "common/spin_lock.h"
18
19namespace Common {
20
21class TestControl1 {
22public:
23 TestControl1() = default;
24
25 void DoWork();
26
27 void ExecuteThread(u32 id);
28
29 std::unordered_map<std::thread::id, u32> ids;
30 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
31 std::vector<std::shared_ptr<Common::Fiber>> work_fibers;
32 std::vector<u32> items;
33 std::vector<u32> results;
34};
35
36static void WorkControl1(void* control) {
37 TestControl1* test_control = static_cast<TestControl1*>(control);
38 test_control->DoWork();
39}
40
41void TestControl1::DoWork() {
42 std::thread::id this_id = std::this_thread::get_id();
43 u32 id = ids[this_id];
44 u32 value = items[id];
45 for (u32 i = 0; i < id; i++) {
46 value++;
47 }
48 results[id] = value;
49 Fiber::YieldTo(work_fibers[id], thread_fibers[id]);
50}
51
52void TestControl1::ExecuteThread(u32 id) {
53 std::thread::id this_id = std::this_thread::get_id();
54 ids[this_id] = id;
55 auto thread_fiber = Fiber::ThreadToFiber();
56 thread_fibers[id] = thread_fiber;
57 work_fibers[id] = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl1}, this);
58 items[id] = rand() % 256;
59 Fiber::YieldTo(thread_fibers[id], work_fibers[id]);
60 thread_fibers[id]->Exit();
61}
62
63static void ThreadStart1(u32 id, TestControl1& test_control) {
64 test_control.ExecuteThread(id);
65}
66
67
68TEST_CASE("Fibers::Setup", "[common]") {
69 constexpr u32 num_threads = 7;
70 TestControl1 test_control{};
71 test_control.thread_fibers.resize(num_threads, nullptr);
72 test_control.work_fibers.resize(num_threads, nullptr);
73 test_control.items.resize(num_threads, 0);
74 test_control.results.resize(num_threads, 0);
75 std::vector<std::thread> threads;
76 for (u32 i = 0; i < num_threads; i++) {
77 threads.emplace_back(ThreadStart1, i, std::ref(test_control));
78 }
79 for (u32 i = 0; i < num_threads; i++) {
80 threads[i].join();
81 }
82 for (u32 i = 0; i < num_threads; i++) {
83 REQUIRE(test_control.items[i] + i == test_control.results[i]);
84 }
85}
86
87class TestControl2 {
88public:
89 TestControl2() = default;
90
91 void DoWork1() {
92 trap2 = false;
93 while (trap.load());
94 for (u32 i = 0; i < 12000; i++) {
95 value1 += i;
96 }
97 Fiber::YieldTo(fiber1, fiber3);
98 std::thread::id this_id = std::this_thread::get_id();
99 u32 id = ids[this_id];
100 assert1 = id == 1;
101 value2 += 5000;
102 Fiber::YieldTo(fiber1, thread_fibers[id]);
103 }
104
105 void DoWork2() {
106 while (trap2.load());
107 value2 = 2000;
108 trap = false;
109 Fiber::YieldTo(fiber2, fiber1);
110 assert3 = false;
111 }
112
113 void DoWork3() {
114 std::thread::id this_id = std::this_thread::get_id();
115 u32 id = ids[this_id];
116 assert2 = id == 0;
117 value1 += 1000;
118 Fiber::YieldTo(fiber3, thread_fibers[id]);
119 }
120
121 void ExecuteThread(u32 id);
122
123 void CallFiber1() {
124 std::thread::id this_id = std::this_thread::get_id();
125 u32 id = ids[this_id];
126 Fiber::YieldTo(thread_fibers[id], fiber1);
127 }
128
129 void CallFiber2() {
130 std::thread::id this_id = std::this_thread::get_id();
131 u32 id = ids[this_id];
132 Fiber::YieldTo(thread_fibers[id], fiber2);
133 }
134
135 void Exit();
136
137 bool assert1{};
138 bool assert2{};
139 bool assert3{true};
140 u32 value1{};
141 u32 value2{};
142 std::atomic<bool> trap{true};
143 std::atomic<bool> trap2{true};
144 std::unordered_map<std::thread::id, u32> ids;
145 std::vector<std::shared_ptr<Common::Fiber>> thread_fibers;
146 std::shared_ptr<Common::Fiber> fiber1;
147 std::shared_ptr<Common::Fiber> fiber2;
148 std::shared_ptr<Common::Fiber> fiber3;
149};
150
151static void WorkControl2_1(void* control) {
152 TestControl2* test_control = static_cast<TestControl2*>(control);
153 test_control->DoWork1();
154}
155
156static void WorkControl2_2(void* control) {
157 TestControl2* test_control = static_cast<TestControl2*>(control);
158 test_control->DoWork2();
159}
160
161static void WorkControl2_3(void* control) {
162 TestControl2* test_control = static_cast<TestControl2*>(control);
163 test_control->DoWork3();
164}
165
166void TestControl2::ExecuteThread(u32 id) {
167 std::thread::id this_id = std::this_thread::get_id();
168 ids[this_id] = id;
169 auto thread_fiber = Fiber::ThreadToFiber();
170 thread_fibers[id] = thread_fiber;
171}
172
173void TestControl2::Exit() {
174 std::thread::id this_id = std::this_thread::get_id();
175 u32 id = ids[this_id];
176 thread_fibers[id]->Exit();
177}
178
179static void ThreadStart2_1(u32 id, TestControl2& test_control) {
180 test_control.ExecuteThread(id);
181 test_control.CallFiber1();
182 test_control.Exit();
183}
184
185static void ThreadStart2_2(u32 id, TestControl2& test_control) {
186 test_control.ExecuteThread(id);
187 test_control.CallFiber2();
188 test_control.Exit();
189}
190
191TEST_CASE("Fibers::InterExchange", "[common]") {
192 TestControl2 test_control{};
193 test_control.thread_fibers.resize(2, nullptr);
194 test_control.fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_1}, &test_control);
195 test_control.fiber2 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_2}, &test_control);
196 test_control.fiber3 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl2_3}, &test_control);
197 std::thread thread1(ThreadStart2_1, 0, std::ref(test_control));
198 std::thread thread2(ThreadStart2_2, 1, std::ref(test_control));
199 thread1.join();
200 thread2.join();
201 REQUIRE(test_control.assert1);
202 REQUIRE(test_control.assert2);
203 REQUIRE(test_control.assert3);
204 REQUIRE(test_control.value2 == 7000);
205 u32 cal_value = 0;
206 for (u32 i = 0; i < 12000; i++) {
207 cal_value += i;
208 }
209 cal_value += 1000;
210 REQUIRE(test_control.value1 == cal_value);
211}
212
213
214} // namespace Common