diff options
| -rw-r--r-- | src/common/fiber.cpp | 32 | ||||
| -rw-r--r-- | src/common/fiber.h | 8 | ||||
| -rw-r--r-- | src/tests/common/fibers.cpp | 46 |
3 files changed, 84 insertions, 2 deletions
diff --git a/src/common/fiber.cpp b/src/common/fiber.cpp index e4ecc73df..f61479e13 100644 --- a/src/common/fiber.cpp +++ b/src/common/fiber.cpp | |||
| @@ -12,10 +12,13 @@ | |||
| 12 | 12 | ||
| 13 | namespace Common { | 13 | namespace Common { |
| 14 | 14 | ||
| 15 | constexpr std::size_t default_stack_size = 256 * 1024; // 256kb | ||
| 16 | |||
| 15 | #if defined(_WIN32) || defined(WIN32) | 17 | #if defined(_WIN32) || defined(WIN32) |
| 16 | 18 | ||
| 17 | struct Fiber::FiberImpl { | 19 | struct Fiber::FiberImpl { |
| 18 | LPVOID handle = nullptr; | 20 | LPVOID handle = nullptr; |
| 21 | LPVOID rewind_handle = nullptr; | ||
| 19 | }; | 22 | }; |
| 20 | 23 | ||
| 21 | void Fiber::start() { | 24 | void Fiber::start() { |
| @@ -26,15 +29,29 @@ void Fiber::start() { | |||
| 26 | UNREACHABLE(); | 29 | UNREACHABLE(); |
| 27 | } | 30 | } |
| 28 | 31 | ||
| 32 | void Fiber::onRewind() { | ||
| 33 | ASSERT(impl->handle != nullptr); | ||
| 34 | DeleteFiber(impl->handle); | ||
| 35 | impl->handle = impl->rewind_handle; | ||
| 36 | impl->rewind_handle = nullptr; | ||
| 37 | rewind_point(rewind_parameter); | ||
| 38 | UNREACHABLE(); | ||
| 39 | } | ||
| 40 | |||
| 29 | void __stdcall Fiber::FiberStartFunc(void* fiber_parameter) { | 41 | void __stdcall Fiber::FiberStartFunc(void* fiber_parameter) { |
| 30 | auto fiber = static_cast<Fiber*>(fiber_parameter); | 42 | auto fiber = static_cast<Fiber*>(fiber_parameter); |
| 31 | fiber->start(); | 43 | fiber->start(); |
| 32 | } | 44 | } |
| 33 | 45 | ||
| 46 | void __stdcall Fiber::RewindStartFunc(void* fiber_parameter) { | ||
| 47 | auto fiber = static_cast<Fiber*>(fiber_parameter); | ||
| 48 | fiber->onRewind(); | ||
| 49 | } | ||
| 50 | |||
| 34 | Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) | 51 | Fiber::Fiber(std::function<void(void*)>&& entry_point_func, void* start_parameter) |
| 35 | : entry_point{std::move(entry_point_func)}, start_parameter{start_parameter} { | 52 | : entry_point{std::move(entry_point_func)}, start_parameter{start_parameter} { |
| 36 | impl = std::make_unique<FiberImpl>(); | 53 | impl = std::make_unique<FiberImpl>(); |
| 37 | impl->handle = CreateFiber(0, &FiberStartFunc, this); | 54 | impl->handle = CreateFiber(default_stack_size, &FiberStartFunc, this); |
| 38 | } | 55 | } |
| 39 | 56 | ||
| 40 | Fiber::Fiber() { | 57 | Fiber::Fiber() { |
| @@ -60,6 +77,18 @@ void Fiber::Exit() { | |||
| 60 | guard.unlock(); | 77 | guard.unlock(); |
| 61 | } | 78 | } |
| 62 | 79 | ||
| 80 | void Fiber::SetRewindPoint(std::function<void(void*)>&& rewind_func, void* start_parameter) { | ||
| 81 | rewind_point = std::move(rewind_func); | ||
| 82 | rewind_parameter = start_parameter; | ||
| 83 | } | ||
| 84 | |||
| 85 | void Fiber::Rewind() { | ||
| 86 | ASSERT(rewind_point); | ||
| 87 | ASSERT(impl->rewind_handle == nullptr); | ||
| 88 | impl->rewind_handle = CreateFiber(default_stack_size, &RewindStartFunc, this); | ||
| 89 | SwitchToFiber(impl->rewind_handle); | ||
| 90 | } | ||
| 91 | |||
| 63 | void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) { | 92 | void Fiber::YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to) { |
| 64 | ASSERT_MSG(from != nullptr, "Yielding fiber is null!"); | 93 | ASSERT_MSG(from != nullptr, "Yielding fiber is null!"); |
| 65 | ASSERT_MSG(to != nullptr, "Next fiber is null!"); | 94 | ASSERT_MSG(to != nullptr, "Next fiber is null!"); |
| @@ -81,7 +110,6 @@ std::shared_ptr<Fiber> Fiber::ThreadToFiber() { | |||
| 81 | } | 110 | } |
| 82 | 111 | ||
| 83 | #else | 112 | #else |
| 84 | constexpr std::size_t default_stack_size = 1024 * 1024; // 1MB | ||
| 85 | 113 | ||
| 86 | struct Fiber::FiberImpl { | 114 | struct Fiber::FiberImpl { |
| 87 | alignas(64) std::array<u8, default_stack_size> stack; | 115 | alignas(64) std::array<u8, default_stack_size> stack; |
diff --git a/src/common/fiber.h b/src/common/fiber.h index 7e3b130a4..a710df257 100644 --- a/src/common/fiber.h +++ b/src/common/fiber.h | |||
| @@ -46,6 +46,10 @@ public: | |||
| 46 | static void YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to); | 46 | static void YieldTo(std::shared_ptr<Fiber> from, std::shared_ptr<Fiber> to); |
| 47 | static std::shared_ptr<Fiber> ThreadToFiber(); | 47 | static std::shared_ptr<Fiber> ThreadToFiber(); |
| 48 | 48 | ||
| 49 | void SetRewindPoint(std::function<void(void*)>&& rewind_func, void* start_parameter); | ||
| 50 | |||
| 51 | void Rewind(); | ||
| 52 | |||
| 49 | /// Only call from main thread's fiber | 53 | /// Only call from main thread's fiber |
| 50 | void Exit(); | 54 | void Exit(); |
| 51 | 55 | ||
| @@ -58,8 +62,10 @@ private: | |||
| 58 | Fiber(); | 62 | Fiber(); |
| 59 | 63 | ||
| 60 | #if defined(_WIN32) || defined(WIN32) | 64 | #if defined(_WIN32) || defined(WIN32) |
| 65 | void onRewind(); | ||
| 61 | void start(); | 66 | void start(); |
| 62 | static void FiberStartFunc(void* fiber_parameter); | 67 | static void FiberStartFunc(void* fiber_parameter); |
| 68 | static void RewindStartFunc(void* fiber_parameter); | ||
| 63 | #else | 69 | #else |
| 64 | void start(boost::context::detail::transfer_t& transfer); | 70 | void start(boost::context::detail::transfer_t& transfer); |
| 65 | static void FiberStartFunc(boost::context::detail::transfer_t transfer); | 71 | static void FiberStartFunc(boost::context::detail::transfer_t transfer); |
| @@ -69,6 +75,8 @@ private: | |||
| 69 | 75 | ||
| 70 | SpinLock guard{}; | 76 | SpinLock guard{}; |
| 71 | std::function<void(void*)> entry_point{}; | 77 | std::function<void(void*)> entry_point{}; |
| 78 | std::function<void(void*)> rewind_point{}; | ||
| 79 | void* rewind_parameter{}; | ||
| 72 | void* start_parameter{}; | 80 | void* start_parameter{}; |
| 73 | std::shared_ptr<Fiber> previous_fiber{}; | 81 | std::shared_ptr<Fiber> previous_fiber{}; |
| 74 | std::unique_ptr<FiberImpl> impl; | 82 | std::unique_ptr<FiberImpl> impl; |
diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp index 0d3d5153d..12536b6d8 100644 --- a/src/tests/common/fibers.cpp +++ b/src/tests/common/fibers.cpp | |||
| @@ -309,4 +309,50 @@ TEST_CASE("Fibers::StartRace", "[common]") { | |||
| 309 | REQUIRE(test_control.value3 == 1); | 309 | REQUIRE(test_control.value3 == 1); |
| 310 | } | 310 | } |
| 311 | 311 | ||
| 312 | class TestControl4; | ||
| 313 | |||
| 314 | static void WorkControl4(void* control); | ||
| 315 | |||
| 316 | class TestControl4 { | ||
| 317 | public: | ||
| 318 | TestControl4() { | ||
| 319 | fiber1 = std::make_shared<Fiber>(std::function<void(void*)>{WorkControl4}, this); | ||
| 320 | goal_reached = false; | ||
| 321 | rewinded = false; | ||
| 322 | } | ||
| 323 | |||
| 324 | void Execute() { | ||
| 325 | thread_fiber = Fiber::ThreadToFiber(); | ||
| 326 | Fiber::YieldTo(thread_fiber, fiber1); | ||
| 327 | thread_fiber->Exit(); | ||
| 328 | } | ||
| 329 | |||
| 330 | void DoWork() { | ||
| 331 | fiber1->SetRewindPoint(std::function<void(void*)>{WorkControl4}, this); | ||
| 332 | if (rewinded) { | ||
| 333 | goal_reached = true; | ||
| 334 | Fiber::YieldTo(fiber1, thread_fiber); | ||
| 335 | } | ||
| 336 | rewinded = true; | ||
| 337 | fiber1->Rewind(); | ||
| 338 | } | ||
| 339 | |||
| 340 | std::shared_ptr<Common::Fiber> fiber1; | ||
| 341 | std::shared_ptr<Common::Fiber> thread_fiber; | ||
| 342 | bool goal_reached; | ||
| 343 | bool rewinded; | ||
| 344 | }; | ||
| 345 | |||
| 346 | static void WorkControl4(void* control) { | ||
| 347 | auto* test_control = static_cast<TestControl4*>(control); | ||
| 348 | test_control->DoWork(); | ||
| 349 | } | ||
| 350 | |||
| 351 | TEST_CASE("Fibers::Rewind", "[common]") { | ||
| 352 | TestControl4 test_control{}; | ||
| 353 | test_control.Execute(); | ||
| 354 | REQUIRE(test_control.goal_reached); | ||
| 355 | REQUIRE(test_control.rewinded); | ||
| 356 | } | ||
| 357 | |||
| 312 | } // namespace Common | 358 | } // namespace Common |