From 3e228690bedef21d7f042efaabed3c7aaafeb098 Mon Sep 17 00:00:00 2001 From: Paul T Date: Fri, 29 Sep 2023 14:56:44 -0400 Subject: [PATCH] feature: wip implementation of chase-lev deque work in progress implementation of a chase-lev, lock free, work stealing deque. --- include/thread_pool/thread_pool.h | 29 ++- include/thread_pool/work_stealing_deque.h | 194 ++++++++++++++++++++ test/source/work_stealing_deque.cpp | 206 ++++++++++++++++++++++ 3 files changed, 413 insertions(+), 16 deletions(-) create mode 100644 include/thread_pool/work_stealing_deque.h create mode 100644 test/source/work_stealing_deque.cpp diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 08a4735..5a7bdd4 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -17,11 +17,13 @@ #endif #include "thread_pool/thread_safe_queue.h" +#include "thread_pool/work_stealing_deque.h" namespace dp { namespace details { -#ifdef __cpp_lib_move_only_function + // TODO: use move only function, work stealing deque can't use move only types +#if 0 // __cpp_lib_move_only_function using default_function_type = std::move_only_function; #else using default_function_type = std::function; @@ -48,7 +50,7 @@ namespace dp { do { // invoke the task - while (auto task = tasks_[id].tasks.pop_front()) { + while (auto task = tasks_[id].tasks.pop_top()) { try { pending_tasks_.fetch_sub(1, std::memory_order_release); std::invoke(std::move(task.value())); @@ -56,16 +58,11 @@ namespace dp { } } - // try to steal a task - for (std::size_t j = 1; j < tasks_.size(); ++j) { - const std::size_t index = (id + j) % tasks_.size(); - if (auto task = tasks_[index].tasks.steal()) { - // steal a task - pending_tasks_.fetch_sub(1, std::memory_order_release); - std::invoke(std::move(task.value())); - // stop stealing once we have invoked a stolen task - break; - } + // try to steal a task from our donor + auto donor_index = (id + 1) % tasks_.size(); + if (auto task = tasks_[donor_index].tasks.pop_top()) { + pending_tasks_.fetch_sub(1, std::memory_order_release); + std::invoke(std::move(task.value())); } } while (pending_tasks_.load(std::memory_order_acquire) > 0); @@ -116,8 +113,8 @@ namespace dp { typename ReturnType = std::invoke_result_t> requires std::invocable [[nodiscard]] std::future enqueue(Function f, Args... args) { -#ifdef __cpp_lib_move_only_function - // we can do this in C++23 because we now have support for move only functions +#if 0 // __cpp_lib_move_only_function + // we can do this in C++23 because we now have support for move only functions std::promise promise; auto future = promise.get_future(); auto task = [func = std::move(f), ... largs = std::move(args), @@ -204,12 +201,12 @@ namespace dp { } auto i = *(i_opt); pending_tasks_.fetch_add(1, std::memory_order_relaxed); - tasks_[i].tasks.push_back(std::forward(f)); + tasks_[i].tasks.push_bottom(std::forward(f)); tasks_[i].signal.release(); } struct task_item { - dp::thread_safe_queue tasks{}; + dp::work_stealing_deque tasks{}; std::binary_semaphore signal{0}; }; diff --git a/include/thread_pool/work_stealing_deque.h b/include/thread_pool/work_stealing_deque.h new file mode 100644 index 0000000..44ffaa1 --- /dev/null +++ b/include/thread_pool/work_stealing_deque.h @@ -0,0 +1,194 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace dp { + +#ifdef __cpp_lib_hardware_interference_size + using std::hardware_destructive_interference_size; +#else + // 64 bytes on x86-64 │ L1_CACHE_BYTES │ L1_CACHE_SHIFT │ __cacheline_aligned │ ... + inline constexpr std::size_t hardware_destructive_interference_size = + 2 * sizeof(std::max_align_t); +#endif + + /** + * @brief Chase-Lev work stealing queue + * @details Support single producer, multiple consumer. The producer owns the back, consumers + * own the top. Consumers can also take from the top of the queue. The queue is "lock-free" in + * that it does not directly use mutexes or locks. + * + * This is an implementation of the deque described in "Correct and Efficient Work-Stealing for + * Weak Memory Models" and "Dynamic Circular Work-Stealing Deque" by Chase,Lev. + * + */ + template + requires std::is_destructible_v + class work_stealing_deque final { + /** + * @brief Simple circular array buffer that can regrow + * TODO: Leverage std::pmr facilities to automatically allocate/reclaim memory? + */ + class circular_buffer final { + public: + explicit circular_buffer(const std::int64_t size) : size_(size), mask_(size - 1) { + // size must be a power of 2 + assert((size % 2) == 0); + + buffer_ = std::make_unique_for_overwrite(size_); + pointer_.store(buffer_.get(), release); + } + + [[nodiscard]] std::int64_t capacity() const noexcept { return size_; } + + void store(const std::size_t index, T value, std::memory_order order = acquire) noexcept + requires std::is_move_assignable_v + { + auto buf = pointer_.load(order); + buf[index & mask_] = value; + } + + T load(const std::size_t index, std::memory_order order = acquire) noexcept { + auto buf = pointer_.load(order); + return buf[index & mask_]; + } + + /** + * @brief Resize the internal buffer. Copies [start, end) to the new buffer. + * @param start The start index + * @param end The end index + */ + circular_buffer* resize(const std::size_t start, const std::size_t end) { + auto temp = new circular_buffer(size_ * 2); + for (std::size_t i = start; i != end; ++i) { + temp->store(i, load(i)); + } + return temp; + } + + private: + std::int64_t size_; + std::int64_t mask_; + std::atomic pointer_; + std::unique_ptr buffer_; + }; + + constexpr static std::size_t default_count = 1024; + alignas(hardware_destructive_interference_size) std::atomic_int64_t top_; + alignas(hardware_destructive_interference_size) std::atomic_int64_t bottom_; + alignas(hardware_destructive_interference_size) std::atomic buffer_; + + std::vector> garbage_{32}; + + static constexpr std::memory_order relaxed = std::memory_order_relaxed; + static constexpr std::memory_order acquire = std::memory_order_acquire; + static constexpr std::memory_order consume = std::memory_order_consume; + static constexpr std::memory_order release = std::memory_order_release; + static constexpr std::memory_order seq_cst = std::memory_order_seq_cst; + + public: + explicit work_stealing_deque(const std::size_t& capacity = default_count) + : top_(0), bottom_(0), buffer_(new circular_buffer(capacity)) {} + + // queue is non-copyable + work_stealing_deque(work_stealing_deque&) = delete; + work_stealing_deque& operator=(work_stealing_deque&) = delete; + + [[nodiscard]] std::size_t capacity() const { return buffer_.load(relaxed)->capacity(); } + [[nodiscard]] std::size_t size() const { + const auto bottom = bottom_.load(relaxed); + const auto top = top_.load(relaxed); + return static_cast(bottom >= top ? bottom - top : 0); + } + + [[nodiscard]] bool empty() const { return size() == 0; } + template + void push_bottom(Args&&... args) { + // construct first in case it throws + T value(std::forward(args)...); + push_bottom(std::move(value)); + } + + void push_bottom(T value) { + auto bottom = bottom_.load(relaxed); + auto top = top_.load(acquire); + auto buffer = buffer_.load(relaxed); + + if (buffer->capacity() < (bottom - top) + 1) { + garbage_.emplace_back(std::exchange(buffer, buffer->resize(top, bottom))); + buffer_.store(buffer, release); + } + + buffer->store(bottom, std::move(value)); + + // this synchronizes with other acquire fences + // memory operations about this line cannot be reordered + std::atomic_thread_fence(release); + + bottom_.store(bottom + 1, relaxed); + } + + std::optional take_bottom() { + auto bottom = bottom_.load(relaxed) - 1; + auto buffer = buffer_.load(relaxed); + + // prevent stealing + bottom_.store(bottom, relaxed); + + // this synchronizes with other release fences + // memory ops below this line cannot be reordered + std::atomic_thread_fence(acquire); + + auto top = top_.load(relaxed); + if (top <= bottom) { + // queue isn't empty + if (top == bottom) { + // there is only 1 item left in the queue, we need the CAS to succeed + // since another thread may be trying to steal and could steal before we're able + // to take the bottom + if (!top_.compare_exchange_strong(top, top + 1, seq_cst, relaxed)) { + // failed race + bottom_.store(bottom + 1, relaxed); + return std::nullopt; + } + bottom_.store(bottom + 1, relaxed); + } + // there is more than one item in the queue, we can take the bottom + return buffer->load(bottom); + } + // queue is empty, reset bottom + bottom_.store(bottom + 1, relaxed); + return std::nullopt; + } + + std::optional pop_top() { + auto top = top_.load(acquire); + // this synchronizes with other release fences + // memory ops below this line cannot be reordered with ops above this line + std::atomic_thread_fence(acquire); + const auto bottom = bottom_.load(acquire); + + if (top < bottom) { + // non-empty queue + auto buffer = buffer_.load(release); + auto temp = buffer->load(top, acquire); + if (!top_.compare_exchange_strong(top, top + 1, seq_cst, relaxed)) { + // failed the race + return std::nullopt; + } + return temp; + } else { + // deque is empty + return std::nullopt; + } + } + }; +} // namespace dp diff --git a/test/source/work_stealing_deque.cpp b/test/source/work_stealing_deque.cpp new file mode 100644 index 0000000..8f11043 --- /dev/null +++ b/test/source/work_stealing_deque.cpp @@ -0,0 +1,206 @@ +#include +#include + +#include +#include +#include +#include + +TEST_CASE("Construct queue") { dp::work_stealing_deque queue{}; } + +TEST_CASE("Construct and grow queue") { + dp::work_stealing_deque queue{2}; + + queue.push_bottom(1); + queue.push_bottom(2); + queue.push_bottom(3); + + REQUIRE_EQ(queue.capacity(), 4); +} + +TEST_CASE("Take bottom while queue is empty") { + dp::work_stealing_deque queue{}; + + REQUIRE_EQ(queue.take_bottom(), std::nullopt); +} + +TEST_CASE("Take bottom while queue is not empty") { + dp::work_stealing_deque queue{}; + + queue.push_bottom(1); + queue.push_bottom(2); + queue.push_bottom(3); + + REQUIRE_EQ(queue.take_bottom(), 3); + REQUIRE_EQ(queue.take_bottom(), 2); + REQUIRE_EQ(queue.take_bottom(), 1); + REQUIRE_EQ(queue.take_bottom(), std::nullopt); +} + +TEST_CASE("Multiple thread steal single item") { + dp::work_stealing_deque queue{}; + + queue.push_bottom(23567); + std::uint64_t value = 0; + + { + auto thread_task = [&queue, &value]() { + if (const auto temp = queue.pop_top()) { + value = temp.value(); + } + }; + std::jthread t1{thread_task}; + std::jthread t2{thread_task}; + std::jthread t3{thread_task}; + std::jthread t4{thread_task}; + } + + REQUIRE_EQ(value, 23567); +} + +TEST_CASE("Steal std::function while pushing") { + dp::work_stealing_deque> deque{}; + std::atomic_uint64_t count{0}; + constexpr auto max = 64'000; + auto expected_sum = 0; + std::atomic_uint64_t pending_tasks{0}; + std::deque signals; + signals.emplace_back(0); + signals.emplace_back(0); + signals.emplace_back(0); + signals.emplace_back(0); + + auto supply_task = [&] { + for (auto i = 0; i < max; i++) { + deque.push_bottom([&count, i]() { count += i; }); + expected_sum += i; + pending_tasks.fetch_add(1, std::memory_order_release); + // wake all threads + if ((i + 1) % 8000 == 0) { + for (auto& signal : signals) signal.release(); + } + } + }; + + auto task = [&](int id) { + signals[id].acquire(); + while (pending_tasks.load(std::memory_order_acquire) > 0) { + auto value = deque.pop_top(); + if (value.has_value()) { + auto temp = std::move(value.value()); + std::invoke(temp); + pending_tasks.fetch_sub(1, std::memory_order_release); + } + } + }; + + { + std::jthread supplier(supply_task); + std::jthread t1(task, 0); + std::jthread t2(task, 1); + std::jthread t3(task, 2); + std::jthread t4(task, 3); + } + + REQUIRE_EQ(count.load(), expected_sum); +} + +// class move_only { +// int private_value_ = 2; +// +// public: +// move_only() = default; +// ~move_only() = default; +// move_only(move_only&) = delete; +// move_only(move_only&& other) noexcept { private_value_ = other.private_value_ * 2; } +// move_only& operator=(move_only&) = delete; +// move_only& operator=(move_only&& other) noexcept { +// private_value_ = other.private_value_ * 2; +// return *this; +// } +// [[nodiscard]] int secret() const { return private_value_; } +// }; +// +// TEST_CASE("Store move only types") { +// move_only mv_only{}; +// dp::work_stealing_deque deque{}; +// deque.push_bottom(std::move(mv_only)); +// +// const auto value = deque.take_bottom(); +// REQUIRE(value.has_value()); +// REQUIRE_NE(value->secret(), 2); +// } +// +// TEST_CASE("Steal move only type") { +// move_only mv_only{}; +// dp::work_stealing_deque queue{}; +// queue.push_bottom(std::move(mv_only)); +// std::optional value = std::nullopt; +// { +// auto thread_task = [&queue, &value]() { +// if (auto temp = queue.pop_top()) { +// value.emplace(std::move(temp.value())); +// } +// }; +// +// std::jthread t1{thread_task}; +// std::jthread t2{thread_task}; +// std::jthread t3{thread_task}; +// std::jthread t4{thread_task}; +// } +// +// REQUIRE(value.has_value()); +// REQUIRE_NE(value->secret(), 2); +// } +// +// #if __cpp_lib_move_only_function +// +// TEST_CASE("Steal std::move_only_function while pushing") { +// dp::work_stealing_deque> deque{}; +// std::atomic_uint64_t count{0}; +// constexpr auto max = 64'000; +// auto expected_sum = 0; +// std::atomic_uint64_t pending_tasks{0}; +// std::deque signals; +// signals.emplace_back(0); +// signals.emplace_back(0); +// signals.emplace_back(0); +// signals.emplace_back(0); +// +// auto supply_task = [&] { +// for (auto i = 0; i < max; i++) { +// deque.push_bottom([&count, i]() { count += i; }); +// expected_sum += i; +// pending_tasks.fetch_add(1, std::memory_order_release); +// // wake all threads +// if (i % 1000 == 0) { +// for (auto& signal : signals) signal.release(); +// } +// } +// }; +// +// auto task = [&](int id) { +// signals[id].acquire(); +// while (pending_tasks.load(std::memory_order_acquire) > 0) { +// auto value = deque.pop_top(); +// if (value.has_value()) { +// auto temp = std::move(value.value()); +// if (temp) { +// std::invoke(value.value()); +// pending_tasks.fetch_sub(1, std::memory_order_release); +// } +// } +// } +// }; +// +// { +// std::jthread supplier(supply_task); +// std::jthread t1(task, 0); +// std::jthread t2(task, 1); +// std::jthread t3(task, 2); +// std::jthread t4(task, 3); +// } +// +// REQUIRE_EQ(count.load(), expected_sum); +// } +// #endif