From 7a0cd26737bfa520249ade17a82e3fa7866b5b10 Mon Sep 17 00:00:00 2001 From: Ian Petersen Date: Mon, 26 Aug 2024 14:05:57 -0700 Subject: [PATCH] Add async stack support to coroutines This change extends the work in #616 to support async stack frames in `task<>` coroutines, including those that invoke `at_coroutine_exit()`. In `task<>`, when `UNIFEX_NO_ASYNC_STACKS` is falsey, the awaiter returned from `task<>`'s customization of `unifex::await_transform` stores an `AsyncStackFrame`. The awaiter pushes its frame onto the current async stack in `await_suspend()` and pops it again in `await_resume()`; since `await_resume()` is only invoked for value and error completions, this arrangement leaves it up to the waiting task to pop the awaiter's frame when the awaited task completes with done. This can be expressed as a new rule: - when a coroutine completes with a value or an error, it is responsible for popping its own `AsyncStackFrame`; but - when a coroutine completes with done, the *caller* is responsible for popping the callee's `AsyncStackFrame` as a part of the caller's `unhandled_done()` coroutine. To support this new requirement of `unhandled_done()` (that it is responsible for popping the callee's stack frame), this change introduces `popAsyncStackFrameFromCaller`, which takes the caller's stack frame by reference so that it can assert that, after popping the current async frame (whatever it is), the new top frame is the caller's frame. A `task<>` promise has an `AsyncStackFrame*` that, when it's not `nullptr`, points to the `AsyncStackFrame` in the awaiter waiting for the task. This pointer exists even when `UNIFEX_NO_ASYNC_STACKS` is truthy to help mitigate against ODR violations; linking together two TUs with `UNIFEX_NO_ASYNC_STACKS` set differently is not explicitly supported but, by ensuring this pointer always exists, some ODR problems are avoided. When a `task<>` is awaited from a TU with async stack support enabled, the awaited task's awaiter sets the promise's `AsyncStackFrame*` to point to the awaiter's frame; when a `task<>` is awaited from a TU with async stack support disabled, this assignment never happens and the promise's pointer remains null. The above description of `task<>`'s async stack maintenance only covers the recursive case of on coroutine awaiting another. The base case is handled in `connect_awaitable()`, where an `AsyncStackRoot` is set up before starting the connected awaitable. `stop_if_requested` used to model both `sender` and `awaitable` so that `co_await stop_if_requested();` could take advantage of symmetric transfer. The `stop_if_requested` sender now customizes `await_transform` to express its participation in async stack management. This means of expressing async stack awareness is unsatisfying but I don't have any better ideas right now. Lastly, `unifex::await_transform()` now wraps naturally-awaitable arguments in an `awaiter_wrapper` that ensures the `coroutine_handle<>` passed to the wrapped awaitable is one that establishes an active `AsyncStackRoot` before resuming the real waiting coroutine. --- include/unifex/at_coroutine_exit.hpp | 130 +++++--- include/unifex/await_transform.hpp | 367 +++++++++++++++++++-- include/unifex/connect_awaitable.hpp | 109 ++++-- include/unifex/stop_if_requested.hpp | 60 +++- include/unifex/task.hpp | 262 +++++++++++---- include/unifex/tracing/async_stack-inl.hpp | 10 + include/unifex/tracing/async_stack.hpp | 5 +- source/task.cpp | 5 +- 8 files changed, 774 insertions(+), 174 deletions(-) diff --git a/include/unifex/at_coroutine_exit.hpp b/include/unifex/at_coroutine_exit.hpp index 775fc0b9..ae9f0339 100644 --- a/include/unifex/at_coroutine_exit.hpp +++ b/include/unifex/at_coroutine_exit.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #if UNIFEX_NO_COROUTINES # error "Coroutine support is required to use this header" @@ -57,10 +58,25 @@ inline constexpr struct _fn { } // namespace _xchg_cont using _xchg_cont::exchange_continuation; +template struct _cleanup_promise_base { struct final_awaitable { bool await_ready() const noexcept { return false; } + template + coro::coroutine_handle<> await_suspend_impl( + coro::coroutine_handle h) const noexcept { + if constexpr (WithAsyncStackSupport) { + if (h.promise().parentFrame_ != nullptr) { + popAsyncStackFrameCallee(h.promise().frame_); + } + } + + auto continuation = h.promise().next(); + h.destroy(); // The cleanup action has finished executing. Destroy it. + return continuation; + } + // Clang before clang-12 has a bug with coroutines that self-destruct in an // await_suspend that uses symmetric transfer. It appears that MSVC has the // same bug, while Emscripten, the WebAssembly compiler just doesn't support @@ -81,18 +97,14 @@ struct _cleanup_promise_base { # endif void await_suspend(coro::coroutine_handle h) const noexcept { - auto continuation = h.promise().next(); - h.destroy(); // The cleanup action has finished executing. Destroy it. - continuation.resume(); + await_suspend_impl(h).resume(); } #else // No bugs here! OK to use symmetric transfer. template coro::coroutine_handle<> await_suspend(coro::coroutine_handle h) const noexcept { - auto continuation = h.promise().next(); - h.destroy(); // The cleanup action has finished executing. Destroy it. - return continuation; + return await_suspend_impl(h); } #endif @@ -135,10 +147,24 @@ struct _cleanup_promise_base { return p.sched_; } + template(typename Promise) // + (requires WithAsyncStackSupport AND + convertible_to) // + friend constexpr AsyncStackFrame* tag_invoke( + tag_t, const Promise& p) noexcept { + return &p.frame_; + } + inline static constexpr inline_scheduler _default_scheduler{}; continuation_handle<> continuation_{}; any_scheduler sched_{_default_scheduler}; bool isUnhandledDone_{false}; + UNIFEX_NO_UNIQUE_ADDRESS mutable std:: + conditional_t> + parentFrame_{}; + UNIFEX_NO_UNIQUE_ADDRESS mutable std:: + conditional_t> + frame_; }; // The die_on_done algorithm implemented here could be implemented in terms of @@ -233,16 +259,16 @@ struct _die_on_done_fn { } }; -template +template struct _cleanup_task; -template -struct _cleanup_promise : _cleanup_promise_base { +template +struct _cleanup_promise : _cleanup_promise_base { template explicit _cleanup_promise(Action&&, Ts&... ts) : args_(ts...) {} - _cleanup_task get_return_object() noexcept { - return _cleanup_task( + _cleanup_task get_return_object() noexcept { + return _cleanup_task( coro::coroutine_handle<_cleanup_promise>::from_promise(*this)); } @@ -253,6 +279,12 @@ struct _cleanup_promise : _cleanup_promise_base { template decltype(auto) await_transform(Value&& value) noexcept(noexcept( unifex::await_transform(*this, _die_on_done_fn{}((Value&&)value)))) { + if constexpr (WithAsyncStackSupport) { + if (this->parentFrame_ != nullptr) { + pushAsyncStackFrameCallerCallee(*this->parentFrame_, this->frame_); + } + } + return unifex::await_transform(*this, _die_on_done_fn{}((Value&&)value)); } @@ -261,15 +293,15 @@ struct _cleanup_promise : _cleanup_promise_base { // Record that we are processing an unhandled done signal. This is checked // in the final_suspend of the cleanup action to know which subsequent // continuation to resume. - isUnhandledDone_ = true; + this->isUnhandledDone_ = true; // On unhandled_done, run the cleanup action: return coro::coroutine_handle<_cleanup_promise>::from_promise(*this); }); }; -template +template struct [[nodiscard]] _cleanup_task { - using promise_type = _cleanup_promise; + using promise_type = _cleanup_promise; explicit _cleanup_task(coro::coroutine_handle coro) noexcept : continuation_(coro) {} @@ -279,29 +311,46 @@ struct [[nodiscard]] _cleanup_task { ~_cleanup_task() { UNIFEX_ASSERT(!continuation_); } - bool await_ready() const noexcept { return false; } + struct awaiter { + bool await_ready() const noexcept { return false; } - template - bool await_suspend_impl_(Promise& parent) noexcept { - continuation_.promise().continuation_ = - exchange_continuation(parent, continuation_); - continuation_.promise().sched_ = get_scheduler(parent); - return false; - } + template + bool await_suspend_impl_( + Promise& parent, + [[maybe_unused]] instruction_ptr returnAddress = + instruction_ptr::read_return_address()) noexcept { + continuation_.promise().continuation_ = + exchange_continuation(parent, continuation_); + continuation_.promise().sched_ = get_scheduler(parent); + if constexpr (WithAsyncStackSupport) { + continuation_.promise().parentFrame_ = get_async_stack_frame(parent); + continuation_.promise().frame_.setReturnAddress(returnAddress); + } + return false; + } - template - bool await_suspend(coro::coroutine_handle parent) noexcept { - return await_suspend_impl_(parent.promise()); - } + template + UNIFEX_NO_INLINE bool + await_suspend(coro::coroutine_handle parent) noexcept { + return await_suspend_impl_(parent.promise()); + } - std::tuple await_resume() noexcept { - return std::move(std::exchange(continuation_, {}).promise().args_); - } + std::tuple await_resume() noexcept { + return std::move(std::exchange(continuation_, {}).promise().args_); + } + + // TODO: how do we address always-inline awaitables + friend constexpr auto tag_invoke(tag_t, const awaiter&) noexcept { + return blocking_kind::always_inline; + } - // TODO: how do we address always-inline awaitables - friend constexpr auto - tag_invoke(tag_t, const _cleanup_task&) noexcept { - return blocking_kind::always_inline; + continuation_handle continuation_; + }; + + template + friend awaiter + tag_invoke(tag_t, Promise&, _cleanup_task task) noexcept { + return awaiter{std::exchange(task.continuation_, {})}; } private: @@ -311,18 +360,23 @@ struct [[nodiscard]] _cleanup_task { namespace _at_coroutine_exit { inline constexpr struct _fn { private: - template - static _cleanup_task at_coroutine_exit(Action action, Ts... ts) { + template + static _cleanup_task + at_coroutine_exit(Action action, Ts... ts) { co_await std::move(action)(std::move(ts)...); } public: - template(typename Action, typename... Ts) // + template( + typename Action, + typename... Ts, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS) // (requires std:: is_invocable_v, std::decay_t...>) // - _cleanup_task...> + _cleanup_task...> operator()(Action&& action, Ts&&... ts) const { - return _fn::at_coroutine_exit((Action&&)action, (Ts&&)ts...); + return _fn::at_coroutine_exit( + (Action&&)action, (Ts&&)ts...); } } at_coroutine_exit{}; } // namespace _at_coroutine_exit diff --git a/include/unifex/await_transform.hpp b/include/unifex/await_transform.hpp index 10530968..e0d08ed2 100644 --- a/include/unifex/await_transform.hpp +++ b/include/unifex/await_transform.hpp @@ -68,18 +68,18 @@ struct _expected { namespace _await_tfx { using namespace _util; -template +template struct _awaitable_base { struct type; }; -template +template struct _awaitable { struct type; }; -template -struct _awaitable_base::type { +template +struct _awaitable_base::type { struct _rec { public: explicit _rec( @@ -92,6 +92,19 @@ struct _awaitable_base::type { : result_(std::exchange(r.result_, nullptr)) , continuation_(std::move(r.continuation_)) {} + void complete() noexcept { + if constexpr (WithAsyncStackSupport) { + if (auto* frame = get_async_stack_frame(continuation_.promise())) { + detail::ScopedAsyncStackRoot root; + root.activateFrame(*frame); + return continuation_.resume(); + } + } + + // run this when stacks are disabled and when the parent hasn't got one + continuation_.resume(); + } + template(class... Us) // (requires( constructible_from || @@ -101,13 +114,13 @@ struct _awaitable_base::type { std::is_void_v) { unifex::activate_union_member(result_->value_, (Us&&)us...); result_->state_ = _state::value; - continuation_.resume(); + complete(); } void set_error(std::exception_ptr eptr) && noexcept { unifex::activate_union_member(result_->exception_, std::move(eptr)); result_->state_ = _state::exception; - continuation_.resume(); + complete(); } void set_error(std::error_code code) && noexcept { @@ -117,6 +130,23 @@ struct _awaitable_base::type { void set_done() && noexcept { result_->state_ = _state::done; + + if constexpr (WithAsyncStackSupport) { + if (auto* parentFrame = + get_async_stack_frame(continuation_.promise())) { + // we need a dummy frame for the waiting coroutine's unhandled_done() + // to pop for us + AsyncStackFrame frame; + frame.setParentFrame(*parentFrame); + + detail::ScopedAsyncStackRoot root; + root.activateFrame(frame); + + return continuation_.resume_done(); + } + } + + // run this when stacks are disabled and when the parent hasn't got one continuation_.resume_done(); } @@ -158,18 +188,21 @@ struct _awaitable_base::type { _expected result_; }; -template +template using _awaitable_base_t = typename _awaitable_base< Promise, - sender_single_value_return_type_t>>::type; + sender_single_value_return_type_t>, + WithAsyncStackSupport>::type; -template -using _receiver_t = typename _awaitable_base_t::_rec; +template +using _receiver_t = + typename _awaitable_base_t::_rec; -template -struct _awaitable::type : _awaitable_base_t { +template +struct _awaitable::type + : _awaitable_base_t { private: - using _rec = _receiver_t; + using _rec = _receiver_t; connect_result_t op_; public: @@ -177,13 +210,288 @@ struct _awaitable::type : _awaitable_base_t { is_nothrow_connectable_v) : op_(unifex::connect((Sender&&)sender, _rec{&this->result_, h})) {} - void await_suspend(coro::coroutine_handle) noexcept { + void await_suspend(coro::coroutine_handle handle) noexcept { + if constexpr (WithAsyncStackSupport) { + auto* frame = get_async_stack_frame(handle.promise()); + if (frame) { + deactivateAsyncStackFrame((*frame)); + } + } unifex::start(op_); } }; -template -using _as_awaitable = typename _awaitable::type; +template +using _as_awaitable = + typename _awaitable::type; + +template +struct is_resumer_promise : std::false_type {}; + +template +struct is_resumer_promise : std::true_type {}; + +template +constexpr bool is_resumer_promise_v = is_resumer_promise::value; + +template +struct _coro_resumer final { + struct type; +}; + +template +struct _coro_resumer::type final { + struct promise_type { + using resumer_promise_t = void; + + static_assert(!is_resumer_promise_v); + + promise_type(coro::coroutine_handle& h) noexcept : handle_(h) {} + + type get_return_object() noexcept { + return type{coro::coroutine_handle::from_promise(*this)}; + } + + coro::suspend_always initial_suspend() noexcept { return {}; } + + [[noreturn]] coro::suspend_always final_suspend() noexcept { + std::terminate(); + } + + // TODO: unhandled_done()? + + [[noreturn]] void return_void() noexcept { std::terminate(); } + + [[noreturn]] void unhandled_exception() noexcept { std::terminate(); } + + struct awaiter { + coro::coroutine_handle h; + + bool await_ready() noexcept { return false; } + + void await_suspend(coro::coroutine_handle<>) noexcept { + auto* frame = get_async_stack_frame(h.promise()); + if (frame) { + detail::ScopedAsyncStackRoot root; + root.activateFrame(*frame); + + h.resume(); + + root.ensureFrameDeactivated(frame); + } else { + h.resume(); + } + } + + [[noreturn]] void await_resume() noexcept { std::terminate(); } + }; + + awaiter await_transform(coro::coroutine_handle h) noexcept { + return awaiter{h}; + } + + template(typename CPO) // + (requires is_receiver_query_cpo_v) // + friend auto tag_invoke(CPO cpo, const promise_type& self) noexcept( + is_nothrow_tag_invocable_v) + -> tag_invoke_result_t { + return tag_invoke(std::move(cpo), std::as_const(self.handle_.promise())); + } + + continuation_handle handle_; + }; + + type() noexcept = default; + + type(type&& other) noexcept : h_(std::exchange(other.h_, {})) {} + + ~type() { + if (h_) { + h_.destroy(); + } + } + + type& operator=(type rhs) noexcept { + std::swap(h_, rhs.h_); + return *this; + } + + coro::coroutine_handle handle() && noexcept { + return std::exchange(h_, {}); + } + +private: + explicit type(coro::coroutine_handle h) noexcept : h_(h) {} + + coro::coroutine_handle h_; +}; + +template +using coro_resumer = typename _coro_resumer::type; + +template +coro_resumer +resume_with_stack_root(coro::coroutine_handle h) { + co_await h; +} + +template +struct _awaitable_wrapper final { + class type; +}; + +template +class _awaitable_wrapper::type final { + using awaiter_t = awaiter_type_t; + + Awaitable&& awaitable_; + awaiter_t awaiter_; + coro::coroutine_handle<> coro_; + +public: + using awaitable_wrapper_t = void; + + explicit type(Awaitable&& awaitable) + : awaitable_(std::forward(awaitable)) + , awaiter_(get_awaiter(std::forward(awaitable))) {} + + type(type&& other) noexcept(std::is_nothrow_move_constructible_v) + : awaitable_(std::move(other.awaitable_)) + , awaiter_(std::move(other.awaiter_)) + , coro_(std::exchange(other.coro_, {})) { + // we should only be move-constructed before being awaited + UNIFEX_ASSERT(!coro_); + } + + ~type() { + if (coro_) { + coro_.destroy(); + } + } + + bool await_ready() noexcept(noexcept(awaiter_.await_ready())) { + return awaiter_.await_ready(); + } + + template + using resume_coro_handle_t = + coro::coroutine_handle::promise_type>; + + template + using _suspend_result_t = decltype(awaiter_.await_suspend( + resume_coro_handle_t::from_address(nullptr))); + + template + using suspend_result_t = std::conditional_t< + convertible_to<_suspend_result_t, coro::coroutine_handle<>>, + coro::coroutine_handle<>, + _suspend_result_t>; + + template(typename Promise) // + (requires same_as>) // + bool await_suspend_impl( + coro::coroutine_handle h, AsyncStackFrame* frame) { + auto* root = frame->getStackRoot(); + + auto resumer = resume_with_stack_root(h).handle(); + + // save for later destruction + coro_ = resumer; + + // ensure that it's safe for the resumer coroutine to activate h's stack + // frame on resumption + deactivateAsyncStackFrame(*frame); + + if (awaiter_.await_suspend(resumer)) { + // suspend + return true; + } else { + // we're not actually suspending so undo the stack manipulation we just + // did + activateAsyncStackFrame(*root, *frame); + + // proactively destroy the unneeded coro_resumer + std::exchange(coro_, {}).destroy(); + + // resume the caller + return false; + } + } + + template(typename Promise) // + (requires(!same_as>)) // + suspend_result_t await_suspend_impl( + coro::coroutine_handle h, AsyncStackFrame* frame) { + auto resumer = resume_with_stack_root(h).handle(); + + // save for later destruction + coro_ = resumer; + + // ensure that it's safe for the resumer coroutine to activate h's stack + // frame on resumption + deactivateAsyncStackFrame(*frame); + + return awaiter_.await_suspend(resumer); + } + + template + suspend_result_t await_suspend(coro::coroutine_handle h) { + if (auto* frame = get_async_stack_frame(h.promise())) { + return await_suspend_impl(h, frame); + } + + using awaiter_suspend_result_t = decltype(awaiter_.await_suspend(h)); + + // Note: it's technically possible for an awaitable's implementation of + // await_suspend() to return different types depending on its argument + // type. This is easily handled if the "different types" are different + // coroutine_handle<> types: just convert them all to + // coro::coroutine_handle<>; but it's a pain if the different return + // types mix-and-match between void, bool, and coroutine handles. If + // any reports ever come in that these static asserts are breaking + // builds, we can handle it by forcing *our* return type to always be + // coro::coroutine_handle<> and just map the void and bool cases to + // the appropriate handle, but let's avoid that complexity until it's + // proven necessary. + if constexpr (same_as>) { + static_assert(same_as); + } else if constexpr (same_as>) { + static_assert(same_as); + } else { + static_assert( + convertible_to>); + } + + return awaiter_.await_suspend(h); + } + + auto await_resume() noexcept(noexcept(awaiter_.await_resume())) + -> decltype(awaiter_.await_resume()) { + return awaiter_.await_resume(); + } + + template(typename CPO) // + (requires same_as, CPO> AND + std::is_invocable_v) // + friend auto tag_invoke(CPO cpo, const type& self) noexcept( + std::is_nothrow_invocable_v) + -> std::invoke_result_t { + return std::move(cpo)(std::as_const(self.awaitable)); + } +}; + +template +using awaitable_wrapper = typename _awaitable_wrapper::type; + +template +struct is_awaitable_wrapper : std::false_type {}; + +template +struct is_awaitable_wrapper + : std::true_type {}; + +template +constexpr bool is_awaitable_wrapper_v = is_awaitable_wrapper::value; struct _fn { // Call custom implementation if present. @@ -201,26 +509,41 @@ struct _fn { } // Default implementation for naturally awaitable types - template(typename Promise, typename Value) // + template( + typename Promise, + typename Value, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS) // (requires(!tag_invocable<_fn, Promise&, Value>) AND detail::_awaitable) // - Value&& + decltype(auto) operator()(Promise&, Value&& value) const noexcept { - return std::forward(value); + if constexpr ( + WithAsyncStackSupport && + !is_awaitable_wrapper_v>) { + return awaitable_wrapper{std::forward(value)}; + } else { + return std::forward(value); + } } // Default implementation for non-awaitable senders - template(typename Promise, typename Value) // + template( + typename Promise, + typename Value, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS) // (requires(!tag_invocable<_fn, Promise&, Value>) AND(!detail::_awaitable) AND unifex::sender) // decltype(auto) operator()(Promise& promise, Value&& value) const { static_assert( - unifex::sender_to>, + unifex::sender_to< + Value, + _receiver_t>, "This sender is not awaitable in this coroutine type."); auto h = coro::coroutine_handle::from_promise(promise); - return _as_awaitable{(Value&&)value, h}; + return _as_awaitable{ + (Value&&)value, h}; } // Fall back to returning the argument if none of the above conditions are met diff --git a/include/unifex/connect_awaitable.hpp b/include/unifex/connect_awaitable.hpp index 3cb21f91..507beb61 100644 --- a/include/unifex/connect_awaitable.hpp +++ b/include/unifex/connect_awaitable.hpp @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -36,19 +38,28 @@ namespace unifex { namespace _await { -template +template struct _sender_task { class type; }; -template -using sender_task = typename _sender_task::type; +template +using sender_task = + typename _sender_task::type; -template -class _sender_task::type { +template +class _sender_task::type { public: struct promise_type { template - explicit promise_type(Awaitable&, Receiver& r) noexcept : receiver_(r) {} + explicit promise_type( + Awaitable&, + Receiver& r, + [[maybe_unused]] instruction_ptr returnAddress) noexcept + : receiver_(r) { + if constexpr (WithAsyncStackSupport) { + frame_.setReturnAddress(returnAddress); + } + } type get_return_object() noexcept { return type{coro::coroutine_handle::from_promise(*this)}; @@ -68,8 +79,12 @@ class _sender_task::type { struct awaiter { Func&& func_; bool await_ready() noexcept { return false; } - void await_suspend(coro::coroutine_handle) noexcept( + void await_suspend(coro::coroutine_handle h) noexcept( std::is_nothrow_invocable_v) { + if constexpr (WithAsyncStackSupport) { + deactivateAsyncStackFrame(h.promise().frame_); + } + ((Func&&)func_)(); } [[noreturn]] void await_resume() noexcept { std::terminate(); } @@ -99,12 +114,27 @@ class _sender_task::type { friend auto tag_invoke(CPO cpo, const promise_type& p) noexcept( std::is_nothrow_invocable_v) -> std::invoke_result_t { - return cpo(std::as_const(p.receiver_)); + if constexpr ( + WithAsyncStackSupport && same_as>) { + return &p.frame_; + } else { + return std::move(cpo)(std::as_const(p.receiver_)); + } } Receiver& receiver_; - done_coro doneCoro_ = unifex::unhandled_done( - [this]() noexcept { unifex::set_done(std::move(receiver_)); }); + done_coro doneCoro_ = unifex::unhandled_done([this]() noexcept { + if constexpr (WithAsyncStackSupport) { + popAsyncStackFrameFromCaller(frame_); + deactivateAsyncStackFrame(frame_); + } + + unifex::set_done(std::move(receiver_)); + }); + + UNIFEX_NO_UNIQUE_ADDRESS mutable std:: + conditional_t> + frame_; }; coro::coroutine_handle coro_; @@ -119,7 +149,24 @@ class _sender_task::type { coro_.destroy(); } - void start() & noexcept { coro_.resume(); } + void start() & noexcept { + if constexpr (WithAsyncStackSupport) { + detail::ScopedAsyncStackRoot root; + + auto* frame = &coro_.promise().frame_; + if (auto parentFrame = get_async_stack_frame(coro_.promise().receiver_)) { + frame->setParentFrame(*parentFrame); + } + + root.activateFrame(*frame); + + coro_.resume(); + + root.ensureFrameDeactivated(frame); + } else { + coro_.resume(); + } + } }; } // namespace _await @@ -138,9 +185,10 @@ inline const struct _fn { operator unit() const noexcept { return {}; } }; - template - static auto connect_impl(Awaitable awaitable, Receiver receiver) - -> _await::sender_task { + template + static auto + connect_impl(Awaitable awaitable, Receiver receiver, instruction_ptr) + -> _await::sender_task { #if !UNIFEX_NO_EXCEPTIONS std::exception_ptr ex; try { @@ -149,7 +197,8 @@ inline const struct _fn { // The _sender_task's promise type has an await_transform that passes the // awaitable through unifex::await_transform. So take that into // consideration when computing the result type: - using promise_type = typename _await::sender_task::promise_type; + using promise_type = typename _await:: + sender_task::promise_type; using awaitable_type = std::invoke_result_t< tag_t, promise_type&, @@ -165,12 +214,17 @@ inline const struct _fn { // after the coroutine is suspended so that it is safe // for the receiver to destroy the coroutine. co_yield [&](result_type&& result) { - return [&] { - if constexpr (std::is_void_v>) { - unifex::set_value(std::move(receiver)); - } else { - unifex::set_value( - std::move(receiver), static_cast(result)); + return [&]() noexcept { + UNIFEX_TRY { + if constexpr (std::is_void_v>) { + unifex::set_value(std::move(receiver)); + } else { + unifex::set_value( + std::move(receiver), static_cast(result)); + } + } + UNIFEX_CATCH(...) { + unifex::set_error(std::move(receiver), std::current_exception()); } }; // The _comma_hack here makes this well-formed when the co_await @@ -189,10 +243,16 @@ inline const struct _fn { } public: - template + template < + typename Awaitable, + typename Receiver, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> auto operator()(Awaitable&& awaitable, Receiver&& receiver) const - -> _await::sender_task> { - return connect_impl((Awaitable&&)awaitable, (Receiver&&)receiver); + -> _await::sender_task, WithAsyncStackSupport> { + return connect_impl( + (Awaitable&&)awaitable, + (Receiver&&)receiver, + instruction_ptr::read_return_address()); } } connect_awaitable{}; } // namespace _await_cpo @@ -294,6 +354,7 @@ struct _fn { (requires detail::_awaitable) // _sender> operator()(Awaitable&& awaitable) const { + // TODO: this is going to generate an unfortunate return address return _sender>{ (Awaitable&&)awaitable, instruction_ptr::read_return_address()}; } diff --git a/include/unifex/stop_if_requested.hpp b/include/unifex/stop_if_requested.hpp index 733a6064..edb25960 100644 --- a/include/unifex/stop_if_requested.hpp +++ b/include/unifex/stop_if_requested.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License Version 2.0 with LLVM Exceptions * (the "License"); you may not use this file except in compliance with @@ -16,11 +16,14 @@ #pragma once #include +#include #include #include #include #include +#include #include +#include #include @@ -37,35 +40,60 @@ struct _fn { void start() & noexcept { UNIFEX_TRY { if (get_stop_token(std::as_const(rec_)).stop_requested()) { - unifex::set_done((Receiver &&) rec_); + unifex::set_done((Receiver&&)rec_); } else { - unifex::set_value((Receiver &&) rec_); + unifex::set_value((Receiver&&)rec_); } } UNIFEX_CATCH(...) { - unifex::set_error((Receiver &&) rec_, std::current_exception()); + unifex::set_error((Receiver&&)rec_, std::current_exception()); } } }; }; - public: #if !UNIFEX_NO_COROUTINES + template + struct awaiter { + UNIFEX_NO_UNIQUE_ADDRESS + std::conditional_t< + WithAsyncStackSupport, + AsyncStackFrame, + detail::_empty<0>> + frame; + + bool await_ready() const noexcept { return false; } + + template + coro::coroutine_handle<> + await_suspend(coro::coroutine_handle coro) noexcept { + if (get_stop_token(coro.promise()).stop_requested()) { + if constexpr (WithAsyncStackSupport) { + frame.setReturnAddress(); + if (auto parentFrame = get_async_stack_frame(coro.promise())) { + pushAsyncStackFrameCallerCallee(*parentFrame, frame); + } + } + + return coro.promise().unhandled_done(); + } + return coro; // don't suspend + } + void await_resume() const noexcept {} + }; + // Provide an awaiter interface in addition to the sender interface // because as an awaiter we can take advantage of symmetric transfer // to save stack space: - bool await_ready() const noexcept { return false; } - template - coro::coroutine_handle<> - await_suspend(coro::coroutine_handle coro) const noexcept { - if (get_stop_token(coro.promise()).stop_requested()) { - return coro.promise().unhandled_done(); - } - return coro; // don't suspend + template < + typename Promise, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> + friend awaiter + tag_invoke(tag_t, Promise&, _sender) noexcept { + return {}; } - void await_resume() const noexcept {} #endif - + public: template < template class Variant, @@ -84,7 +112,7 @@ struct _fn { (requires receiver_of) // auto connect(Receiver&& rec) const -> typename _op>::type { - return typename _op>::type{(Receiver &&) rec}; + return typename _op>::type{(Receiver&&)rec}; } }; diff --git a/include/unifex/task.hpp b/include/unifex/task.hpp index 1a73373f..e0b36f45 100644 --- a/include/unifex/task.hpp +++ b/include/unifex/task.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include #include #include #include @@ -33,6 +34,8 @@ #include #include #include +#include +#include #include #include #include @@ -181,6 +184,11 @@ struct _promise_base { return std::exchange(p.continuation_, std::move(action)); } + friend constexpr AsyncStackFrame* + tag_invoke(tag_t, const _promise_base& p) noexcept { + return p.frame_; + } + #ifdef UNIFEX_ENABLE_CONTINUATION_VISITATIONS template friend void @@ -199,6 +207,10 @@ struct _promise_base { inplace_stop_token stoken_; // the coroutine to resume when a child awaitable completes with done done_coro doneCoro_; + // the async stack frame corresponding to this coroutine + // null until this coroutine is awaited; stays null when async stack support + // is disabled + AsyncStackFrame* frame_{}; }; /** @@ -206,8 +218,13 @@ struct _promise_base { */ struct _task_promise_base : _promise_base { _task_promise_base() - : _promise_base([this]() noexcept { return continuation_.done_handle(); }) { - } + : _promise_base([this]() noexcept { + if (frame_) { + popAsyncStackFrameFromCaller(*frame_); + } + + return continuation_.done_handle(); + }) {} // the implementation of the magic of co_await schedule(s); this is to be // ripped out and replaced with something more explicit @@ -342,7 +359,9 @@ struct _promise final { return awaiter{}; } - template + template < + typename Value, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> // todo: consider if this should be nothrow or not // NOTE: Magic rescheduling is not currently supported by nothrow tasks decltype(auto) await_transform(Value&& value) { @@ -358,12 +377,13 @@ struct _promise final { return unifex::await_transform( *this, with_scheduler_affinity(static_cast(value), this->sched_)); - } else if constexpr ( - tag_invocable, type&, Value> || - detail::_awaitable) { - // Either await_transform has been customized or Value is an awaitable. - // Either way, we can dispatch to the await_transform CPO, then insert a - // transition back to the correct execution context if necessary. + } else if constexpr (std::is_invocable_v< + tag_t, + type&, + Value>) { + // await_transform has been customized so we can dispatch to the + // await_transform CPO, then insert a transition back to the correct + // execution context if necessary. return with_scheduler_affinity( *this, unifex::await_transform(*this, static_cast(value)), @@ -396,14 +416,28 @@ struct _promise final { }; }; +struct _frame_state { + _frame_state() noexcept = default; + + explicit _frame_state(AsyncStackFrame& frame, AsyncStackRoot& root) noexcept + : frame_(&frame) + , root_(&root) {} + + void restore_frame_state() const noexcept { + if (frame_) { + activateAsyncStackFrame(*root_, *frame_); + } + } + +private: + AsyncStackFrame* frame_{}; + AsyncStackRoot* root_; // only conditionally initialized +}; + struct _sr_thunk_promise_base : _promise_base { _sr_thunk_promise_base() : _promise_base([this]() noexcept -> coro::coroutine_handle<> { - callback_.destruct(); - - whoToContinue_ = continuation::DONE; - - if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + if (complete_and_choose_continuation(continuation_.done_handle())) { return continuation_.done_handle(); } else { return coro::noop_coroutine(); @@ -439,11 +473,15 @@ struct _sr_thunk_promise_base : _promise_base { void set_value(bool) noexcept { if (self->refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - if (self->whoToContinue_ == continuation::PRIMARY) { - self->continuation_.resume(); + UNIFEX_ASSERT(self->whoToContinue_); + + if (self->frame_) { + unifex::detail::ScopedAsyncStackRoot root; + root.activateFrame(*self->frame_); + + self->whoToContinue_.resume(); } else { - UNIFEX_ASSERT(self->whoToContinue_ == continuation::DONE); - self->continuation_.resume_done(); + self->whoToContinue_.resume(); } } } @@ -475,17 +513,74 @@ struct _sr_thunk_promise_base : _promise_base { std::atomic refCount_{1}; - enum class continuation : uint8_t { - UNSET, - PRIMARY, - DONE, - }; - - continuation whoToContinue_{continuation::UNSET}; + coro::coroutine_handle<> whoToContinue_{}; void register_stop_callback() noexcept { callback_.construct(stoken_, stop_callback{this}); } + + _frame_state ensure_frame_deactivated() noexcept { + if (frame_ != nullptr) { + if (whoToContinue_ == continuation_.done_handle()) { + popAsyncStackFrameFromCaller(*frame_); + } + + auto* root = frame_->getStackRoot(); + // this asserts that root is not null + deactivateAsyncStackFrame(*frame_); + + return _frame_state(*frame_, *root); + } + + return {}; + } + + // performs the final steps of completing this coroutine: + // - destroy (and thus synchronize with) the stop callback if it exists + // - record the continuation (normal or done) that should be resumed + // - ensure the async stack state is correct + // - decrement the refcount + // + // returns true when the caller should proceed with resuming the continuation + // and false when the caller should suspend and allow the deferred stop + // request to shoulder that responsibility + // + // Note: we don't return a coroutine handle for the caller to resume to work + // around a symmetric transfer bug in Clang 11 on Windows + bool complete_and_choose_continuation( + coro::coroutine_handle<> whoToContinue) noexcept { + UNIFEX_ASSERT( + whoToContinue == continuation_.handle() || + whoToContinue == continuation_.done_handle()); + + callback_.destruct(); + + // whoToContinue_ needs to be written before we decrement the refcount + // to ensure that we synchronize this write with the corresponding + // read in the deferred stop callback's completion + whoToContinue_ = whoToContinue; + + // deactivate our async stack frame before decrementing the refcount + // + // Once the refcount has been decremented, it's possible for the + // deferred stop callback to resume our continuation and it must + // activate our frame on a new stack root before doing; for that to be + // safe, it can't be active on any other stack root. If it turns out + // *we* are going to resume our continuation then we have to + // reactivate our frame to undo this proactivate deactivation. + const auto frameState = ensure_frame_deactivated(); + + // if we're last to complete, continue our continuation; otherwise do + // nothing and wait for the async stop request to do it + if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { + frameState.restore_frame_state(); + + return true; + } else { + // the deferred stop callback will reactivate this frame + return false; + } + } }; // TODO: determine if this should also be nothrow @@ -518,48 +613,28 @@ struct _sr_thunk_promise final { auto final_suspend() noexcept { struct awaiter final : _final_suspend_awaiter_base { -#if (defined(_MSC_VER) && !defined(__clang__)) || defined(__EMSCRIPTEN__) - // MSVC doesn't seem to like symmetric transfer in this final awaiter - // and the Emscripten (WebAssembly) compiler doesn't support tail-calls - void await_suspend(coro::coroutine_handle h) noexcept { + coro::coroutine_handle<> + await_suspend_impl(coro::coroutine_handle h) noexcept { auto& p = h.promise(); - p.callback_.destruct(); - - // this needs to be written before we decrement the refcount to ensure - // that we synchronize this write with the corresponding read in the - // deferred stop callback's completion - p.whoToContinue_ = continuation::PRIMARY; - - // if we're last to complete, continue our continuation; otherwise do - // nothing and wait for the async stop request to do it - if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return h.promise().continuation_.handle().resume(); + if (p.complete_and_choose_continuation(p.continuation_.handle())) { + return p.continuation_.handle(); + } else { + return coro::noop_coroutine(); } + } - // don't resume anything here; wait for the deferred stop request to - // resume our continuation +#if (defined(_MSC_VER) && !defined(__clang__)) || defined(__EMSCRIPTEN__) + // MSVC doesn't seem to like symmetric transfer in this final awaiter + // and the Emscripten (WebAssembly) compiler doesn't support tail-calls + void await_suspend(coro::coroutine_handle h) noexcept { + await_suspend_impl(h).resume(); } #else coro::coroutine_handle<> await_suspend(coro::coroutine_handle h) noexcept { - auto& p = h.promise(); - - p.callback_.destruct(); - - // this needs to be written before we decrement the refcount to ensure - // that we synchronize this write with the corresponding read in the - // deferred stop callback's completion - p.whoToContinue_ = continuation::PRIMARY; - - // if we're last to complete, continue our continuation; otherwise do - // nothing and wait for the async stop request to do it - if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - return h.promise().continuation_.handle(); - } else { - return coro::noop_coroutine(); - } - } + return await_suspend_impl(h); + }; #endif }; @@ -573,7 +648,10 @@ struct _sr_thunk_promise final { }; }; -template +template < + typename ThisPromise, + typename OtherPromise, + bool WithAsyncStackSupport> struct _awaiter final { /** * An awaitable type that knows how to await a task<>, sa_task<>, or @@ -628,6 +706,9 @@ struct _awaiter final { promise.register_stop_callback(); + maybePushAsyncStackFrame( + promise, h.promise(), instruction_ptr::read_return_address()); + return thisCoro; } @@ -640,6 +721,9 @@ struct _awaiter final { auto thisCoro = coro::coroutine_handle::from_address( (void*)std::exchange(--coro_, 0)); coro_holder destroyOnExit{thisCoro}; + + maybePopAsyncStackFrame(); + return thisCoro.promise().result(); } @@ -650,6 +734,28 @@ struct _awaiter final { std::bool_constant>; using needs_stop_token_t = std::bool_constant>; + using needs_async_stack_frame_t = std::bool_constant; + + void maybePushAsyncStackFrame( + [[maybe_unused]] ThisPromise& callee, + [[maybe_unused]] OtherPromise& caller, + [[maybe_unused]] instruction_ptr returnAddress) noexcept { + if constexpr (WithAsyncStackSupport) { + if (auto* callerFrame = get_async_stack_frame(caller)) { + frame_.setReturnAddress(returnAddress); + callee.frame_ = &frame_; + pushAsyncStackFrameCallerCallee(*callerFrame, frame_); + } + } + } + + void maybePopAsyncStackFrame() noexcept { + if constexpr (WithAsyncStackSupport) { + if (frame_.getParentFrame() != nullptr) { + popAsyncStackFrameCallee(frame_); + } + } + } // Only store the scheduler and the stop_token in the awaiter if we need to // type erase them. Otherwise, these members are "empty" and should take up @@ -668,6 +774,12 @@ struct _awaiter final { inplace_stop_token_adapter, detail::_empty<1>> stopTokenAdapter_; + UNIFEX_NO_UNIQUE_ADDRESS + conditional_t< + needs_async_stack_frame_t::value, + AsyncStackFrame, + detail::_empty<2>> + frame_; }; }; @@ -681,16 +793,20 @@ struct _sr_thunk_task::type final : coro_holder { friend promise_type; private: - template - using awaiter = typename _awaiter::type; + template + using awaiter = + typename _awaiter:: + type; explicit type(coro::coroutine_handle h) noexcept : coro_holder(h) {} - template - friend awaiter + template < + typename Promise, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> + friend awaiter tag_invoke(tag_t, Promise&, type&& t) noexcept { - return awaiter{std::exchange(t.coro_, {})}; + return awaiter{std::exchange(t.coro_, {})}; } }; @@ -805,18 +921,22 @@ struct _sa_task::type final : public _task::type { type(base&& t) noexcept : base(std::move(t)) {} - template - using awaiter = - typename _awaiter::type; + template + using awaiter = typename _awaiter< + typename base::promise_type, + OtherPromise, + WithAsyncStackSupport>::type; // given that we're awaited in a scheduler-affine context, we are ourselves // scheduler-affine static constexpr bool is_always_scheduler_affine = true; - template - friend awaiter + template < + typename Promise, + bool WithAsyncStackSupport = !UNIFEX_NO_ASYNC_STACKS> + friend awaiter tag_invoke(tag_t, Promise&, type&& t) noexcept { - return awaiter{std::exchange(t.coro_, {})}; + return awaiter{std::exchange(t.coro_, {})}; } template diff --git a/include/unifex/tracing/async_stack-inl.hpp b/include/unifex/tracing/async_stack-inl.hpp index 787ddc9a..add04bde 100644 --- a/include/unifex/tracing/async_stack-inl.hpp +++ b/include/unifex/tracing/async_stack-inl.hpp @@ -67,6 +67,16 @@ popAsyncStackFrameCallee(unifex::AsyncStackFrame& calleeFrame) noexcept { calleeFrame.stackRoot = nullptr; } +inline void popAsyncStackFrameFromCaller( + [[maybe_unused]] unifex::AsyncStackFrame& callerFrame) noexcept { + auto root = tryGetCurrentAsyncStackRoot(); + assert(root != nullptr); + auto topFrame = root->getTopFrame(); + assert(topFrame != nullptr); + assert(topFrame->getParentFrame() == &callerFrame); + popAsyncStackFrameCallee(*topFrame); +} + inline std::size_t getAsyncStackTraceFromInitialFrame( unifex::AsyncStackFrame* initialFrame, std::uintptr_t* addresses, diff --git a/include/unifex/tracing/async_stack.hpp b/include/unifex/tracing/async_stack.hpp index 18c747d0..ca52d394 100644 --- a/include/unifex/tracing/async_stack.hpp +++ b/include/unifex/tracing/async_stack.hpp @@ -221,6 +221,9 @@ void pushAsyncStackFrameCallerCallee( // the current AsyncStackRoot. void popAsyncStackFrameCallee(unifex::AsyncStackFrame& calleeFrame) noexcept; +void popAsyncStackFrameFromCaller( + unifex::AsyncStackFrame& callerFrame) noexcept; + // Get a pointer to a special frame that can be used as the root-frame // for a chain of AsyncStackFrame that does not chain onto a normal // call-stack. @@ -516,7 +519,7 @@ class ScopedAsyncStackRoot { assert(tryGetCurrentAsyncStackRoot() == &root_); [[maybe_unused]] auto topFrame = root_.topFrame.exchange(nullptr, std::memory_order_relaxed); - assert(topFrame == possiblyDeadFrame); + assert(topFrame == nullptr || topFrame == possiblyDeadFrame); } private: diff --git a/source/task.cpp b/source/task.cpp index acbc9401..94b4970e 100644 --- a/source/task.cpp +++ b/source/task.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Facebook, Inc. and its affiliates. + * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,8 @@ void _task_promise_base::transform_schedule_sender_impl_( // correct scheduler, do so now: if (!std::exchange(this->rescheduled_, true)) { // Create a cleanup action that transitions back onto the current scheduler: - auto cleanupTask = at_coroutine_exit(schedule, this->sched_); + auto cleanupTask = + await_transform(*this, at_coroutine_exit(schedule, this->sched_)); // Insert the cleanup action into the head of the continuation chain by // making direct calls to the cleanup task's awaiter member functions. See // type _cleanup_task in at_coroutine_exit.hpp: