From 288c4253476bb036d5832b954836d3bd4c1e881d Mon Sep 17 00:00:00 2001 From: Ian Petersen Date: Thu, 26 Oct 2023 15:53:47 -0700 Subject: [PATCH] Fix UB in cancelled tasks again As reported in #577, #578 didn't fix the UB in cancelled tasks. There were two bugs in that "fix": 1. we were writing the coroutine handle to resume *after* synchronizing between the two racing operation states, leading to a data race when the deferred stop source completed second; and 2. it's wrong to call `continuation_.done()` to store the coroutine handle to be resumed later because `done()` completes the operation with `set_done()` as a side effect. This diff fixes both issues. Instead of storing a coroutine handle to be resumed, we store an enum value that describes which coroutine handle to resume (so we don't eagerly invoke `done()`), and we do the store before synchronizing to eliminate the data race. I've added a unit test that fails with a TSAN-detected data race without the fix, and the fix silences the TSAN error. I also clang-formatted `task.hpp` and the modified test file. --- include/unifex/task.hpp | 80 ++++++++++++------ test/task_scheduler_affinity_test.cpp | 115 +++++++++++++++----------- 2 files changed, 122 insertions(+), 73 deletions(-) diff --git a/include/unifex/task.hpp b/include/unifex/task.hpp index 8c2b90ba9..fb3db892c 100644 --- a/include/unifex/task.hpp +++ b/include/unifex/task.hpp @@ -222,8 +222,9 @@ struct _result_and_unhandled_exception final { * This is used to share the implementation of result handling between * type-specific promise types. */ - struct type { //parametize on noexcept, noexcept terminates in body - void unhandled_exception() noexcept { // will be invoked in the catch block of the coroutine + struct type { // parametize on noexcept, noexcept terminates in body + void unhandled_exception() noexcept { // will be invoked in the catch block + // of the coroutine if constexpr (nothrow) { std::terminate(); } else { @@ -270,7 +271,7 @@ struct _return_value_or_void { * Provides a type-specific return_value() method to meet a promise type's * requirements. */ - // todo: consider if this should be nothrow or not + // todo: consider if this should be nothrow or not struct type : _result_and_unhandled_exception::type { template(typename Value = T) // (requires convertible_to AND constructible_from) // @@ -287,9 +288,7 @@ struct _return_value_or_void { * Provides a return_void() method to meet a promise type's requirements. */ struct type : _result_and_unhandled_exception::type { - void return_void() noexcept { - this->set_value(); - } + void return_void() noexcept { this->set_value(); } }; }; @@ -337,7 +336,8 @@ struct _promise final { // NOTE: Magic rescheduling is not currently supported by nothrow tasks decltype(auto) await_transform(Value&& value) { if constexpr (is_sender_for_v, schedule>) { - static_assert(!nothrow, "Magic rescheduling isn't supported by no-throw tasks"); + static_assert( + !nothrow, "Magic rescheduling isn't supported by no-throw tasks"); // TODO: rip this out and replace it with something more explicit // If we are co_await'ing a sender that is the result of calling @@ -379,7 +379,8 @@ struct _promise final { transform_schedule_sender_impl_(get_scheduler(snd)); // Return the inner sender, appropriately wrapped in an awaitable: - return unifex::await_transform(*this, std::forward(snd).base()); + return unifex::await_transform( + *this, std::forward(snd).base()); } }; }; @@ -388,10 +389,11 @@ struct _sr_thunk_promise_base : _promise_base { coro::coroutine_handle<> unhandled_done() noexcept { callback_.destruct(); + whoToContinue_ = continuation::DONE; + if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { return continuation_.done(); } else { - handleToResume_ = continuation_.done(); return coro::noop_coroutine(); } } @@ -423,8 +425,12 @@ struct _sr_thunk_promise_base : _promise_base { void set_value(bool) noexcept { if (self->refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { - UNIFEX_ASSERT(self->handleToResume_ != coro::coroutine_handle<>{}); - self->handleToResume_.resume(); + if (self->whoToContinue_ == continuation::PRIMARY) { + self->continuation_.handle().resume(); + } else { + UNIFEX_ASSERT(self->whoToContinue_ == continuation::DONE); + self->continuation_.done().resume(); + } } } void set_error(std::exception_ptr) noexcept { std::terminate(); } @@ -451,17 +457,24 @@ struct _sr_thunk_promise_base : _promise_base { using stop_callback_t = typename inplace_stop_token::callback_type; - coro::coroutine_handle<> handleToResume_{}; manual_lifetime callback_; std::atomic refCount_{1}; + enum class continuation : uint8_t { + UNSET, + PRIMARY, + DONE, + }; + + continuation whoToContinue_{continuation::UNSET}; + void register_stop_callback() noexcept { callback_.construct(stoken_, stop_callback{this}); } }; -//TODO: determine if this should also be nothrow +// TODO: determine if this should also be nothrow template struct _sr_thunk_promise final { /** @@ -499,15 +512,19 @@ struct _sr_thunk_promise final { p.callback_.destruct(); + // this needs to be written before we decrement the refcount to ensure + // that we synchronize this write with the corresponding read in the + // deferred stop callback's completion + p.whoToContinue_ = continuation::PRIMARY; + // if we're last to complete, continue our continuation; otherwise do // nothing and wait for the async stop request to do it if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { return h.promise().continuation_.handle().resume(); - } else { - p.handleToResume_ = h.promise().continuation_.handle(); - // don't resume anything here; wait for the deferred stop request - // to resume our continuation } + + // don't resume anything here; wait for the deferred stop request to + // resume our continuation } #else coro::coroutine_handle<> @@ -516,12 +533,16 @@ struct _sr_thunk_promise final { p.callback_.destruct(); + // this needs to be written before we decrement the refcount to ensure + // that we synchronize this write with the corresponding read in the + // deferred stop callback's completion + p.whoToContinue_ = continuation::PRIMARY; + // if we're last to complete, continue our continuation; otherwise do // nothing and wait for the async stop request to do it if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) { return h.promise().continuation_.handle(); } else { - p.handleToResume_ = h.promise().continuation_.handle(); return coro::noop_coroutine(); } } @@ -596,7 +617,8 @@ struct _awaiter final { return thisCoro; } - result_type await_resume() noexcept(noexcept(UNIFEX_DECLVAL(ThisPromise&).result())) { + result_type + await_resume() noexcept(noexcept(UNIFEX_DECLVAL(ThisPromise&).result())) { if constexpr (needs_stop_token_t::value) stopTokenAdapter_.unsubscribe(); if constexpr (needs_scheduler_t::value) @@ -703,8 +725,13 @@ struct _task::type type& operator=(type&& t) noexcept = default; template - friend type - tag_invoke(tag_t, type_identity, Fn fn, Args... args) noexcept(false) /* even if nothrow is true, ramp of a coroutine can still throw */ { + friend type tag_invoke( + tag_t, + type_identity, + Fn fn, + Args... args) noexcept(false) /* even if nothrow is true, ramp of a + coroutine can still throw */ + { co_return co_await std::invoke((Fn &&) fn, (Args &&) args...); } @@ -713,7 +740,9 @@ struct _task::type : coro_holder(h) {} template - friend auto tag_invoke(tag_t, Promise& p, type&& t) noexcept(false) /* calls inject_stop_request_thunk which might throw */ { + friend auto + tag_invoke(tag_t, Promise& p, type&& t) noexcept( + false) /* calls inject_stop_request_thunk which might throw */ { // we don't know whether our consumer will enforce the scheduler-affinity // invariants so we need to ensure that stop requests are delivered on the // right scheduler @@ -722,7 +751,9 @@ struct _task::type } template - friend auto tag_invoke(tag_t, type&& t, Receiver&& r) noexcept(false) /* will ultimately call a coroutine that might throw */ { + friend auto + tag_invoke(tag_t, type&& t, Receiver&& r) noexcept( + false) /* will ultimately call a coroutine that might throw */ { using stoken_t = stop_token_type_t; if constexpr (is_stop_never_possible_v) { @@ -776,7 +807,8 @@ struct _sa_task::type final : public _task::type { template friend auto - tag_invoke(tag_t, type&& t, Receiver&& r) noexcept(false) /* ultimately calls a coroutine which may throw */ { + tag_invoke(tag_t, type&& t, Receiver&& r) noexcept( + false) /* ultimately calls a coroutine which may throw */ { return connect_awaitable(std::move(t), static_cast(r)); } }; diff --git a/test/task_scheduler_affinity_test.cpp b/test/task_scheduler_affinity_test.cpp index 9d3646202..7f9f8e772 100644 --- a/test/task_scheduler_affinity_test.cpp +++ b/test/task_scheduler_affinity_test.cpp @@ -18,20 +18,21 @@ #if !UNIFEX_NO_COROUTINES -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include +# include + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include using namespace unifex; @@ -41,18 +42,17 @@ struct TaskSchedulerAffinityTest : testing::Test { }; UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task> child(Scheduler s) { +(requires scheduler) + task> child(Scheduler s) { auto that_id = - co_await then(schedule(s), []{ return std::this_thread::get_id(); }); + co_await then(schedule(s), [] { return std::this_thread::get_id(); }); // Should have automatically transitioned back to the original thread: auto this_id = std::this_thread::get_id(); co_return std::make_pair(this_id, that_id); } UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task inner(Scheduler s) { +(requires scheduler)task inner(Scheduler s) { // Transition to the scheduler's context: co_await schedule(s); // Should return the new context @@ -60,8 +60,8 @@ task inner(Scheduler s) { } UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task> outer(Scheduler s) { +(requires scheduler) + task> outer(Scheduler s) { // Call a nested coroutine that transitions context: auto that_id = co_await inner(s); // Should have automatically transitioned back to the correct context @@ -72,8 +72,7 @@ task> outer(Scheduler s) { // Test that after a co_await schedule(), the coroutine's current // scheduler has changed: UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task test_current_scheduler(Scheduler s) { +(requires scheduler)task test_current_scheduler(Scheduler s) { auto before = co_await current_scheduler(); co_await schedule(s); auto after = co_await current_scheduler(); @@ -83,15 +82,17 @@ task test_current_scheduler(Scheduler s) { // Test that after a co_await schedule(), the coroutine's current // scheduler is inherited by child tasks: UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task> test_current_scheduler_is_inherited_impl(Scheduler s) { +(requires scheduler)task> test_current_scheduler_is_inherited_impl(Scheduler s) { any_scheduler s2 = co_await current_scheduler(); bool sameScheduler = (s2 == s); co_return std::make_pair(sameScheduler, std::this_thread::get_id()); } UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task> test_current_scheduler_is_inherited(Scheduler s) { +(requires scheduler) + task> test_current_scheduler_is_inherited( + Scheduler s) { co_await schedule(s); co_return co_await test_current_scheduler_is_inherited_impl(s); } @@ -99,49 +100,48 @@ task> test_current_scheduler_is_inherited(Sched // Test that we properly transition back to the right context when // the task is cancelled. UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task test_context_restored_on_cancel_2(Scheduler s) { +(requires scheduler) + task test_context_restored_on_cancel_2(Scheduler s) { co_await schedule(s); co_await stop(); ADD_FAILURE() << "Coroutine did not stop!"; co_return; } UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task test_context_restored_on_cancel(Scheduler s) { +(requires scheduler) + task test_context_restored_on_cancel(Scheduler s) { // swallow the cancellation signal: - (void) co_await let_done( - test_context_restored_on_cancel_2(s), - []() noexcept { return just(); }); + (void)co_await let_done( + test_context_restored_on_cancel_2(s), []() noexcept { return just(); }); co_return std::this_thread::get_id(); } // Test that we properly transition back to the right context when // the task fails. UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task test_context_restored_on_error_2(Scheduler s) { +(requires scheduler) + task test_context_restored_on_error_2(Scheduler s) { co_await schedule(s); throw std::runtime_error("whoops"); } UNIFEX_TEMPLATE(typename Scheduler) - (requires scheduler) -task test_context_restored_on_error(Scheduler s) { +(requires scheduler) + task test_context_restored_on_error(Scheduler s) { std::thread::id id; // swallow the cancellation signal: try { co_await test_context_restored_on_error_2(s); ADD_FAILURE() << "Was expecting a throw"; - } catch(...) { + } catch (...) { id = std::this_thread::get_id(); } co_return id; } -} // anonymous namespace +} // anonymous namespace TEST_F(TaskSchedulerAffinityTest, TransformSenderOnSeparateThread) { - if(auto opt = sync_wait(child(thread_ctx.get_scheduler()))) { + if (auto opt = sync_wait(child(thread_ctx.get_scheduler()))) { auto [this_id, that_id] = *opt; ASSERT_EQ(this_id, std::this_thread::get_id()); ASSERT_EQ(that_id, thread_ctx.get_thread_id()); @@ -151,7 +151,7 @@ TEST_F(TaskSchedulerAffinityTest, TransformSenderOnSeparateThread) { } TEST_F(TaskSchedulerAffinityTest, InlineThreadHopInCoroutine) { - if(auto opt = sync_wait(outer(thread_ctx.get_scheduler()))) { + if (auto opt = sync_wait(outer(thread_ctx.get_scheduler()))) { auto [this_id, that_id] = *opt; ASSERT_EQ(this_id, std::this_thread::get_id()); ASSERT_EQ(that_id, thread_ctx.get_thread_id()); @@ -161,7 +161,8 @@ TEST_F(TaskSchedulerAffinityTest, InlineThreadHopInCoroutine) { } TEST_F(TaskSchedulerAffinityTest, CurrentSchedulerTest) { - if(auto opt = sync_wait(test_current_scheduler(thread_ctx.get_scheduler()))) { + if (auto opt = + sync_wait(test_current_scheduler(thread_ctx.get_scheduler()))) { ASSERT_TRUE(opt.has_value()); EXPECT_TRUE(*opt); } else { @@ -170,7 +171,8 @@ TEST_F(TaskSchedulerAffinityTest, CurrentSchedulerTest) { } TEST_F(TaskSchedulerAffinityTest, CurrentSchedulerIsInheritedTest) { - if(auto opt = sync_wait(test_current_scheduler_is_inherited(thread_ctx.get_scheduler()))) { + if (auto opt = sync_wait( + test_current_scheduler_is_inherited(thread_ctx.get_scheduler()))) { ASSERT_TRUE(opt.has_value()); auto [success, thread_id] = *opt; EXPECT_TRUE(success); @@ -181,7 +183,8 @@ TEST_F(TaskSchedulerAffinityTest, CurrentSchedulerIsInheritedTest) { } TEST_F(TaskSchedulerAffinityTest, ContextRestoredOnCancelTest) { - if(auto opt = sync_wait(test_context_restored_on_cancel(thread_ctx.get_scheduler()))) { + if (auto opt = sync_wait( + test_context_restored_on_cancel(thread_ctx.get_scheduler()))) { ASSERT_TRUE(opt.has_value()); EXPECT_EQ(*opt, std::this_thread::get_id()); } else { @@ -190,7 +193,8 @@ TEST_F(TaskSchedulerAffinityTest, ContextRestoredOnCancelTest) { } TEST_F(TaskSchedulerAffinityTest, ContextRestoredOnErrrorTest) { - if(auto opt = sync_wait(test_context_restored_on_error(thread_ctx.get_scheduler()))) { + if (auto opt = sync_wait( + test_context_restored_on_error(thread_ctx.get_scheduler()))) { ASSERT_TRUE(opt.has_value()); EXPECT_EQ(*opt, std::this_thread::get_id()); } else { @@ -302,8 +306,6 @@ unifex::task reportCancellationThreadId() { }); } -} // namespace - TEST_F(TaskSchedulerAffinityTest, StopRequestsDeliveredOnExpectedScheduler) { unifex::single_thread_context ctx; @@ -315,4 +317,19 @@ TEST_F(TaskSchedulerAffinityTest, StopRequestsDeliveredOnExpectedScheduler) { EXPECT_EQ(std::this_thread::get_id(), *ret); } -#endif // !UNIFEX_NO_COROUTINES +TEST_F(TaskSchedulerAffinityTest, NoRacesInCancellation) { + unifex::static_thread_pool pool{2}; + + unifex::sync_wait(unifex::on( + pool.get_scheduler(), + unifex::let_value_with_stop_source([](auto& source) noexcept { + source.request_stop(); + return []() noexcept -> unifex::task { + co_return co_await unifex::just(42); + }(); + }))); +} + +} // namespace + +#endif // !UNIFEX_NO_COROUTINES