diff --git a/include/itlib/generator.hpp b/include/itlib/generator.hpp index fc08b90..b925122 100644 --- a/include/itlib/generator.hpp +++ b/include/itlib/generator.hpp @@ -61,35 +61,33 @@ namespace itlib { -namespace genimpl { -// tempting to include expected here so we could have optional of ref and -// ditch the T& specialization -// ... but we promised to make standalone libs +// why std::optional still doesn't have T& specialization is beyond me +// it's tempting to include expected here so we could have optional of ref and +// ditch our T& specialization, but we promised to make standalone libs template -class val_holder : public std::optional {}; +class generator_value : public std::optional {}; template -class val_holder { - T* val = nullptr; +class generator_value { + T* m_val = nullptr; public: - void emplace(T& v) noexcept { val = &v; } - void reset() noexcept { val = nullptr; } - T& operator*() noexcept { return *val; } - bool has_value() const noexcept { return val != nullptr; } + void emplace(T& v) noexcept { m_val = &v; } + void reset() noexcept { m_val = nullptr; } + T& operator*() noexcept { return *m_val; } + bool has_value() const noexcept { return m_val != nullptr; } explicit operator bool() const noexcept { return has_value(); } }; -} // namespace genimpl - template class generator { public: + // return const ref in case we're generating values, otherwise keep the ref type using value_ret_t = std::conditional_t, T, const T&>; - class promise_type { - genimpl::val_holder m_val; - public: + struct promise_type { + generator_value m_val; + promise_type() noexcept = default; ~promise_type() = default; @@ -99,25 +97,89 @@ class generator { std::suspend_always initial_suspend() noexcept { return {}; } std::suspend_always final_suspend() noexcept { return {}; } std::suspend_always yield_value(T value) noexcept { // assume T is noexcept move constructible - m_val = std::move(value); + if constexpr (std::is_reference_v) { + m_val.emplace(value); + } + else { + m_val.emplace(std::move(value)); + } return {}; } void return_void() noexcept {} void unhandled_exception() { throw; } - value_ret_t val() const noexcept { + value_ret_t val() & noexcept { return *m_val; } + T&& val() && noexcept { + return std::move(*m_val); + } + void clear_value() noexcept { + m_val.reset(); + } }; + using handle_t = std::coroutine_handle; + ~generator() { if (m_handle) m_handle.destroy(); } - // std::optional interface - genimpl::val_holder next() {} + // next (optional-based) interface + + // NOTE: this won't return true until next() has returned an empty optional at least once + bool done() const noexcept { + return m_handle.done(); + } + + generator_value next() { + if (done()) return {}; + m_handle.promise().clear_value(); + m_handle.resume(); + return std::move(m_handle.promise().m_val); + } + + // iterator-like/range-for interface + + // emphasize that this is not a real iterator + class pseudo_iterator { + handle_t m_handle; + public: + using value_type = std::decay_t; + using reference = value_ret_t; + + pseudo_iterator() noexcept = default; + explicit pseudo_iterator(handle_t handle) noexcept : m_handle(handle) {} + + reference operator*() const noexcept { + return m_handle.promise().val(); + } + + pseudo_iterator& operator++() { + m_handle.promise().clear_value(); + m_handle.resume(); + return *this; + } + + struct end_t {}; + + // we're not really an iterator, but we can pretend to be one + friend bool operator==(const pseudo_iterator& i, end_t) noexcept { return i.m_handle.done(); } + friend bool operator==(end_t, const pseudo_iterator& i) noexcept { return i.m_handle.done(); } + friend bool operator!=(const pseudo_iterator& i, end_t) noexcept { return !i.m_handle.done(); } + friend bool operator!=(end_t, const pseudo_iterator& i) noexcept { return !i.m_handle.done(); } + }; + + pseudo_iterator begin() { + m_handle.resume(); + return pseudo_iterator{m_handle}; + } + + pseudo_iterator::end_t end() { + return {}; + } + private: - using handle_t = std::coroutine_handle; handle_t m_handle; explicit generator(handle_t handle) noexcept : m_handle(handle) {} }; diff --git a/test/t-generator-20.cpp b/test/t-generator-20.cpp index f12f8f1..046b913 100644 --- a/test/t-generator-20.cpp +++ b/test/t-generator-20.cpp @@ -2,4 +2,133 @@ // SPDX-License-Identifier: MIT // #include -#include \ No newline at end of file + +#include +#include + +#include +#include +#include + +itlib::generator range(int begin, int end) { + for (int i = begin; i < end; ++i) { + if (i == 103) throw std::runtime_error("test exception"); + co_yield i; + } +} + +TEST_CASE("simple") { + int i = 50; + + // range for + for (int x : range(i, i+10)) { + CHECK(x == i); + ++i; + } + CHECK(i == 60); + + // next + auto r = range(1, 5); + CHECK(*r.next() == 1); + CHECK_FALSE(r.done()); + CHECK(*r.next() == 2); + CHECK(r.next().value() == 3); + CHECK(r.next().value() == 4); + CHECK_FALSE(r.next().has_value()); + CHECK_FALSE(r.next().has_value()); // check once again to make sure it's safe + CHECK(r.done()); + + // exceptions + + i = 0; + CHECK_THROWS_WITH_AS( + [&]() { + for (int x : range(100, 105)) { + i = x; + } + }(), + "test exception", + std::runtime_error + ); + CHECK(i == 102); + + auto tr = range(101, 105); + CHECK_NOTHROW(tr.next()); + CHECK_NOTHROW(tr.next()); + CHECK_THROWS_WITH_AS(tr.next(), "test exception", std::runtime_error); +} + +template +itlib::generator ref_gen(std::span vals) { + for (T& v : vals) { + co_yield v; + } +} + +TEST_CASE("ref") { + std::vector ints = {1, 2, 3, 4, 5}; + auto g = ref_gen(std::span(ints)); + for (int& i : g) { + i += 10; + } + CHECK(ints == std::vector{11, 12, 13, 14, 15}); + + auto cg = ref_gen(std::span(ints)); + const int& a = *cg.next(); + const int& b = *cg.next(); + for (const int& i : cg) { + CHECK(i > 12); + CHECK(i < 16); + } + CHECK(cg.done()); + CHECK(a == 11); + CHECK(&a == ints.data()); + CHECK(b == 12); + CHECK(&b == ints.data() + 1); +} + +struct value : doctest::util::lifetime_counter +{ + value() = default; + explicit value(int i) : val(i) {} + int val = 0; +}; + +itlib::generator value_range(int begin, int end) { + for (int i = begin; i < end; ++i) { + co_yield value(i); + } +} + +TEST_CASE("lifetime") { + doctest::util::lifetime_counter_sentry lcsentry(value::root_lifetime_stats()); + + int i = 0; + { + value::lifetime_stats ls; + + auto r = value_range(0, 5); + for (value v : r) { + CHECK(v.val == i); + ++i; + } + CHECK(i == 5); + + CHECK(ls.living == 0); + CHECK(ls.copies == 5); + CHECK(ls.m_ctr == 5); + } + + { + value::lifetime_stats ls; + + auto r = value_range(0, 3); + auto v1 = r.next(); + auto v2 = r.next(); + auto v3 = r.next(); + auto vend = r.next(); + CHECK(ls.living == 3); + CHECK(ls.copies == 0); + CHECK(ls.m_ctr == 6); + } +}