diff --git a/folly/coro/Collect-inl.h b/folly/coro/Collect-inl.h index 50ac63d0d2f..a39b388a1c7 100644 --- a/folly/coro/Collect-inl.h +++ b/folly/coro/Collect-inl.h @@ -235,8 +235,17 @@ auto makeUnorderedAsyncGeneratorImpl( return [](AsyncScope& scopeParam, InputRange awaitablesParam) -> AsyncGenerator { auto [results, pipe] = AsyncPipe::create(); - const CancellationSource cancelSource; - auto guard = folly::makeGuard([&] { cancelSource.requestCancellation(); }); + struct SharedState { + explicit SharedState(AsyncPipe&& p) : pipe(std::move(p)) {} + + AsyncPipe pipe; + const CancellationSource cancelSource; + }; + auto sharedState = std::make_shared(std::move(pipe)); + auto cancelToken = sharedState->cancelSource.getToken(); + + auto guard = folly::makeGuard( + [&] { sharedState->cancelSource.requestCancellation(); }); auto ex = co_await co_current_executor; size_t expected = 0; // Save the initial context and restore it after starting each task @@ -246,24 +255,19 @@ auto makeUnorderedAsyncGeneratorImpl( const auto context = RequestContext::saveContext(); for (auto&& semiAwaitable : static_cast(awaitablesParam)) { - auto task = [](auto semiAwaitableParam, - auto& cancelSourceParam, - auto& p) -> Task { + auto task = [](auto semiAwaitableParam, auto state) -> Task { auto result = co_await co_awaitTry(std::move(semiAwaitableParam)); if (!result.hasValue() && !IsTry::value) { - cancelSourceParam.requestCancellation(); + state->cancelSource.requestCancellation(); } - p.write(std::move(result)); - }(static_cast(semiAwaitable), - cancelSource, - pipe); + state->pipe.write(std::move(result)); + }(static_cast(semiAwaitable), sharedState); if constexpr (std::is_same_v) { scopeParam.add( - co_withCancellation(cancelSource.getToken(), std::move(task)) - .scheduleOn(ex)); + co_withCancellation(cancelToken, std::move(task)).scheduleOn(ex)); } else { static_assert(std::is_same_v); - scopeParam.add(std::move(task).scheduleOn(ex), cancelSource.getToken()); + scopeParam.add(std::move(task).scheduleOn(ex), cancelToken); } ++expected; RequestContext::setContext(context); @@ -272,7 +276,7 @@ auto makeUnorderedAsyncGeneratorImpl( while (expected > 0) { CancellationCallback cancelCallback( co_await co_current_cancellation_token, - [&]() noexcept { cancelSource.requestCancellation(); }); + [&]() noexcept { sharedState->cancelSource.requestCancellation(); }); if constexpr (!IsTry::value) { auto result = co_await co_awaitTry(results.next()); diff --git a/folly/coro/test/CollectTest.cpp b/folly/coro/test/CollectTest.cpp index 5031f9c93e1..2a5ec6723fa 100644 --- a/folly/coro/test/CollectTest.cpp +++ b/folly/coro/test/CollectTest.cpp @@ -3457,3 +3457,31 @@ TEST_F( co_await scope.joinAsync(); }()); } + +TEST(MakeUnorderedAsyncGeneratorTest, GeneratorEarlyDestroy) { + folly::coro::blockingWait([]() -> folly::coro::Task { + folly::coro::AsyncScope scope; + folly::CPUThreadPoolExecutor executor(2); + + std::vector> tasks; + + tasks.push_back(folly::coro::co_invoke([]() -> folly::coro::Task { + co_await folly::coro::co_reschedule_on_current_executor; + std::this_thread::sleep_for(std::chrono::seconds{2}); + co_return 42; + }).scheduleOn(&executor)); + tasks.push_back(folly::coro::co_invoke([]() -> folly::coro::Task { + co_await folly::coro::co_reschedule_on_current_executor; + std::this_thread::sleep_for(std::chrono::seconds{1}); + co_return 43; + }).scheduleOn(&executor)); + + { + auto gen = + folly::coro::makeUnorderedAsyncGenerator(scope, std::move(tasks)); + EXPECT_EQ(43, *(co_await gen.next())); + } + + co_await scope.joinAsync(); + }()); +}