Skip to content

Commit

Permalink
Fix UB in cancelled tasks again
Browse files Browse the repository at this point in the history
As reported in #577, #578 didn't fix the UB in cancelled tasks. There
were two bugs in that "fix":
 1. we were writing the coroutine handle to resume *after*
    synchronizing between the two racing operation states, leading to a
    data race when the deferred stop source completed second; and
 2. it's wrong to call `continuation_.done()` to store the coroutine
    handle to be resumed later because `done()` completes the operation
    with `set_done()` as a side effect.

This diff fixes both issues. Instead of storing a coroutine handle to be
resumed, we store an enum value that describes which coroutine handle to
resume (so we don't eagerly invoke `done()`), and we do the store before
synchronizing to eliminate the data race.

I've added a unit test that fails with a TSAN-detected data race without
the fix, and the fix silences the TSAN error.

I also clang-formatted `task.hpp` and the modified test file.
  • Loading branch information
ispeters committed Oct 27, 2023
1 parent cd8b3a6 commit 288c425
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 73 deletions.
80 changes: 56 additions & 24 deletions include/unifex/task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,9 @@ struct _result_and_unhandled_exception final {
* This is used to share the implementation of result handling between
* type-specific promise types.
*/
struct type { //parametize on noexcept, noexcept terminates in body
void unhandled_exception() noexcept { // will be invoked in the catch block of the coroutine
struct type { // parametize on noexcept, noexcept terminates in body
void unhandled_exception() noexcept { // will be invoked in the catch block
// of the coroutine
if constexpr (nothrow) {
std::terminate();
} else {
Expand Down Expand Up @@ -270,7 +271,7 @@ struct _return_value_or_void {
* Provides a type-specific return_value() method to meet a promise type's
* requirements.
*/
// todo: consider if this should be nothrow or not
// todo: consider if this should be nothrow or not
struct type : _result_and_unhandled_exception<T, nothrow>::type {
template(typename Value = T) //
(requires convertible_to<Value, T> AND constructible_from<T, Value>) //
Expand All @@ -287,9 +288,7 @@ struct _return_value_or_void<void, nothrow> {
* Provides a return_void() method to meet a promise type's requirements.
*/
struct type : _result_and_unhandled_exception<void, nothrow>::type {
void return_void() noexcept {
this->set_value();
}
void return_void() noexcept { this->set_value(); }
};
};

Expand Down Expand Up @@ -337,7 +336,8 @@ struct _promise final {
// NOTE: Magic rescheduling is not currently supported by nothrow tasks
decltype(auto) await_transform(Value&& value) {
if constexpr (is_sender_for_v<remove_cvref_t<Value>, schedule>) {
static_assert(!nothrow, "Magic rescheduling isn't supported by no-throw tasks");
static_assert(
!nothrow, "Magic rescheduling isn't supported by no-throw tasks");
// TODO: rip this out and replace it with something more explicit

// If we are co_await'ing a sender that is the result of calling
Expand Down Expand Up @@ -379,7 +379,8 @@ struct _promise final {
transform_schedule_sender_impl_(get_scheduler(snd));

// Return the inner sender, appropriately wrapped in an awaitable:
return unifex::await_transform(*this, std::forward<ScheduleSender>(snd).base());
return unifex::await_transform(
*this, std::forward<ScheduleSender>(snd).base());
}
};
};
Expand All @@ -388,10 +389,11 @@ struct _sr_thunk_promise_base : _promise_base {
coro::coroutine_handle<> unhandled_done() noexcept {
callback_.destruct();

whoToContinue_ = continuation::DONE;

if (refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
return continuation_.done();
} else {
handleToResume_ = continuation_.done();
return coro::noop_coroutine();
}
}
Expand Down Expand Up @@ -423,8 +425,12 @@ struct _sr_thunk_promise_base : _promise_base {

void set_value(bool) noexcept {
if (self->refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
UNIFEX_ASSERT(self->handleToResume_ != coro::coroutine_handle<>{});
self->handleToResume_.resume();
if (self->whoToContinue_ == continuation::PRIMARY) {
self->continuation_.handle().resume();
} else {
UNIFEX_ASSERT(self->whoToContinue_ == continuation::DONE);
self->continuation_.done().resume();
}
}
}
void set_error(std::exception_ptr) noexcept { std::terminate(); }
Expand All @@ -451,17 +457,24 @@ struct _sr_thunk_promise_base : _promise_base {
using stop_callback_t =
typename inplace_stop_token::callback_type<stop_callback>;

coro::coroutine_handle<> handleToResume_{};
manual_lifetime<stop_callback_t> callback_;

std::atomic<uint8_t> refCount_{1};

enum class continuation : uint8_t {
UNSET,
PRIMARY,
DONE,
};

continuation whoToContinue_{continuation::UNSET};

void register_stop_callback() noexcept {
callback_.construct(stoken_, stop_callback{this});
}
};

//TODO: determine if this should also be nothrow
// TODO: determine if this should also be nothrow
template <typename T>
struct _sr_thunk_promise final {
/**
Expand Down Expand Up @@ -499,15 +512,19 @@ struct _sr_thunk_promise final {

p.callback_.destruct();

// this needs to be written before we decrement the refcount to ensure
// that we synchronize this write with the corresponding read in the
// deferred stop callback's completion
p.whoToContinue_ = continuation::PRIMARY;

// if we're last to complete, continue our continuation; otherwise do
// nothing and wait for the async stop request to do it
if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
return h.promise().continuation_.handle().resume();
} else {
p.handleToResume_ = h.promise().continuation_.handle();
// don't resume anything here; wait for the deferred stop request
// to resume our continuation
}

// don't resume anything here; wait for the deferred stop request to
// resume our continuation
}
#else
coro::coroutine_handle<>
Expand All @@ -516,12 +533,16 @@ struct _sr_thunk_promise final {

p.callback_.destruct();

// this needs to be written before we decrement the refcount to ensure
// that we synchronize this write with the corresponding read in the
// deferred stop callback's completion
p.whoToContinue_ = continuation::PRIMARY;

// if we're last to complete, continue our continuation; otherwise do
// nothing and wait for the async stop request to do it
if (p.refCount_.fetch_sub(1, std::memory_order_acq_rel) == 1) {
return h.promise().continuation_.handle();
} else {
p.handleToResume_ = h.promise().continuation_.handle();
return coro::noop_coroutine();
}
}
Expand Down Expand Up @@ -596,7 +617,8 @@ struct _awaiter final {
return thisCoro;
}

result_type await_resume() noexcept(noexcept(UNIFEX_DECLVAL(ThisPromise&).result())) {
result_type
await_resume() noexcept(noexcept(UNIFEX_DECLVAL(ThisPromise&).result())) {
if constexpr (needs_stop_token_t::value)
stopTokenAdapter_.unsubscribe();
if constexpr (needs_scheduler_t::value)
Expand Down Expand Up @@ -703,8 +725,13 @@ struct _task<T, nothrow>::type
type& operator=(type&& t) noexcept = default;

template <typename Fn, typename... Args>
friend type
tag_invoke(tag_t<co_invoke>, type_identity<type>, Fn fn, Args... args) noexcept(false) /* even if nothrow is true, ramp of a coroutine can still throw */ {
friend type tag_invoke(
tag_t<co_invoke>,
type_identity<type>,
Fn fn,
Args... args) noexcept(false) /* even if nothrow is true, ramp of a
coroutine can still throw */
{
co_return co_await std::invoke((Fn &&) fn, (Args &&) args...);
}

Expand All @@ -713,7 +740,9 @@ struct _task<T, nothrow>::type
: coro_holder(h) {}

template <typename Promise>
friend auto tag_invoke(tag_t<unifex::await_transform>, Promise& p, type&& t) noexcept(false) /* calls inject_stop_request_thunk which might throw */ {
friend auto
tag_invoke(tag_t<unifex::await_transform>, Promise& p, type&& t) noexcept(
false) /* calls inject_stop_request_thunk which might throw */ {
// we don't know whether our consumer will enforce the scheduler-affinity
// invariants so we need to ensure that stop requests are delivered on the
// right scheduler
Expand All @@ -722,7 +751,9 @@ struct _task<T, nothrow>::type
}

template <typename Receiver>
friend auto tag_invoke(tag_t<unifex::connect>, type&& t, Receiver&& r) noexcept(false) /* will ultimately call a coroutine that might throw */ {
friend auto
tag_invoke(tag_t<unifex::connect>, type&& t, Receiver&& r) noexcept(
false) /* will ultimately call a coroutine that might throw */ {
using stoken_t = stop_token_type_t<Receiver>;

if constexpr (is_stop_never_possible_v<stoken_t>) {
Expand Down Expand Up @@ -776,7 +807,8 @@ struct _sa_task<T, nothrow>::type final : public _task<T, nothrow>::type {

template <typename Receiver>
friend auto
tag_invoke(tag_t<unifex::connect>, type&& t, Receiver&& r) noexcept(false) /* ultimately calls a coroutine which may throw */ {
tag_invoke(tag_t<unifex::connect>, type&& t, Receiver&& r) noexcept(
false) /* ultimately calls a coroutine which may throw */ {
return connect_awaitable(std::move(t), static_cast<Receiver&&>(r));
}
};
Expand Down
Loading

0 comments on commit 288c425

Please sign in to comment.