diff --git a/trpc/coroutine/BUILD b/trpc/coroutine/BUILD index e09fbaa3..2229cf3b 100644 --- a/trpc/coroutine/BUILD +++ b/trpc/coroutine/BUILD @@ -233,6 +233,7 @@ cc_test( srcs = ["fiber_test.cc"], deps = [ "//trpc/coroutine/testing:fiber_runtime_test", + "//trpc/util:latch", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], diff --git a/trpc/coroutine/fiber.cc b/trpc/coroutine/fiber.cc index 9a1decae..60c23444 100644 --- a/trpc/coroutine/fiber.cc +++ b/trpc/coroutine/fiber.cc @@ -10,6 +10,7 @@ #include "trpc/coroutine/fiber.h" +#include #include #include @@ -151,8 +152,18 @@ void FiberYield() { } void FiberSleepUntil(const std::chrono::steady_clock::time_point& expires_at) { - fiber::detail::WaitableTimer wt(expires_at); - wt.wait(); + if (trpc::fiber::detail::IsFiberContextPresent()) { + fiber::detail::WaitableTimer wt(expires_at); + wt.wait(); + return; + } + + auto now = ReadSteadyClock(); + if (expires_at <= now) { + return; + } + + std::this_thread::sleep_for(expires_at - now); } void FiberSleepFor(const std::chrono::nanoseconds& expires_in) { diff --git a/trpc/coroutine/fiber.h b/trpc/coroutine/fiber.h index 6f3f326d..5f9415b9 100644 --- a/trpc/coroutine/fiber.h +++ b/trpc/coroutine/fiber.h @@ -183,23 +183,23 @@ bool BatchStartFiberDetached(std::vector>&& start_procs); /// @note It only uses in fiber runtime. void FiberYield(); -/// @brief Block calling fiber until `expires_at`. -/// @note It only uses in fiber runtime. +/// @brief Block calling pthread or calling fiber until `expires_at`. +/// @note It can be used in pthread context and fiber context. void FiberSleepUntil(const std::chrono::steady_clock::time_point& expires_at); -/// @brief Block calling fiber for `expires_in`. -/// @note It only uses in fiber runtime. +/// @brief Block calling pthread or calling fiber for `expires_in`. +/// @note It can be used in pthread context and fiber context. void FiberSleepFor(const std::chrono::nanoseconds& expires_in); /// @brief `SleepUntil` for clocks other than `std::steady_clock`. -/// @note It only uses in fiber runtime. +/// @note It can be used in pthread context and fiber context. template void FiberSleepUntil(const std::chrono::time_point& expires_at) { return FiberSleepUntil(ReadSteadyClock() + (expires_at - Clock::now())); } /// @brief `SleepFor` for durations other than `std::chrono::nanoseconds`. -/// @note It only uses in fiber runtime. +/// @note It can be used in pthread context and fiber context. template void FiberSleepFor(const std::chrono::duration& expires_in) { return FiberSleepFor(static_cast(expires_in)); diff --git a/trpc/coroutine/fiber_condition_variable.h b/trpc/coroutine/fiber_condition_variable.h index 4d8aa964..f274ee66 100644 --- a/trpc/coroutine/fiber_condition_variable.h +++ b/trpc/coroutine/fiber_condition_variable.h @@ -18,8 +18,7 @@ namespace trpc { -/// @brief Analogous to `std::condition_variable`, but it's for fiber. -/// @note It only uses in fiber runtime. +/// @brief Adaptive condition variable primitive for both fiber and pthread context. class FiberConditionVariable { public: /// @brief Wake up one waiter. diff --git a/trpc/coroutine/fiber_condition_variable_test.cc b/trpc/coroutine/fiber_condition_variable_test.cc index 5d9778a8..a18fd1f5 100644 --- a/trpc/coroutine/fiber_condition_variable_test.cc +++ b/trpc/coroutine/fiber_condition_variable_test.cc @@ -22,7 +22,7 @@ namespace trpc { -TEST(FiberConditionVariable, All) { +TEST(FiberConditionVariable, UseInFiberContext) { RunAsFiber([] { for (int k = 0; k != 10; ++k) { constexpr auto N = 600; @@ -64,4 +64,125 @@ TEST(FiberConditionVariable, All) { }); } +TEST(FiberConditionVariable, UseInPthreadContext) { + constexpr auto N = 64; + std::atomic run{0}; + FiberMutex lock[N]; + FiberConditionVariable cv[N]; + bool set[N] = {false}; + std::vector prod(N); + std::vector cons(N); + + for (int i = 0; i != N; ++i) { + prod[i] = std::thread([&run, i, &cv, &lock, &set] { + FiberSleepFor(Random(20) * std::chrono::milliseconds(1)); + std::unique_lock lk(lock[i]); + cv[i].wait(lk, [&] { return set[i]; }); + ++run; + }); + + cons[i] = std::thread([&run, i, &cv, &lock, &set] { + FiberSleepFor(Random(20) * std::chrono::milliseconds(1)); + std::scoped_lock _(lock[i]); + set[i] = true; + cv[i].notify_one(); + ++run; + }); + } + + for (auto&& e : prod) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + for (auto&& e : cons) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + ASSERT_EQ(N * 2, run); +} + +TEST(FiberConditionVariable, NotifyPthreadFromFiber) { + RunAsFiber([] { + constexpr auto N = 64; + std::atomic run{0}; + FiberMutex lock[N]; + FiberConditionVariable cv[N]; + bool set[N] = {false}; + std::vector prod(N); + std::vector cons(N); + + for (int i = 0; i != N; ++i) { + prod[i] = std::thread([&run, i, &cv, &lock, &set] { + FiberSleepFor(Random(20) * std::chrono::milliseconds(1)); + std::unique_lock lk(lock[i]); + cv[i].wait(lk, [&] { return set[i]; }); + ++run; + }); + + cons[i] = Fiber([&run, i, &cv, &lock, &set] { + FiberSleepFor(Random(20) * std::chrono::milliseconds(1)); + std::scoped_lock _(lock[i]); + set[i] = true; + cv[i].notify_one(); + ++run; + }); + } + + for (auto&& e : prod) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + for (auto&& e : cons) { + ASSERT_TRUE(e.Joinable()); + e.Join(); + } + + ASSERT_EQ(N * 2, run); + }); +} + +TEST(FiberConditionVariable, NotifyFiberFromPthread) { + RunAsFiber([] { + constexpr auto N = 64; + std::atomic run{0}; + FiberMutex lock[N]; + FiberConditionVariable cv[N]; + bool set[N] = {false}; + std::vector prod(N); + std::vector cons(N); + + for (int i = 0; i != N; ++i) { + prod[i] = Fiber([&run, i, &cv, &lock, &set] { + FiberSleepFor(Random(20) * std::chrono::milliseconds(1)); + std::unique_lock lk(lock[i]); + cv[i].wait(lk, [&] { return set[i]; }); + ++run; + }); + + cons[i] = std::thread([&run, i, &cv, &lock, &set] { + FiberSleepFor(Random(20) * std::chrono::milliseconds(1)); + std::scoped_lock _(lock[i]); + set[i] = true; + cv[i].notify_one(); + ++run; + }); + } + + for (auto&& e : prod) { + ASSERT_TRUE(e.Joinable()); + e.Join(); + } + + for (auto&& e : cons) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + ASSERT_EQ(N * 2, run); + }); +} + } // namespace trpc diff --git a/trpc/coroutine/fiber_event.h b/trpc/coroutine/fiber_event.h index f0c6a2ab..69864aa1 100644 --- a/trpc/coroutine/fiber_event.h +++ b/trpc/coroutine/fiber_event.h @@ -14,16 +14,14 @@ namespace trpc { -/// @brief Event for fiber. +/// @brief Adaptive event primitive for both fiber and pthread context. class FiberEvent { public: /// @brief Wait until `Set()` is called. /// If `Set()` is called before `Wait()`, this method returns immediately. - /// @note This method only uses in fiber runtime. void Wait() { event_.Wait(); } - /// @brief Wake up fibers blockings on `Wait()`. - /// @note You can call this method outside of fiber runtime. + /// @brief Wake up fibers and pthreads blockings on `Wait()`. void Set() { event_.Set(); } private: diff --git a/trpc/coroutine/fiber_event_test.cc b/trpc/coroutine/fiber_event_test.cc index 8bc2ac32..568a56f6 100644 --- a/trpc/coroutine/fiber_event_test.cc +++ b/trpc/coroutine/fiber_event_test.cc @@ -16,19 +16,70 @@ #include "gtest/gtest.h" +#include "trpc/coroutine/fiber.h" #include "trpc/coroutine/testing/fiber_runtime.h" +#include "trpc/util/algorithm/random.h" namespace trpc::testing { -TEST(FiberEvent, EventOnWakeup) { +TEST(FiberEvent, WaitInFiberSetFromPthread) { RunAsFiber([]() { for (int i = 0; i != 1000; ++i) { - auto ev = std::make_unique(); - std::thread t([&] { ev->Set(); }); + auto ev = std::make_unique(); + std::thread t([&] { + // Random sleep to make Set before Wait or after Wait. + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Set(); + }); + // Random sleep to make Wait return immediately or awakened by Set. + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); ev->Wait(); t.join(); } }); } +TEST(FiberEvent, WaitInFiberSetFromFiber) { + RunAsFiber([]() { + for (int i = 0; i != 1000; ++i) { + auto ev = std::make_unique(); + trpc::Fiber f = trpc::Fiber([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Wait(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Set(); + f.Join(); + } + }); +} + +TEST(FiberEvent, WaitInPthreadSetFromFiber) { + RunAsFiber([]() { + for (int i = 0; i != 1000; ++i) { + auto ev = std::make_unique(); + std::thread t([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Wait(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Set(); + t.join(); + } + }); +} + +TEST(FiberEvent, WaitInPthreadSetFromPthread) { + for (int i = 0; i != 1000; ++i) { + auto ev = std::make_unique(); + std::thread t([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Wait(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + ev->Set(); + t.join(); + } +} + } // namespace trpc::testing diff --git a/trpc/coroutine/fiber_latch.h b/trpc/coroutine/fiber_latch.h index a20d2b5f..6f3928a1 100644 --- a/trpc/coroutine/fiber_latch.h +++ b/trpc/coroutine/fiber_latch.h @@ -16,8 +16,7 @@ namespace trpc { -/// @brief Analogous to `std::latch`, but it's for fiber. -/// @note It only uses in fiber runtime. +/// @brief Adaptive latch primitive for both fiber and pthread context. class FiberLatch { public: explicit FiberLatch(std::ptrdiff_t count); diff --git a/trpc/coroutine/fiber_latch_test.cc b/trpc/coroutine/fiber_latch_test.cc index f4393e96..27d4cf3c 100644 --- a/trpc/coroutine/fiber_latch_test.cc +++ b/trpc/coroutine/fiber_latch_test.cc @@ -19,59 +19,86 @@ #include "trpc/coroutine/fiber.h" #include "trpc/coroutine/testing/fiber_runtime.h" +#include "trpc/util/algorithm/random.h" #include "trpc/util/chrono/chrono.h" namespace trpc { -std::atomic exiting{false}; - -void RunTest() { - std::atomic local_count = 0, remote_count = 0; - while (!exiting) { - FiberLatch l(1); - auto called = std::make_shared>(false); - Fiber([called, &l, &remote_count] { - if (!called->exchange(true)) { - FiberYield(); - l.CountDown(); - ++remote_count; - } - }).Detach(); - FiberYield(); - if (!called->exchange(true)) { - l.CountDown(); - ++local_count; +TEST(FiberLatch, WaitInFiberCountDownFromFiber) { + RunAsFiber([]() { + for (int i = 0; i != 100; ++i) { + auto fl = std::make_unique(1); + trpc::Fiber fb = trpc::Fiber([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->Wait(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->CountDown(); + fb.Join(); } - l.Wait(); - } - std::cout << local_count << " " << remote_count << std::endl; + }); } -TEST(Latch, Torture) { - RunAsFiber([] { - Fiber fs[10]; - for (auto&& f : fs) { - f = Fiber(RunTest); +TEST(FiberLatch, WaitInFiberCountDownFromPthread) { + RunAsFiber([]() { + for (int i = 0; i != 100; ++i) { + auto fl = std::make_unique(1); + std::thread t([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->CountDown(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->Wait(); + t.join(); } - std::this_thread::sleep_for(std::chrono::seconds(1)); - exiting = true; - for (auto&& f : fs) { - f.Join(); + }); +} + +TEST(FiberLatch, WaitInPthreadCountDownFromFiber) { + RunAsFiber([]() { + for (int i = 0; i != 100; ++i) { + auto fl = std::make_unique(1); + std::thread t([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->Wait(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->CountDown(); + t.join(); } }); } -TEST(Latch, CountDownTwo) { +TEST(FiberLatch, WaitInPthreadCountDownFromPthread) { + for (int i = 0; i != 100; ++i) { + auto fl = std::make_unique(1); + std::thread t([&] { + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->Wait(); + }); + trpc::FiberSleepFor(Random(10) * std::chrono::milliseconds(1)); + fl->CountDown(); + t.join(); + } +} + +TEST(FiberLatch, CountDownTwoInFiber) { RunAsFiber([] { FiberLatch l(2); ASSERT_FALSE(l.TryWait()); l.ArriveAndWait(2); - ASSERT_TRUE(1); ASSERT_TRUE(l.TryWait()); }); } -TEST(Latch, WaitFor) { +TEST(FiberLatch, CountDownTwoInPthread) { + FiberLatch l(2); + ASSERT_FALSE(l.TryWait()); + l.ArriveAndWait(2); + ASSERT_TRUE(l.TryWait()); +} + +TEST(FiberLatch, WaitForInFiber) { RunAsFiber([] { FiberLatch l(1); ASSERT_FALSE(l.WaitFor(std::chrono::milliseconds(100))); @@ -80,7 +107,14 @@ TEST(Latch, WaitFor) { }); } -TEST(Latch, WaitUntil) { +TEST(FiberLatch, WaitForInPthread) { + FiberLatch l(1); + ASSERT_FALSE(l.WaitFor(std::chrono::milliseconds(100))); + l.CountDown(); + ASSERT_TRUE(l.WaitFor(std::chrono::milliseconds(0))); +} + +TEST(FiberLatch, WaitUntilInFiber) { RunAsFiber([] { FiberLatch l(1); ASSERT_FALSE(l.WaitUntil(ReadSteadyClock() + std::chrono::milliseconds(100))); @@ -89,4 +123,11 @@ TEST(Latch, WaitUntil) { }); } +TEST(FiberLatch, WaitUntilInPthread) { + FiberLatch l(1); + ASSERT_FALSE(l.WaitUntil(ReadSteadyClock() + std::chrono::milliseconds(100))); + l.CountDown(); + ASSERT_TRUE(l.WaitUntil(ReadSteadyClock())); +} + } // namespace trpc diff --git a/trpc/coroutine/fiber_mutex.h b/trpc/coroutine/fiber_mutex.h index e1c682c4..dd05f0c6 100644 --- a/trpc/coroutine/fiber_mutex.h +++ b/trpc/coroutine/fiber_mutex.h @@ -14,8 +14,7 @@ namespace trpc { -/// @brief Analogous to `std::mutex`, but it's for fiber. -/// @note It only uses in fiber runtime. +/// @brief Adaptive mutex primitive for both fiber and pthread context. using FiberMutex = ::trpc::fiber::detail::Mutex; } // namespace trpc diff --git a/trpc/coroutine/fiber_shared_mutex.h b/trpc/coroutine/fiber_shared_mutex.h index 4240c2f4..97f3870b 100644 --- a/trpc/coroutine/fiber_shared_mutex.h +++ b/trpc/coroutine/fiber_shared_mutex.h @@ -19,7 +19,7 @@ namespace trpc { -/// @brief Analogous to `std::shared_mutex`, but it for fiber. +/// @brief Adaptive shared mutex primitive for both fiber and pthread context. /// @note Performance-wise, reader-writer lock does NOT perform well unless /// your critical section is sufficient large. In certain cases, reader-writer /// lock can perform worse than `Mutex`. If reader performance is critical to diff --git a/trpc/coroutine/fiber_shared_mutex_test.cc b/trpc/coroutine/fiber_shared_mutex_test.cc index 60e4d35a..2835ccfb 100644 --- a/trpc/coroutine/fiber_shared_mutex_test.cc +++ b/trpc/coroutine/fiber_shared_mutex_test.cc @@ -111,4 +111,102 @@ TEST(FiberShardMutex, AllWriter) { }); } +TEST(FiberSharedMutex, UseInPthreadContext) { + static constexpr auto B = 64; + Latch l(B); + + FiberSharedMutex rwlock; + ASSERT_EQ(true, rwlock.try_lock()); + rwlock.unlock(); + + ASSERT_EQ(true, rwlock.try_lock_shared()); + rwlock.unlock_shared(); + + std::thread ts[B]; + int counter1 = 0; + int counter2 = 0; + for (int i = 0; i != B; ++i) { + ts[i] = std::thread([&] { + for (int _ = 0; _ != 100; ++_) { + auto op = Random(100); + if (op < 80) { + std::shared_lock _(rwlock); + EXPECT_EQ(counter1, counter2); + } else { + std::scoped_lock _(rwlock); + ++counter1; + ++counter2; + EXPECT_EQ(counter1, counter2); + } + } + + l.count_down(); + }); + } + + l.wait(); + for (auto&& t : ts) { + t.join(); + } +} + +TEST(FiberSharedMutex, UseInMixedContext) { + RunAsFiber([] { + static constexpr auto B = 64; + FiberLatch l(2 * B); + + FiberSharedMutex rwlock; + + std::thread ts[B]; + std::vector fibers; + int counter1 = 0; + int counter2 = 0; + for (int i = 0; i != B; ++i) { + ts[i] = std::thread([&] { + for (int _ = 0; _ != 100; ++_) { + auto op = Random(100); + if (op < 80) { + std::shared_lock _(rwlock); + EXPECT_EQ(counter1, counter2); + } else { + std::scoped_lock _(rwlock); + ++counter1; + ++counter2; + EXPECT_EQ(counter1, counter2); + } + } + + l.CountDown(); + }); + + fibers.emplace_back([&] { + for (int _ = 0; _ != 100; ++_) { + auto op = Random(100); + if (op < 70) { + std::shared_lock _(rwlock); + EXPECT_EQ(counter1, counter2); + } else { + std::scoped_lock _(rwlock); + ++counter1; + ++counter2; + EXPECT_EQ(counter1, counter2); + } + } + + l.CountDown(); + }); + } + + l.Wait(); + + for (auto&& t : ts) { + t.join(); + } + + for (auto&& e : fibers) { + e.Join(); + } + }); +} + } // namespace trpc diff --git a/trpc/coroutine/fiber_test.cc b/trpc/coroutine/fiber_test.cc index d42ae864..bd943d37 100644 --- a/trpc/coroutine/fiber_test.cc +++ b/trpc/coroutine/fiber_test.cc @@ -24,9 +24,7 @@ #include "trpc/coroutine/fiber_latch.h" #include "trpc/coroutine/testing/fiber_runtime.h" -DECLARE_bool(trpc_fiber_stack_enable_guard_page); -DECLARE_int32(trpc_cross_numa_work_stealing_ratio); -DECLARE_int32(trpc_fiber_run_queue_size); +#define COUT std::cout << __FILE__ << ":" << __LINE__ << ":" << __FUNCTION__ << "|" namespace trpc { @@ -150,4 +148,154 @@ TEST(Fiber, GetFiberCount) { ASSERT_EQ(0, trpc::GetFiberCount() - fiber_count); }); } + +TEST(Fiber, FiberSleepInFiberContext) { + RunAsFiber([] { + FiberLatch l(1); + StartFiberDetached([&l] { + // Test FiberSleepFor. + auto sleep_for = 100 * std::chrono::milliseconds(1); + auto start = ReadSystemClock(); // Used system_clock intentionally. + FiberSleepFor(sleep_for); + + auto use_time = (ReadSystemClock() - start) / std::chrono::milliseconds(1); + auto expect_time = sleep_for / std::chrono::milliseconds(1); + auto error_time = use_time - expect_time; + COUT << "use_time:" << use_time << ",expect_time:" << expect_time + << ",error_time:" << error_time << std::endl; + + ASSERT_NEAR((ReadSystemClock() - start) / std::chrono::milliseconds(1), + sleep_for / std::chrono::milliseconds(1), 20); + + // Test FiberSleepUntil. + auto sleep_until = ReadSystemClock() + 100 * std::chrono::milliseconds(1); + FiberSleepUntil(sleep_until); + + use_time = (ReadSystemClock().time_since_epoch()) / std::chrono::milliseconds(1); + expect_time = sleep_until.time_since_epoch() / std::chrono::milliseconds(1); + error_time = use_time - expect_time; + + COUT << "use_time:" << use_time << ",expect_time:" << expect_time + << ",error_time:" << error_time << std::endl; + + ASSERT_NEAR(use_time, expect_time, 20); + + l.CountDown(); + }); + + l.Wait(); + }); +} + +TEST(Fiber, FiberSleepInPthreadContext) { + auto sleep_for = 100 * std::chrono::milliseconds(1); + auto start = ReadSystemClock(); + FiberSleepFor(sleep_for); + + auto use_time = (ReadSystemClock() - start) / std::chrono::milliseconds(1); + auto expect_time = sleep_for / std::chrono::milliseconds(1); + auto error_time = use_time - expect_time; + COUT << "use_time:" << use_time << ",expect_time:" << expect_time + << ",error_time:" << error_time << std::endl; + + ASSERT_NEAR((ReadSystemClock() - start) / std::chrono::milliseconds(1), + sleep_for / std::chrono::milliseconds(1), 20); + + auto sleep_until = ReadSystemClock() + 100 * std::chrono::milliseconds(1); + FiberSleepUntil(sleep_until); + + use_time = (ReadSystemClock().time_since_epoch()) / std::chrono::milliseconds(1); + expect_time = sleep_until.time_since_epoch() / std::chrono::milliseconds(1); + error_time = use_time - expect_time; + + COUT << "use_time:" << use_time << ",expect_time:" << expect_time + << ",error_time:" << error_time << std::endl; + + ASSERT_NEAR(use_time, expect_time, 20); +} + +TEST(Fiber, FiberMutexInFiberContext) { + RunAsFiber([] { + static constexpr auto B = 1000; + FiberLatch l(B); + + FiberMutex m; + int value = 0; + for (int i = 0; i != B; ++i) { + StartFiberDetached([&l, &m, &value] { + std::scoped_lock _(m); + ++value; + + l.CountDown(); + }); + } + + l.Wait(); + + COUT << "value:" << value << std::endl; + + ASSERT_EQ(B, value); + }); +} + +TEST(Fiber, FiberMutexInPthreadContext) { + static constexpr auto B = 64; + Latch l(B); + + FiberMutex m; + ASSERT_EQ(true, m.try_lock()); + m.unlock(); + + int value = 0; + std::thread ts[B]; + for (int i = 0; i != B; ++i) { + ts[i] = std::thread([&l, &m, &value] { + std::scoped_lock _(m); + ++value; + + l.count_down(); + }); + } + + l.wait(); + for (auto&& t : ts) { + t.join(); + } + + COUT << "value:" << value << std::endl; + ASSERT_EQ(B, value); +} + +TEST(Fiber, FiberMutexInMixedContext) { + RunAsFiber([] { + static constexpr auto B = 64; + FiberLatch l(2 * B); + FiberMutex m; + + int value = 0; + std::thread ts[B]; + for (int i = 0; i != B; ++i) { + ts[i] = std::thread([&l, &m, &value] { + std::scoped_lock _(m); + ++value; + l.CountDown(); + }); + + StartFiberDetached([&l, &m, &value] { + std::scoped_lock _(m); + ++value; + l.CountDown(); + }); + } + + l.Wait(); + for (auto&& t : ts) { + t.join(); + } + + COUT << "value:" << value << std::endl; + ASSERT_EQ(2 * B, value); + }); +} + } // namespace trpc diff --git a/trpc/runtime/threadmodel/fiber/detail/BUILD b/trpc/runtime/threadmodel/fiber/detail/BUILD index f01b84e0..d20e5762 100644 --- a/trpc/runtime/threadmodel/fiber/detail/BUILD +++ b/trpc/runtime/threadmodel/fiber/detail/BUILD @@ -73,7 +73,6 @@ cc_library( "stack_allocator_impl", ":assembly", ":context", - "//trpc/util/internal:casting", "//trpc/log:trpc_log", "//trpc/runtime/threadmodel/common:worker_thread", "//trpc/tvar/compound_ops:internal_latency", @@ -91,9 +90,11 @@ cc_library( "//trpc/util:unique_id", "//trpc/util/chrono", "//trpc/util/chrono:tsc", + "//trpc/util/internal:casting", "//trpc/util/object_pool:object_pool_ptr", "//trpc/util/queue:bounded_mpmc_queue", "//trpc/util/queue:bounded_spmc_queue", + "//trpc/util/thread:futex_notifier", "//trpc/util/thread:latch", "//trpc/util/thread:predicate_notifier", "//trpc/util/thread:spinlock", @@ -169,6 +170,8 @@ cc_test( deps = [ ":fiber_impl", ":testing", + "//trpc/coroutine:fiber", + "//trpc/coroutine/testing:fiber_runtime_test", "@com_google_googletest//:gtest_main", ], ) diff --git a/trpc/runtime/threadmodel/fiber/detail/scheduling/v1/scheduling_impl.cc b/trpc/runtime/threadmodel/fiber/detail/scheduling/v1/scheduling_impl.cc index 6dcc0847..f386fb01 100644 --- a/trpc/runtime/threadmodel/fiber/detail/scheduling/v1/scheduling_impl.cc +++ b/trpc/runtime/threadmodel/fiber/detail/scheduling/v1/scheduling_impl.cc @@ -114,7 +114,9 @@ void SchedulingImpl::Enter(std::size_t index) noexcept { worker_index_ = index; // Initialize master fiber for this worker. - SetUpMasterFiberEntity(); + if (worker_index_ != SchedulingGroup::kTimerWorkerIndex) { + SetUpMasterFiberEntity(); + } } void SchedulingImpl::Schedule() noexcept { diff --git a/trpc/runtime/threadmodel/fiber/detail/scheduling/v2/scheduling_impl.cc b/trpc/runtime/threadmodel/fiber/detail/scheduling/v2/scheduling_impl.cc index 8ba315f8..994bd7a7 100644 --- a/trpc/runtime/threadmodel/fiber/detail/scheduling/v2/scheduling_impl.cc +++ b/trpc/runtime/threadmodel/fiber/detail/scheduling/v2/scheduling_impl.cc @@ -74,7 +74,9 @@ void SchedulingImpl::Enter(std::size_t index) noexcept { vtm_[worker_index_] = worker_index_; } - SetUpMasterFiberEntity(); + if (worker_index_ != SchedulingGroup::kTimerWorkerIndex) { + SetUpMasterFiberEntity(); + } } void SchedulingImpl::Leave() noexcept { diff --git a/trpc/runtime/threadmodel/fiber/detail/waitable.cc b/trpc/runtime/threadmodel/fiber/detail/waitable.cc index b1606189..7b1c7595 100644 --- a/trpc/runtime/threadmodel/fiber/detail/waitable.cc +++ b/trpc/runtime/threadmodel/fiber/detail/waitable.cc @@ -100,7 +100,6 @@ class AsyncWaker { bool Waitable::AddWaiter(WaitBlock* waiter) { std::scoped_lock _(lock_); - TRPC_CHECK(waiter->waiter); if (persistent_awakened_) { return false; } @@ -113,7 +112,7 @@ bool Waitable::TryRemoveWaiter(WaitBlock* waiter) { return waiters_.erase(waiter); } -FiberEntity* Waitable::WakeOne() { +WaitBlock* Waitable::WakeOne() { std::scoped_lock _(lock_); while (true) { auto waiter = waiters_.pop_front(); @@ -121,23 +120,28 @@ FiberEntity* Waitable::WakeOne() { return nullptr; } // Memory order is guaranteed by `lock_`. - if (waiter->satisfied.exchange(true, std::memory_order_relaxed)) { - continue; // It's awakened by someone else. + if (waiter->waiter && waiter->satisfied.exchange(true, std::memory_order_relaxed)) { + // For fiber waiter, it's awakened by someone else. + continue; } - return waiter->waiter; + return waiter; } } -void Waitable::SetPersistentAwakened(FiberEntityList& wbs) { +void Waitable::SetPersistentAwakened(WaitBlockList& wbs) { std::scoped_lock _(lock_); persistent_awakened_ = true; while (auto ptr = waiters_.pop_front()) { // Same as `WakeOne`. - if (ptr->satisfied.exchange(true, std::memory_order_relaxed)) { + // `satisfied` is only applied to fiber context. + if ((ptr->waiter != nullptr) && ptr->satisfied.exchange(true, std::memory_order_relaxed)) { continue; } - wbs.push_back(ptr->waiter); + + // 1. Push back all fiber context which are not satisfied. + // 2. Push back all pthread context. + wbs.push_back(ptr); } } @@ -173,46 +177,55 @@ void WaitableTimer::wait() { } void WaitableTimer::OnTimerExpired(RefPtr ref) { - FiberEntityList fibers; - ref->SetPersistentAwakened(fibers); + WaitBlockList wbs; + ref->SetPersistentAwakened(wbs); while (true) { - auto* e = fibers.pop_front(); - if (!e) { + auto* wb = wbs.pop_front(); + if (!wb) { return; } - e->scheduling_group->Resume(e, std::unique_lock(e->scheduler_lock)); + // This is fiber waiter. + if (wb->waiter) { + wb->waiter->scheduling_group->Resume(wb->waiter, std::unique_lock(wb->waiter->scheduler_lock)); + continue; + } + + // This is pthread waiter. + wb->futex.Wake(1); } } // Implementation of `Mutex` goes below. void Mutex::unlock() { - TRPC_DCHECK(IsFiberContextPresent()); - - if (auto was = count_.fetch_sub(1, std::memory_order_release); was == 1) { + auto was = count_.fetch_sub(1, std::memory_order_release); + if (was == 1) { // Lucky day, no one is waiting on the mutex. // // Nothing to do. - } else { - TRPC_CHECK_GT(was, std::uint32_t(1)); - - // We need this lock so as to see a consistent state between `count_` and - // `impl_` ('s internal wait queue). - std::unique_lock splk(slow_path_lock_); - auto fiber = impl_.WakeOne(); - TRPC_CHECK(fiber); // Otherwise `was` must be 1 (as there's no waiter). - splk.unlock(); - fiber->scheduling_group->Resume(fiber, std::unique_lock(fiber->scheduler_lock)); + return; } -} + TRPC_CHECK_GT(was, std::uint32_t(1)); -void Mutex::LockSlow() { - TRPC_DCHECK(IsFiberContextPresent()); + // We need this lock so as to see a consistent state between `count_` and + // `impl_` ('s internal wait queue). + std::unique_lock splk(slow_path_lock_); + auto* wb = impl_.WakeOne(); + TRPC_CHECK(wb); // Otherwise `was` must be 1 (as there's no waiter). + splk.unlock(); - if (try_lock()) { - return; // Your lucky day. + + // This is fiber waiter. + if (wb->waiter) { + wb->waiter->scheduling_group->Resume(wb->waiter, std::unique_lock(wb->waiter->scheduler_lock)); + return; } + // This is pthread waiter. + wb->futex.Wake(1); +} + +void Mutex::LockSlowFromFiber() { // It's locked, take the slow path. std::unique_lock splk(slow_path_lock_); @@ -228,8 +241,7 @@ void Mutex::LockSlow() { auto current = GetCurrentFiberEntity(); std::unique_lock slk(current->scheduler_lock); WaitBlock wb = {.waiter = current}; - TRPC_CHECK(impl_.AddWaiter(&wb)); // This can't fail as we never call - // `SetPersistentAwakened()`. + TRPC_CHECK(impl_.AddWaiter(&wb)); // This can't fail as we never call `SetPersistentAwakened()`. // Now the slow path lock can be unlocked. // @@ -250,10 +262,27 @@ void Mutex::LockSlow() { return; } +void Mutex::LockSlowFromPthread() { + std::unique_lock splk(slow_path_lock_); + + if (count_.fetch_add(1, std::memory_order_acquire) == 0) { + return; + } + + // Set waiter to nullptr as we are in pthread context. + WaitBlock wb = {.waiter = nullptr}; + TRPC_CHECK(impl_.AddWaiter(&wb)); + trpc::FutexNotifier::State st = wb.futex.GetState(); + + splk.unlock(); + + // No need to handle return value as we only wait for awakening. + wb.futex.Wait(st, nullptr); +} + // Implementation of `ConditionVariable` goes below. void ConditionVariable::wait(std::unique_lock& lock) { - TRPC_DCHECK(IsFiberContextPresent()); TRPC_DCHECK(lock.owns_lock()); wait_until(lock, std::chrono::steady_clock::time_point::max()); @@ -262,8 +291,15 @@ void ConditionVariable::wait(std::unique_lock& lock) { bool ConditionVariable::wait_until( std::unique_lock& lock, std::chrono::steady_clock::time_point expires_at) { - TRPC_DCHECK(IsFiberContextPresent()); + if (IsFiberContextPresent()) { + return WaitUntilFromFiber(lock, expires_at); + } + return WaitUntilFromPthread(lock, expires_at); +} + +bool ConditionVariable::WaitUntilFromFiber(std::unique_lock& lock, + std::chrono::steady_clock::time_point expires_at) { auto current = GetCurrentFiberEntity(); auto sg = current->scheduling_group; bool use_timeout = expires_at != std::chrono::steady_clock::time_point::max(); @@ -298,55 +334,50 @@ bool ConditionVariable::wait_until( return !timeout; } -void ConditionVariable::notify_one() noexcept { - TRPC_DCHECK(IsFiberContextPresent()); +bool ConditionVariable::WaitUntilFromPthread(std::unique_lock& lock, + std::chrono::steady_clock::time_point expires_at) { + // Set waiter to nullptr as we are in pthread context. + WaitBlock wb = {.waiter = nullptr}; + TRPC_CHECK(impl_.AddWaiter(&wb)); + trpc::FutexNotifier::State st = wb.futex.GetState(); + lock.unlock(); - auto fiber = impl_.WakeOne(); - if (!fiber) { + bool ret = wb.futex.Wait(st, &expires_at); + impl_.TryRemoveWaiter(&wb); + lock.lock(); + return ret; +} + +void ConditionVariable::notify_one() noexcept { + auto* wb = impl_.WakeOne(); + if (!wb) { return; } - fiber->scheduling_group->Resume(fiber, std::unique_lock(fiber->scheduler_lock)); + // This is fiber waiter. + if (wb->waiter) { + wb->waiter->scheduling_group->Resume(wb->waiter, std::unique_lock(wb->waiter->scheduler_lock)); + return; + } + + // This is pthread waiter. + wb->futex.Wake(1); } void ConditionVariable::notify_all() noexcept { - TRPC_DCHECK(IsFiberContextPresent()); - - // We cannot keep calling `notify_one` here. If a waiter immediately goes to - // sleep again after we wake up it, it's possible that we wake it again when - // we try to drain the wait chain. - // - // So we remove all waiters first, and schedule them then. - std::array fibers_quick; - std::size_t array_usage = 0; - // We don't want to touch this in most cases. - // - // Given that `std::vector::vector()` is not allowed to throw, I do believe it - // won't allocated memory on construction. - FiberEntityList fibers_slow; - while (true) { - auto fiber = impl_.WakeOne(); - if (!fiber) { - break; - } - if (TRPC_LIKELY(array_usage < std::size(fibers_quick))) { - fibers_quick[array_usage++] = fiber; - } else { - fibers_slow.push_back(fiber); + auto* wb = impl_.WakeOne(); + if (!wb) { + return; } - } - // Schedule the waiters. - for (std::size_t index = 0; index != array_usage; ++index) { - auto&& e = fibers_quick[index]; - e->scheduling_group->Resume(e, std::unique_lock(e->scheduler_lock)); - } - while (true) { - auto* e = fibers_slow.pop_front(); - if (!e) { - return; + // This is fiber waiter. + if (wb->waiter) { + wb->waiter->scheduling_group->Resume(wb->waiter, std::unique_lock(wb->waiter->scheduler_lock)); + continue; } - e->scheduling_group->Resume(e, std::unique_lock(e->scheduler_lock)); + + // This is pthread waiter. + wb->futex.Wake(1); } } @@ -384,8 +415,15 @@ void ExitBarrier::Wait() { // Implementation of `Event` goes below. void Event::Wait() { - TRPC_DCHECK(IsFiberContextPresent()); + if (IsFiberContextPresent()) { + WaitFromFiber(); + return; + } + + WaitFromPthread(); +} +void Event::WaitFromFiber() { auto current = GetCurrentFiberEntity(); WaitBlock wb = {.waiter = current}; std::unique_lock lk(current->scheduler_lock); @@ -396,20 +434,38 @@ void Event::Wait() { } } +void Event::WaitFromPthread() { + // Set waiter to nullptr as we are in pthread context. + WaitBlock wb = {.waiter = nullptr}; + // Must get state before AddWaiter. + trpc::FutexNotifier::State st = wb.futex.GetState(); + if (impl_.AddWaiter(&wb)) { + wb.futex.Wait(st, nullptr); + } else { + // The event is set already, return immediately. + } +} + void Event::Set() { - // `IsFiberContextPresent()` is not checked. This method is explicitly allowed - // to be called out of fiber context. - FiberEntityList fibers; - impl_.SetPersistentAwakened(fibers); + WaitBlockList wbs; + impl_.SetPersistentAwakened(wbs); // Fiber wake-up must be delayed until we're done with `impl_`, otherwise // `impl_` can be destroyed after its emptied but before we touch it again. while (true) { - auto* e = fibers.pop_front(); - if (!e) { + auto* wb = wbs.pop_front(); + if (!wb) { return; } - e->scheduling_group->Resume(e, std::unique_lock(e->scheduler_lock)); + + // This is fiber waiter. + if (wb->waiter) { + wb->waiter->scheduling_group->Resume(wb->waiter, std::unique_lock(wb->waiter->scheduler_lock)); + continue; + } + + // This is pthread waiter. + wb->futex.Wake(1); } } diff --git a/trpc/runtime/threadmodel/fiber/detail/waitable.h b/trpc/runtime/threadmodel/fiber/detail/waitable.h index 7a7a60f7..6716d41f 100644 --- a/trpc/runtime/threadmodel/fiber/detail/waitable.h +++ b/trpc/runtime/threadmodel/fiber/detail/waitable.h @@ -24,6 +24,7 @@ #include "trpc/util/likely.h" #include "trpc/util/object_pool/object_pool_ptr.h" #include "trpc/util/ref_ptr.h" +#include "trpc/util/thread/futex_notifier.h" #include "trpc/util/thread/spinlock.h" namespace trpc::fiber::detail { @@ -40,9 +41,12 @@ class SchedulingGroup; struct WaitBlock { FiberEntity* waiter = nullptr; // This initialization will be optimized away. trpc::DoublyLinkedListEntry chain; - std::atomic satisfied = false; + trpc::FutexNotifier futex; // For pthread context, which waiter is default nullptr. + std::atomic satisfied = false; // For fiber context, which waiter is not nullptr. }; +using WaitBlockList = trpc::DoublyLinkedList; + // Basic class for implementing waitable classes. // // Do NOT use this class directly, it's meant to be used as a building block. @@ -58,9 +62,9 @@ class Waitable { // Returns `true` if the waiter is added to the wait chain, returns // `false` if the wait is immediately satisfied. // - // To prevent wake-up loss, `FiberEntity::scheduler_lock` must be held by the - // caller. (Otherwise before you take the lock, the fiber could have been - // concurrently waken up, which is lost, by someone else.) + // For fiber context, to prevent wake-up loss, `FiberEntity::scheduler_lock` + // must be held by the caller. (Otherwise before you take the lock, the fiber + // could have been concurrently waken up, which is lost, by someone else.) bool AddWaiter(WaitBlock* waiter); // Remove a waiter. @@ -71,7 +75,7 @@ class Waitable { // Popup one waiter and schedule it. // // Returns `nullptr` if there's no waiter. - FiberEntity* WakeOne(); + WaitBlock* WakeOne(); // Set this `Waitable` as "persistently" awakened. After this call, all // further calls to `AddWaiter` will fail. @@ -83,7 +87,7 @@ class Waitable { // immediately and could have freed this `Waitable` before you touch it again. // // Normally you should call `WakeAll` after calling this method. - void SetPersistentAwakened(FiberEntityList& wbs); + void SetPersistentAwakened(WaitBlockList& wbs); // Undo `SetPersistentAwakened()`. void ResetAwakened(); @@ -95,7 +99,7 @@ class Waitable { private: Spinlock lock_; bool persistent_awakened_ = false; - trpc::DoublyLinkedList waiters_; + WaitBlockList waiters_; }; // "Waitable" timer. This `Waitable` signals all its waiters once the given time @@ -135,29 +139,33 @@ class WaitableTimer { RefPtr impl_; }; -// Mutex for fiber. +/// @brief Implementation of adaptive mutex primitive for both fiber and pthread context. class Mutex { public: bool try_lock() { - TRPC_DCHECK(IsFiberContextPresent()); - std::uint32_t expected = 0; return count_.compare_exchange_strong(expected, 1, std::memory_order_acquire); } void lock() { - TRPC_DCHECK(IsFiberContextPresent()); - if (TRPC_LIKELY(try_lock())) { return; } - LockSlow(); + + if (IsFiberContextPresent()) { + LockSlowFromFiber(); + return; + } + + LockSlowFromPthread(); } void unlock(); private: - void LockSlow(); + void LockSlowFromFiber(); + + void LockSlowFromPthread(); private: Waitable impl_; @@ -165,20 +173,17 @@ class Mutex { // Synchronizes between slow path of `lock()` and `unlock()`. Spinlock slow_path_lock_; - // Number of waiters (plus the owner). Hopefully `std::uint32_t` is large - // enough. + // Number of waiters (plus the owner). Hopefully `std::uint32_t` is large enough. std::atomic count_{0}; }; -// Condition variable for fiber. +/// @brief Adaptive condition variable primitive for both fiber and pthread context. class ConditionVariable { public: void wait(std::unique_lock& lock); template void wait(std::unique_lock& lock, F&& pred) { - TRPC_DCHECK(IsFiberContextPresent()); - while (!std::forward(pred)()) { wait(lock); } @@ -194,8 +199,6 @@ class ConditionVariable { template bool wait_until(std::unique_lock& lk, std::chrono::steady_clock::time_point timeout, F&& pred) { - TRPC_DCHECK(IsFiberContextPresent()); - while (!std::forward(pred)()) { wait_until(lk, timeout); if (ReadSteadyClock() >= timeout) { @@ -209,6 +212,13 @@ class ConditionVariable { void notify_one() noexcept; void notify_all() noexcept; + private: + bool WaitUntilFromFiber(std::unique_lock& lock, + std::chrono::steady_clock::time_point expires_at); + + bool WaitUntilFromPthread(std::unique_lock& lock, + std::chrono::steady_clock::time_point expires_at); + private: Waitable impl_; }; @@ -243,22 +253,24 @@ class ExitBarrier : public object_pool::EnableLwSharedFromThis { ConditionVariable cv_; }; -// Emulates Event in Win32 API. -// -// For internal use only. Normally you'd like to use `Mutex` + -// `ConditionVariable` instead. +/// @brief Adaptive Event primitive for both fiber and pthread context. +/// @note Emulates Event in Win32 API. +/// For internal use only. Normally you'd like to use `Mutex` + `ConditionVariable` instead. class Event { public: // Wait until `Set()` is called. If `Set()` is called before `Wait()`, this // method returns immediately. void Wait(); - // Wake up fibers blockings on `Wait()`. All subsequent calls to `Wait()` will + // Wake up fibers or pthreads blockings on `Wait()`. All subsequent calls to `Wait()` will // return immediately. - // - // It's explicitly allowed to call this method outside of fiber context. void Set(); + private: + void WaitFromFiber(); + + void WaitFromPthread(); + private: Waitable impl_; }; diff --git a/trpc/runtime/threadmodel/fiber/detail/waitable_test.cc b/trpc/runtime/threadmodel/fiber/detail/waitable_test.cc index fd7f83fd..b04d489b 100644 --- a/trpc/runtime/threadmodel/fiber/detail/waitable_test.cc +++ b/trpc/runtime/threadmodel/fiber/detail/waitable_test.cc @@ -19,6 +19,8 @@ #include "gtest/gtest.h" +#include "trpc/coroutine/fiber.h" +#include "trpc/coroutine/testing/fiber_runtime.h" #include "trpc/runtime/threadmodel/fiber/detail/fiber_entity.h" #include "trpc/runtime/threadmodel/fiber/detail/fiber_worker.h" #include "trpc/runtime/threadmodel/fiber/detail/scheduling_group.h" @@ -105,6 +107,181 @@ TEST(Waitable, MutexOnScheduling) { } } +TEST(Waitable, MutexInPthreadContext) { + static constexpr auto B = 64; + + Mutex m; + ASSERT_EQ(true, m.try_lock()); + m.unlock(); + + int value = 0; + std::thread ts[B]; + for (int i = 0; i != B; ++i) { + ts[i] = std::thread([&m, &value] { + std::scoped_lock _(m); + ++value; + }); + } + + for (auto&& t : ts) { + t.join(); + } + + ASSERT_EQ(B, value); +} + +TEST(Waitable, MutexInMixedContext) { + RunAsFiber([] { + static constexpr auto B = 64; + Mutex m; + + int value = 0; + std::thread ts[B]; + std::vector fbs(B); + for (int i = 0; i != B; ++i) { + ts[i] = std::thread([&m, &value] { + std::scoped_lock _(m); + ++value; + }); + + fbs[i] = trpc::Fiber([&m, &value] { + std::scoped_lock _(m); + ++value; + }); + } + + for (auto&& t : ts) { + t.join(); + } + + for (auto&& e : fbs) { + e.Join(); + } + + ASSERT_EQ(2 * B, value); + }); +} +TEST(Waitable, ConditionVariableInPthreadContext) { + constexpr auto N = 64; + std::atomic run{0}; + Mutex lock[N]; + ConditionVariable cv[N]; + bool set[N] = {false}; + std::vector prod(N); + std::vector cons(N); + + for (int i = 0; i != N; ++i) { + prod[i] = std::thread([&run, i, &cv, &lock, &set] { + trpc::FiberSleepFor(trpc::Random(20) * std::chrono::milliseconds(1)); + std::unique_lock lk(lock[i]); + cv[i].wait(lk, [&] { return set[i]; }); + ++run; + }); + + cons[i] = std::thread([&run, i, &cv, &lock, &set] { + trpc::FiberSleepFor(trpc::Random(20) * std::chrono::milliseconds(1)); + std::scoped_lock _(lock[i]); + set[i] = true; + cv[i].notify_one(); + ++run; + }); + } + + for (auto&& e : prod) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + for (auto&& e : cons) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + ASSERT_EQ(N * 2, run); +} + +TEST(Waitable, ConditionVariableNotifyPthreadFromFiber) { + RunAsFiber([] { + constexpr auto K = 64; + std::atomic executed{0}; + Mutex lock[K]; + ConditionVariable cv[K]; + bool set[K] = {false}; + std::vector prod(K); + std::vector cons(K); + + for (int i = 0; i != K; ++i) { + prod[i] = std::thread([&executed, i, &cv, &lock, &set] { + trpc::FiberSleepFor(trpc::Random(20) * std::chrono::milliseconds(1)); + std::unique_lock lk(lock[i]); + cv[i].wait(lk, [&] { return set[i]; }); + ++executed; + }); + + cons[i] = trpc::Fiber([&executed, i, &cv, &lock, &set] { + trpc::FiberSleepFor(trpc::Random(20) * std::chrono::milliseconds(1)); + std::scoped_lock _(lock[i]); + set[i] = true; + cv[i].notify_one(); + ++executed; + }); + } + + for (auto&& e : prod) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + for (auto&& e : cons) { + ASSERT_TRUE(e.Joinable()); + e.Join(); + } + + ASSERT_EQ(K * 2, executed); + }); +} + +TEST(Waitable, ConditionVariableNotifyFiberFromPthread) { + RunAsFiber([] { + constexpr auto N = 64; + std::atomic run{0}; + Mutex lock[N]; + ConditionVariable cv[N]; + bool set[N] = {false}; + std::vector prod(N); + std::vector cons(N); + + for (int i = 0; i != N; ++i) { + prod[i] = trpc::Fiber([&run, i, &cv, &lock, &set] { + trpc::FiberSleepFor(trpc::Random(20) * std::chrono::milliseconds(1)); + std::scoped_lock _(lock[i]); + set[i] = true; + cv[i].notify_one(); + ++run; + }); + + cons[i] = std::thread([&run, i, &cv, &lock, &set] { + trpc::FiberSleepFor(trpc::Random(20) * std::chrono::milliseconds(1)); + std::unique_lock lk(lock[i]); + cv[i].wait(lk, [&] { return set[i]; }); + ++run; + }); + } + + for (auto&& e : prod) { + ASSERT_TRUE(e.Joinable()); + e.Join(); + } + + for (auto&& e : cons) { + ASSERT_TRUE(e.joinable()); + e.join(); + } + + ASSERT_EQ(N * 2, run); + }); +} + void TestConditionVariable(std::string_view scheduling_name) { constexpr auto N = 10000; @@ -350,6 +527,34 @@ void TestEvent(std::string_view scheduling_name) { } } +TEST(Waitable, EventWaitInPthreadSetFromFiber) { + RunAsFiber([]() { + for (int i = 0; i != 100; ++i) { + auto ev = std::make_unique(); + std::thread t([&] { + trpc::FiberSleepFor(trpc::Random(10) * std::chrono::milliseconds(1)); + ev->Wait(); + }); + trpc::FiberSleepFor(trpc::Random(10) * std::chrono::milliseconds(1)); + ev->Set(); + t.join(); + } + }); +} + +TEST(Waitable, EventWaitInPthreadSetFromPthread) { + for (int i = 0; i != 100; ++i) { + auto ev = std::make_unique(); + std::thread t([&] { + trpc::FiberSleepFor(trpc::Random(10) * std::chrono::milliseconds(1)); + ev->Wait(); + }); + trpc::FiberSleepFor(trpc::Random(10) * std::chrono::milliseconds(1)); + ev->Set(); + t.join(); + } +} + TEST(Waitable, EventOnScheduling) { for (auto& name : kSchedulingNames) { TestEvent(name); diff --git a/trpc/util/thread/BUILD b/trpc/util/thread/BUILD index cb9875ac..0c7ce236 100644 --- a/trpc/util/thread/BUILD +++ b/trpc/util/thread/BUILD @@ -39,6 +39,18 @@ cc_library( cc_library( name = "futex_notifier", hdrs = ["futex_notifier.h"], + deps = [ + "//trpc/util/chrono", + ], +) + +cc_test( + name = "futex_notifier_test", + srcs = ["futex_notifier_test.cc"], + deps = [ + ":futex_notifier", + "@com_google_googletest//:gtest_main", + ], ) cc_library( diff --git a/trpc/util/thread/futex_notifier.h b/trpc/util/thread/futex_notifier.h index 1dd3debd..bc049f92 100644 --- a/trpc/util/thread/futex_notifier.h +++ b/trpc/util/thread/futex_notifier.h @@ -21,6 +21,8 @@ #include #include +#include "trpc/util/chrono/chrono.h" + namespace trpc { /// @brief Waking up/notifying between multiple threads based on futex @@ -54,6 +56,54 @@ class FutexNotifier { expected_state.val_, NULL, NULL, 0); } + /// @brief Wait to be awakened up or until timeout occured. + /// @param expected_state Expected state to block at. + /// @param abs_time Pointer to absolute time to trigger timeout. + /// @return false: timeout occured, true: awakened up. + /// @note For compatibility, upper `void Wait` interface is reserved. + bool Wait(const State& expected_state, const std::chrono::steady_clock::time_point* abs_time) { + timespec timeout; + timespec* timeout_ptr = nullptr; + + while (true) { + if (abs_time != nullptr) { + auto now = ReadSteadyClock(); + + // Already timeout. + if (*abs_time <= now) + return false; + + auto diff = (*abs_time - now) / std::chrono::nanoseconds(1); + // Convert format. + timeout.tv_sec = diff / 1000000000L; + timeout.tv_nsec = diff - timeout.tv_sec * 1000000000L; + timeout_ptr = &timeout; + } + + int ret = syscall(SYS_futex, &pending_wake_, (FUTEX_WAIT | FUTEX_PRIVATE_FLAG), + expected_state.val_, timeout_ptr, nullptr, 0); + // Timeout occured. + if ((ret != 0) && (errno == ETIMEDOUT)) + return false; + + // The value pointed to by uaddr was not equal to the expected value val at the time of the call. + // Take it as awakened up by another thread already. + if ((ret != 0) && (errno == EAGAIN)) + return true; + + // Interrupted, just continue to guarantee enough timeout. + if ((ret != 0) && (errno == EINTR)) + continue; + + // Spurious wake-up, just continue. + if ((ret == 0) && (pending_wake_ == expected_state.val_)) + continue; + + // Awakened up by others or real timeout. + return (ret == 0) ? true : false; + } + } + void Stop() { pending_wake_.fetch_or(1); syscall(SYS_futex, &pending_wake_, (FUTEX_WAKE | FUTEX_PRIVATE_FLAG), diff --git a/trpc/util/thread/futex_notifier_test.cc b/trpc/util/thread/futex_notifier_test.cc new file mode 100644 index 00000000..36db58b6 --- /dev/null +++ b/trpc/util/thread/futex_notifier_test.cc @@ -0,0 +1,139 @@ +// +// +// Tencent is pleased to support the open source community by making tRPC available. +// +// Copyright (C) 2023 THL A29 Limited, a Tencent company. +// All rights reserved. +// +// If you have downloaded a copy of the tRPC source code from Tencent, +// please note that tRPC source code is licensed under the Apache 2.0 License, +// A copy of the Apache 2.0 License is included in this file. +// +// + +#include "trpc/util/thread/futex_notifier.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace trpc::testing { + +TEST(FutexNotifierTest, FutexNotifierTest) { + trpc::FutexNotifier futex_notifier; + + bool is_executed = false; + std::thread t([&futex_notifier, &is_executed] { + while (true) { + const trpc::FutexNotifier::State st = futex_notifier.GetState(); + if (st.Stopped()) { + break; + } else { + futex_notifier.Wait(st); + is_executed = true; + } + } + }); + + while (!is_executed) { + futex_notifier.Wake(1); + } + + futex_notifier.Stop(); + + t.join(); + + ASSERT_TRUE(is_executed); +} + +TEST(FutexNotifierTest, ReturnByETIMEDOUT) { + trpc::FutexNotifier futex_notifier; + bool is_awaken = true; + const trpc::FutexNotifier::State st = futex_notifier.GetState(); + + std::thread t([&futex_notifier, &is_awaken, &st] { + // Set timeout to 1ms. + std::chrono::steady_clock::time_point abs_time = ReadSteadyClock() + std::chrono::microseconds(1000); + is_awaken = futex_notifier.Wait(st, &abs_time); + }); + + // Trigger timeout, no need to wake. + t.join(); + ASSERT_FALSE(is_awaken); +} + +TEST(FutexNotifierTest, ReturnByEAGAIN) { + trpc::FutexNotifier futex_notifier; + bool is_awaken = false; + const trpc::FutexNotifier::State st = futex_notifier.GetState(); + + std::thread t([&futex_notifier, &is_awaken, &st] { + // Let wake first. + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + // Set timeout to 1ms. + std::chrono::steady_clock::time_point abs_time = ReadSteadyClock() + std::chrono::microseconds(1000); + // Wait after wake, equal to awakened up. + is_awaken = futex_notifier.Wait(st, &abs_time); + }); + + futex_notifier.Wake(1); + + t.join(); + ASSERT_TRUE(is_awaken); +} + +TEST(FutexNotifierTest, ReturnByWakeWithoutTimeoutSet) { + trpc::FutexNotifier futex_notifier; + bool is_awaken = false; + const trpc::FutexNotifier::State st = futex_notifier.GetState(); + + std::thread t([&futex_notifier, &is_awaken, &st] { + is_awaken = futex_notifier.Wait(st, NULL); + }); + + // Let wait first. + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + + futex_notifier.Wake(1); + + t.join(); + ASSERT_TRUE(is_awaken); +} + +TEST(FutexNotifierTest, ReturnByWakeWithTimeoutSet) { + trpc::FutexNotifier futex_notifier; + bool is_awaken = false; + const trpc::FutexNotifier::State st = futex_notifier.GetState(); + + std::thread t([&futex_notifier, &is_awaken, &st] { + // Set timeout to large enough 10s. + std::chrono::steady_clock::time_point abs_time = ReadSteadyClock() + std::chrono::microseconds(10000000); + is_awaken = futex_notifier.Wait(st, &abs_time); + }); + + // Let wait first. + std::this_thread::sleep_for(std::chrono::milliseconds(2)); + + futex_notifier.Wake(1); + + t.join(); + ASSERT_TRUE(is_awaken); +} + +TEST(FutexNotifierTest, TimeoutLessThanNow) { + trpc::FutexNotifier futex_notifier; + bool is_awaken = true; + const trpc::FutexNotifier::State st = futex_notifier.GetState(); + + std::thread t([&futex_notifier, &is_awaken, &st] { + // Set timeout to less than now. + std::chrono::steady_clock::time_point abs_time = ReadSteadyClock() - std::chrono::microseconds(1000); + is_awaken = futex_notifier.Wait(st, &abs_time); + }); + + t.join(); + ASSERT_FALSE(is_awaken); +} + +} // namespace trpc::testing