From ed0a3b5887fc7213a9b685de52566bff29799dc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A1n=20Ondru=C5=A1ek?= Date: Tue, 23 May 2023 14:03:11 -0700 Subject: [PATCH] add cancellation support into `io_uring_context` * `open_listening_socket` * `async_read_only_file` * `async_write_only_file` --- include/unifex/linux/io_uring_context.hpp | 243 +++++++++++++++++++++- test/io_uring_2_test.cpp | 145 +++++++++++++ 2 files changed, 382 insertions(+), 6 deletions(-) create mode 100644 test/io_uring_2_test.cpp diff --git a/include/unifex/linux/io_uring_context.hpp b/include/unifex/linux/io_uring_context.hpp index 8b45e806d..0cefba1fd 100644 --- a/include/unifex/linux/io_uring_context.hpp +++ b/include/unifex/linux/io_uring_context.hpp @@ -441,7 +441,8 @@ class io_uring_context::read_sender { void start_io() noexcept { UNIFEX_ASSERT(context_.is_running_on_io_thread()); - + stopCallback_.construct( + get_stop_token(receiver_), cancel_callback{*this}); auto populateSqe = [this](io_uring_sqe & sqe) noexcept { sqe.opcode = IORING_OP_READV; sqe.fd = fd_; @@ -460,9 +461,57 @@ class io_uring_context::read_sender { } } + void request_stop() noexcept { + if (char expected = 1; !refCount_.compare_exchange_strong(expected, 2, std::memory_order_relaxed)) { + // lost race with on_read_complete + UNIFEX_ASSERT(expected == 0); + return; + } + if (context_.is_running_on_io_thread()) { + request_stop_local(); + } else { + request_stop_remote(); + } + } + + void request_stop_local() noexcept { + UNIFEX_ASSERT(context_.is_running_on_io_thread()); + auto populateSqe = [this](io_uring_sqe & sqe) noexcept { + sqe.opcode = IORING_OP_ASYNC_CANCEL; + sqe.fd = -1; + sqe.off = 0; + auto op = reinterpret_cast( + static_cast(this)); + // sqe.addr is the user_data to look for and cancel + sqe.addr = op; + sqe.len = 0; + auto cop = reinterpret_cast( + static_cast(&cop_)); + sqe.user_data = cop; + cop_.execute_ = &cancel_operation::on_stop_complete; + }; + + if (!context_.try_submit_io(populateSqe)) { + cop_.execute_ = &cancel_operation::on_schedule_stop_complete; + context_.schedule_pending_io(&cop_); + } + } + + void request_stop_remote() noexcept { + cop_.execute_ = &cancel_operation::on_schedule_stop_complete; + context_.schedule_remote(&cop_); + } + static void on_read_complete(operation_base* op) noexcept { auto& self = *static_cast(op); - if (self.result_ >= 0) { + if (self.refCount_.fetch_sub(1, std::memory_order_acq_rel) != 1) { + // stop callback is running, must complete the op + return; + } + self.stopCallback_.destruct(); + if (get_stop_token(self.receiver_).stop_requested()) { + unifex::set_done(std::move(self.receiver_)); + } else if (self.result_ >= 0) { if constexpr (noexcept(unifex::set_value(std::move(self.receiver_), ssize_t(self.result_)))) { unifex::set_value(std::move(self.receiver_), ssize_t(self.result_)); } else { @@ -481,11 +530,39 @@ class io_uring_context::read_sender { } } + struct cancel_operation final : completion_base { + operation& op_; + + explicit cancel_operation(operation& op) noexcept : op_(op) {} + // intrusive list breaks if the same operation is submitted twice + // break the cycle: `on_stop_complete` delegates to the parent operation + static void on_stop_complete(operation_base* op) noexcept { + operation::on_read_complete(&static_cast(op)->op_); + } + + static void on_schedule_stop_complete(operation_base* op) noexcept { + static_cast(op)->op_.request_stop_local(); + } + }; + + struct cancel_callback final { + operation& op_; + + void operator()() noexcept { + op_.request_stop(); + } + }; + io_uring_context& context_; int fd_; offset_t offset_; iovec buffer_[1]; Receiver receiver_; + manual_lifetime::template callback_type> + stopCallback_; + std::atomic_char refCount_{1}; + cancel_operation cop_{*this}; }; public: @@ -555,7 +632,8 @@ class io_uring_context::write_sender { void start_io() noexcept { UNIFEX_ASSERT(context_.is_running_on_io_thread()); - + stopCallback_.construct( + get_stop_token(receiver_), cancel_callback{*this}); auto populateSqe = [this](io_uring_sqe & sqe) noexcept { sqe.opcode = IORING_OP_WRITEV; sqe.fd = fd_; @@ -574,9 +652,57 @@ class io_uring_context::write_sender { } } + void request_stop() noexcept { + if (char expected = 1; !refCount_.compare_exchange_strong(expected, 2, std::memory_order_relaxed)) { + // lost race with on_write_complete + UNIFEX_ASSERT(expected == 0); + return; + } + if (context_.is_running_on_io_thread()) { + request_stop_local(); + } else { + request_stop_remote(); + } + } + + void request_stop_local() noexcept { + UNIFEX_ASSERT(context_.is_running_on_io_thread()); + auto populateSqe = [this](io_uring_sqe & sqe) noexcept { + sqe.opcode = IORING_OP_ASYNC_CANCEL; + sqe.fd = -1; + sqe.off = 0; + auto op = reinterpret_cast( + static_cast(this)); + // sqe.addr is the user_data to look for and cancel + sqe.addr = op; + sqe.len = 0; + auto cop = reinterpret_cast( + static_cast(&cop_)); + sqe.user_data = cop; + cop_.execute_ = &cancel_operation::on_stop_complete; + }; + + if (!context_.try_submit_io(populateSqe)) { + cop_.execute_ = &cancel_operation::on_schedule_stop_complete; + context_.schedule_pending_io(&cop_); + } + } + + void request_stop_remote() noexcept { + cop_.execute_ = &cancel_operation::on_schedule_stop_complete; + context_.schedule_remote(&cop_); + } + static void on_write_complete(operation_base* op) noexcept { auto& self = *static_cast(op); - if (self.result_ >= 0) { + if (self.refCount_.fetch_sub(1, std::memory_order_acq_rel) != 1) { + // stop callback is running, must complete the op + return; + } + self.stopCallback_.destruct(); + if (get_stop_token(self.receiver_).stop_requested()) { + unifex::set_done(std::move(self.receiver_)); + } else if (self.result_ >= 0) { if constexpr (noexcept(unifex::set_value(std::move(self.receiver_), ssize_t(self.result_)))) { unifex::set_value(std::move(self.receiver_), ssize_t(self.result_)); } else { @@ -595,11 +721,39 @@ class io_uring_context::write_sender { } } + struct cancel_operation final : completion_base { + operation& op_; + + explicit cancel_operation(operation& op) noexcept : op_(op) {} + // intrusive list breaks if the same operation is submitted twice + // break the cycle: `on_stop_complete` delegates to the parent operation + static void on_stop_complete(operation_base* op) noexcept { + operation::on_write_complete(&static_cast(op)->op_); + } + + static void on_schedule_stop_complete(operation_base* op) noexcept { + static_cast(op)->op_.request_stop_local(); + } + }; + + struct cancel_callback final { + operation& op_; + + void operator()() noexcept { + op_.request_stop(); + } + }; + io_uring_context& context_; int fd_; offset_t offset_; iovec buffer_[1]; Receiver receiver_; + manual_lifetime::template callback_type> + stopCallback_; + std::atomic_char refCount_{1}; + cancel_operation cop_{*this}; }; public: @@ -989,7 +1143,8 @@ class io_uring_context::accept_sender { void start_io() noexcept { UNIFEX_ASSERT(context_.is_running_on_io_thread()); - + stopCallback_.construct( + get_stop_token(receiver_), cancel_callback{*this}); auto populateSqe = [this](io_uring_sqe& sqe) noexcept { sqe.opcode = IORING_OP_ACCEPT; sqe.accept_flags = SOCK_NONBLOCK; @@ -1007,9 +1162,57 @@ class io_uring_context::accept_sender { } } + void request_stop() noexcept { + if (char expected = 1; !refCount_.compare_exchange_strong(expected, 2, std::memory_order_relaxed)) { + // lost race with on_accept + UNIFEX_ASSERT(expected == 0); + return; + } + if (context_.is_running_on_io_thread()) { + request_stop_local(); + } else { + request_stop_remote(); + } + } + + void request_stop_local() noexcept { + UNIFEX_ASSERT(context_.is_running_on_io_thread()); + auto populateSqe = [this](io_uring_sqe & sqe) noexcept { + sqe.opcode = IORING_OP_ASYNC_CANCEL; + sqe.fd = -1; + sqe.off = 0; + auto op = reinterpret_cast( + static_cast(this)); + // sqe.addr is the user_data to look for and cancel + sqe.addr = op; + sqe.len = 0; + auto cop = reinterpret_cast( + static_cast(&cop_)); + sqe.user_data = cop; + cop_.execute_ = &cancel_operation::on_stop_complete; + }; + + if (!context_.try_submit_io(populateSqe)) { + cop_.execute_ = &cancel_operation::on_schedule_stop_complete; + context_.schedule_pending_io(&cop_); + } + } + + void request_stop_remote() noexcept { + cop_.execute_ = &cancel_operation::on_schedule_stop_complete; + context_.schedule_remote(&cop_); + } + static void on_accept(operation_base* op) noexcept { auto& self = *static_cast(op); - if (self.result_ >= 0) { + if (self.refCount_.fetch_sub(1, std::memory_order_acq_rel) != 1) { + // stop callback is running, must complete the op + return; + } + self.stopCallback_.destruct(); + if (get_stop_token(self.receiver_).stop_requested()) { + unifex::set_done(std::move(self.receiver_)); + } else if (self.result_ >= 0) { if constexpr (noexcept(unifex::set_value( std::move(self.receiver_), async_read_write_file{self.context_, self.result_}))) { unifex::set_value(std::move(self.receiver_), async_read_write_file{self.context_, self.result_}); @@ -1031,9 +1234,37 @@ class io_uring_context::accept_sender { } } + struct cancel_operation final : completion_base { + operation& op_; + + explicit cancel_operation(operation& op) noexcept : op_(op) {} + // intrusive list breaks if the same operation is submitted twice + // break the cycle: `on_stop_complete` delegates to the parent operation + static void on_stop_complete(operation_base* op) noexcept { + operation::on_accept(&static_cast(op)->op_); + } + + static void on_schedule_stop_complete(operation_base* op) noexcept { + static_cast(op)->op_.request_stop_local(); + } + }; + + struct cancel_callback final { + operation& op_; + + void operator()() noexcept { + op_.request_stop(); + } + }; + io_uring_context& context_; int fd_; Receiver receiver_; + manual_lifetime::template callback_type> + stopCallback_; + std::atomic_char refCount_{1}; + cancel_operation cop_{*this}; }; public: diff --git a/test/io_uring_2_test.cpp b/test/io_uring_2_test.cpp new file mode 100644 index 000000000..98f9d6fd6 --- /dev/null +++ b/test/io_uring_2_test.cpp @@ -0,0 +1,145 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License Version 2.0 with LLVM Exceptions + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * https://llvm.org/LICENSE.txt + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if !UNIFEX_NO_LIBURING +# if !UNIFEX_NO_COROUTINES + +# include + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include +# include +# include +# include +# include + +# include +using namespace unifex; +using namespace unifex::linuxos; +namespace { +constexpr std::chrono::milliseconds stopAfter{42}; +const char* fdPath = "/proc/self/fd/"; + +struct IOUringTest : testing::Test { + void SetUp() override { + ASSERT_NE(pipe(pipes_), -1) << "unable to create pipe"; + close_ = true; + } + + ~IOUringTest() { + if (close_) { + close(pipes_[0]); + close(pipes_[1]); + } + stopSource_.request_stop(); + t_.join(); + } + +private: + bool close_{false}; + +protected: + int pipes_[2]; + io_uring_context ctx_; + inplace_stop_source stopSource_; + std::thread t_{[&] { + ctx_.run(stopSource_.get_token()); + }}; + + task accept(io_uring_context::scheduler sched) { + // open on a random port, will hang forever + auto stream = open_listening_socket(sched, 0); + co_await finally(unifex::next(stream), unifex::cleanup(stream)); + ADD_FAILURE() << "should cancel and unroll"; + } + task read(io_uring_context::scheduler sched) { + auto in = open_file_read_only(sched, fdPath + std::to_string(pipes_[0])); + std::array buffer; + // will hang forever + co_await async_read_some_at( + in, 0, as_writable_bytes(span{buffer.data(), buffer.size()})); + ADD_FAILURE() << "should cancel and unroll"; + } + + auto bloat() const { + // pipe is blocking when full (what we want), settings are env. specific + auto size = fcntl(pipes_[1], F_GETPIPE_SZ); + EXPECT_GT(size, 0); + std::printf("Pipe size: %d\n", size); + return std::string(static_cast(size), '?'); + } + + task write(io_uring_context::scheduler sched) { + auto data = bloat(); + const auto buffer = as_bytes(span{data.data(), data.size()}); + auto out = open_file_write_only(sched, fdPath + std::to_string(pipes_[1])); + // Start 8 concurrent writes to the file at different offsets. + co_await when_all( + // Calls the 'async_write_some_at()' CPO on the file object + // returned from 'open_file_write_only()'. + async_write_some_at(out, 0, buffer), + async_write_some_at(out, 1 * buffer.size(), buffer), + async_write_some_at(out, 2 * buffer.size(), buffer), + async_write_some_at(out, 3 * buffer.size(), buffer), + async_write_some_at(out, 4 * buffer.size(), buffer), + async_write_some_at(out, 5 * buffer.size(), buffer), + async_write_some_at(out, 6 * buffer.size(), buffer), + async_write_some_at(out, 7 * buffer.size(), buffer)); + ADD_FAILURE() << "should cancel and unroll"; + } +}; + +task +stopTrigger(std::chrono::milliseconds ms, io_uring_context::scheduler sched) { + co_await stop_when( + schedule_at(sched, now(sched) + ms) | + then([ms] { std::printf("Timeout after %ldms\n", ms.count()); }), + never_sender()); +} +} // namespace + +TEST_F(IOUringTest, AsyncReadCancel) { + auto scheduler = ctx_.get_scheduler(); + // cancel the read from *nix pipe + sync_wait(stop_when(read(scheduler), stopTrigger(stopAfter, scheduler))); +} + +TEST_F(IOUringTest, AsyncWriteCancel) { + auto scheduler = ctx_.get_scheduler(); + // cancel the write into *nix pipe + sync_wait(stop_when(write(scheduler), stopTrigger(stopAfter, scheduler))); +} + +TEST_F(IOUringTest, AcceptCancel) { + auto scheduler = ctx_.get_scheduler(); + // cancel the accept stream + sync_wait(stop_when(accept(scheduler), stopTrigger(stopAfter, scheduler))); +} + +# endif // UNIFEX_NO_LIBURING +#endif // UNIFEX_NO_LIBURING