diff --git a/CMakeLists.txt b/CMakeLists.txt index debac7e..0383d85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) # ---- Project ---- diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 37e7662..272e86f 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) project(ThreadPoolBenchmarks LANGUAGES CXX) diff --git a/documentation/CMakeLists.txt b/documentation/CMakeLists.txt index 5a9789f..d743bb5 100644 --- a/documentation/CMakeLists.txt +++ b/documentation/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) project(ThreadPoolDocs) diff --git a/examples/mandelbrot/CMakeLists.txt b/examples/mandelbrot/CMakeLists.txt index 2d8c087..45a0789 100644 --- a/examples/mandelbrot/CMakeLists.txt +++ b/examples/mandelbrot/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) project(Mandelbrot LANGUAGES CXX) diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 0a5229b..2062151 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -117,7 +117,12 @@ namespace dp { auto task = [func = std::move(f), ... largs = std::move(args), promise = std::move(promise)]() mutable { try { - promise.set_value(func(largs...)); + if constexpr (std::is_same_v) { + func(largs...); + promise.set_value(); + } else { + promise.set_value(func(largs...)); + } } catch (...) { promise.set_exception(std::current_exception()); } @@ -139,7 +144,13 @@ namespace dp { auto task = [func = std::move(f), ... largs = std::move(args), promise = shared_promise]() { try { - promise->set_value(func(largs...)); + if constexpr (std::is_same_v) { + func(largs...); + promise->set_value(); + } else { + promise->set_value(func(largs...)); + } + } catch (...) { promise->set_exception(std::current_exception()); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bc79d98..94d7618 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) project(ThreadPoolTests LANGUAGES CXX) diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index 0b0ac8c..7e261f9 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -47,6 +48,14 @@ TEST_CASE("Pass raw reference to pool") { CHECK_EQ(x, 2); } +TEST_CASE("Support enqueue with void return type") { + dp::thread_pool pool; + auto value = 8; + auto future = pool.enqueue([](int& x) { x *= 2; }, std::ref(value)); + future.wait(); + CHECK_EQ(value, 16); +} + TEST_CASE("Ensure input params are properly passed") { dp::thread_pool pool(4); constexpr auto total_tasks = 30; @@ -245,3 +254,62 @@ TEST_CASE("Ensure work completes with fewer threads than expected.") { CHECK_EQ(counter.load(), total_tasks); } + +void recursive_sequential_sum(std::atomic_int32_t& counter, int count, dp::thread_pool<>& pool) { + counter.fetch_add(count); + if (count > 1) { + pool.enqueue_detach(recursive_sequential_sum, std::ref(counter), count - 1, std::ref(pool)); + } +} + +TEST_CASE("Recursive enqueue calls work correctly") { + std::atomic_int32_t counter = 0; + constexpr auto start = 1000; + { + dp::thread_pool pool(4); + recursive_sequential_sum(counter, start, pool); + } + + auto expected_sum = 0; + for (int i = 0; i <= start; i++) { + expected_sum += i; + } + CHECK_EQ(expected_sum, counter.load()); +} + +void recursive_parallel_sort(int* begin, int* end, int split_level, dp::thread_pool<>& pool) { + if (split_level < 2 || end - begin < 2) { + std::sort(begin, end); + } else { + const auto mid = begin + (end - begin) / 2; + if (split_level == 2) { + const auto future = + pool.enqueue(recursive_parallel_sort, begin, mid, split_level / 2, std::ref(pool)); + std::sort(mid, end); + future.wait(); + } else { + const auto left = + pool.enqueue(recursive_parallel_sort, begin, mid, split_level / 2, std::ref(pool)); + const auto right = + pool.enqueue(recursive_parallel_sort, mid, end, split_level / 2, std::ref(pool)); + + left.wait(); + right.wait(); + } + std::inplace_merge(begin, mid, end); + } +} + +TEST_CASE("Recursive parallel sort") { + std::vector data(10000); + // std::ranges::iota is a C++23 feature + std::iota(data.begin(), data.end(), 0); + std::ranges::shuffle(data, std::mt19937{std::random_device{}()}); + + { + dp::thread_pool pool(4); + recursive_parallel_sort(data.data(), data.data() + data.size(), 4, pool); + } + + CHECK(std::ranges::is_sorted(data)); +}