From c6aa9814bff8902d3efae1574fb9d1de633f1b24 Mon Sep 17 00:00:00 2001 From: Chuck Hastings <45364586+ChuckHastings@users.noreply.github.com> Date: Fri, 1 Dec 2023 14:03:44 -0500 Subject: [PATCH] Mtmg updates for rmm (#4031) Discovered RAII capabilities in RMM while reviewing issues related to MTMG testing. This PR modifies the MTMG implementation to use the RAII capabilities for setting the device id temporarily to another device. Authors: - Chuck Hastings (https://github.com/ChuckHastings) Approvers: - Seunghwa Kang (https://github.com/seunghwak) URL: https://github.com/rapidsai/cugraph/pull/4031 --- .../mtmg/detail/device_shared_wrapper.hpp | 8 +++---- cpp/include/cugraph/mtmg/handle.hpp | 21 +++++++------------ cpp/include/cugraph/mtmg/instance_manager.hpp | 10 ++------- cpp/include/cugraph/mtmg/resource_manager.hpp | 11 +++------- cpp/src/structure/detail/structure_utils.cuh | 1 + cpp/tests/mtmg/threaded_test.cu | 21 ++++++++++++++++--- 6 files changed, 35 insertions(+), 37 deletions(-) diff --git a/cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp b/cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp index c4cacb401af..3e4b2513a8d 100644 --- a/cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp +++ b/cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp @@ -57,10 +57,10 @@ class device_shared_wrapper_t { { std::lock_guard lock(lock_); - auto pos = objects_.find(handle.get_local_rank()); + auto pos = objects_.find(handle.get_rank()); CUGRAPH_EXPECTS(pos == objects_.end(), "Cannot overwrite wrapped object"); - objects_.insert(std::make_pair(handle.get_local_rank(), std::move(obj))); + objects_.insert(std::make_pair(handle.get_rank(), std::move(obj))); } /** @@ -90,7 +90,7 @@ class device_shared_wrapper_t { { std::lock_guard lock(lock_); - auto pos = objects_.find(handle.get_local_rank()); + auto pos = objects_.find(handle.get_rank()); CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object"); return pos->second; @@ -106,7 +106,7 @@ class device_shared_wrapper_t { { std::lock_guard lock(lock_); - auto pos = objects_.find(handle.get_local_rank()); + auto pos = objects_.find(handle.get_rank()); CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object"); diff --git a/cpp/include/cugraph/mtmg/handle.hpp b/cpp/include/cugraph/mtmg/handle.hpp index 6223de1781d..0b02091a3cc 100644 --- a/cpp/include/cugraph/mtmg/handle.hpp +++ b/cpp/include/cugraph/mtmg/handle.hpp @@ -32,18 +32,19 @@ namespace mtmg { * */ class handle_t { + handle_t(handle_t const&) = delete; + handle_t operator=(handle_t const&) = delete; + public: /** * @brief Constructor * * @param raft_handle Raft handle for the resources * @param thread_rank Rank for this thread + * @param device_id Device id for the device this handle operates on */ - handle_t(raft::handle_t const& raft_handle, int thread_rank, size_t device_id) - : raft_handle_(raft_handle), - thread_rank_(thread_rank), - local_rank_(raft_handle.get_comms().get_rank()), // FIXME: update for multi-node - device_id_(device_id) + handle_t(raft::handle_t const& raft_handle, int thread_rank, rmm::cuda_device_id device_id) + : raft_handle_(raft_handle), thread_rank_(thread_rank), device_id_raii_(device_id) { } @@ -118,18 +119,10 @@ class handle_t { */ int get_rank() const { return raft_handle_.get_comms().get_rank(); } - /** - * @brief Get local gpu rank - * - * @return local gpu rank - */ - int get_local_rank() const { return local_rank_; } - private: raft::handle_t const& raft_handle_; int thread_rank_; - int local_rank_; - size_t device_id_; + rmm::cuda_set_device_raii device_id_raii_; }; } // namespace mtmg diff --git a/cpp/include/cugraph/mtmg/instance_manager.hpp b/cpp/include/cugraph/mtmg/instance_manager.hpp index f819a5a0abe..f60063c4101 100644 --- a/cpp/include/cugraph/mtmg/instance_manager.hpp +++ b/cpp/include/cugraph/mtmg/instance_manager.hpp @@ -47,15 +47,10 @@ class instance_manager_t { ~instance_manager_t() { - int current_device{}; - RAFT_CUDA_TRY(cudaGetDevice(¤t_device)); - for (size_t i = 0; i < nccl_comms_.size(); ++i) { - RAFT_CUDA_TRY(cudaSetDevice(device_ids_[i].value())); + rmm::cuda_set_device_raii local_set_device(device_ids_[i]); RAFT_NCCL_TRY(ncclCommDestroy(*nccl_comms_[i])); } - - RAFT_CUDA_TRY(cudaSetDevice(current_device)); } /** @@ -75,8 +70,7 @@ class instance_manager_t { int gpu_id = local_id % raft_handle_.size(); int thread_id = local_id / raft_handle_.size(); - RAFT_CUDA_TRY(cudaSetDevice(device_ids_[gpu_id].value())); - return handle_t(*raft_handle_[gpu_id], thread_id, static_cast(gpu_id)); + return handle_t(*raft_handle_[gpu_id], thread_id, device_ids_[gpu_id]); } /** diff --git a/cpp/include/cugraph/mtmg/resource_manager.hpp b/cpp/include/cugraph/mtmg/resource_manager.hpp index 127944cf7ba..bc312c9ae77 100644 --- a/cpp/include/cugraph/mtmg/resource_manager.hpp +++ b/cpp/include/cugraph/mtmg/resource_manager.hpp @@ -89,7 +89,7 @@ class resource_manager_t { local_rank_map_.insert(std::pair(global_rank, local_device_id)); - RAFT_CUDA_TRY(cudaSetDevice(local_device_id.value())); + rmm::cuda_set_device_raii local_set_device(local_device_id); // FIXME: There is a bug in the cuda_memory_resource that results in a Hang. // using the pool resource as a work-around. @@ -182,14 +182,12 @@ class resource_manager_t { --gpu_row_comm_size; } - int current_device{}; - RAFT_CUDA_TRY(cudaGetDevice(¤t_device)); RAFT_NCCL_TRY(ncclGroupStart()); for (size_t i = 0; i < local_ranks_to_include.size(); ++i) { int rank = local_ranks_to_include[i]; auto pos = local_rank_map_.find(rank); - RAFT_CUDA_TRY(cudaSetDevice(pos->second.value())); + rmm::cuda_set_device_raii local_set_device(pos->second); nccl_comms.push_back(std::make_unique()); handles.push_back( @@ -204,7 +202,6 @@ class resource_manager_t { handles[i].get(), *nccl_comms[i], ranks_to_include.size(), rank); } RAFT_NCCL_TRY(ncclGroupEnd()); - RAFT_CUDA_TRY(cudaSetDevice(current_device)); std::vector running_threads; @@ -217,9 +214,7 @@ class resource_manager_t { &device_ids, &nccl_comms, &handles]() { - int rank = local_ranks_to_include[idx]; - RAFT_CUDA_TRY(cudaSetDevice(device_ids[idx].value())); - + rmm::cuda_set_device_raii local_set_device(device_ids[idx]); cugraph::partition_manager::init_subcomm(*handles[idx], gpu_row_comm_size); }); } diff --git a/cpp/src/structure/detail/structure_utils.cuh b/cpp/src/structure/detail/structure_utils.cuh index c49b62e4543..7630d5855a0 100644 --- a/cpp/src/structure/detail/structure_utils.cuh +++ b/cpp/src/structure/detail/structure_utils.cuh @@ -524,6 +524,7 @@ std::tuple> mark_entries(raft::handle_t co return word; }); + // FIXME: use detail::count_set_bits size_t bit_count = thrust::transform_reduce( handle.get_thrust_policy(), marked_entries.begin(), diff --git a/cpp/tests/mtmg/threaded_test.cu b/cpp/tests/mtmg/threaded_test.cu index bc4d8cfef6a..1a6a17eaa18 100644 --- a/cpp/tests/mtmg/threaded_test.cu +++ b/cpp/tests/mtmg/threaded_test.cu @@ -155,10 +155,25 @@ class Tests_Multithreaded input_usecase.template construct_edgelist( handle, multithreaded_usecase.test_weighted, false, false); + rmm::device_uvector d_unique_vertices(2 * d_src_v.size(), handle.get_stream()); + thrust::copy( + handle.get_thrust_policy(), d_src_v.begin(), d_src_v.end(), d_unique_vertices.begin()); + thrust::copy(handle.get_thrust_policy(), + d_dst_v.begin(), + d_dst_v.end(), + d_unique_vertices.begin() + d_src_v.size()); + thrust::sort(handle.get_thrust_policy(), d_unique_vertices.begin(), d_unique_vertices.end()); + + d_unique_vertices.resize(thrust::distance(d_unique_vertices.begin(), + thrust::unique(handle.get_thrust_policy(), + d_unique_vertices.begin(), + d_unique_vertices.end())), + handle.get_stream()); + auto h_src_v = cugraph::test::to_host(handle, d_src_v); auto h_dst_v = cugraph::test::to_host(handle, d_dst_v); auto h_weights_v = cugraph::test::to_host(handle, d_weights_v); - auto unique_vertices = cugraph::test::to_host(handle, d_vertices_v); + auto unique_vertices = cugraph::test::to_host(handle, d_unique_vertices); // Load edgelist from different threads. We'll use more threads than GPUs here for (int i = 0; i < num_threads; ++i) { @@ -293,13 +308,13 @@ class Tests_Multithreaded num_threads]() { auto thread_handle = instance_manager->get_handle(); - auto number_of_vertices = unique_vertices->size(); + auto number_of_vertices = unique_vertices.size(); std::vector my_vertex_list; my_vertex_list.reserve((number_of_vertices + num_threads - 1) / num_threads); for (size_t j = i; j < number_of_vertices; j += num_threads) { - my_vertex_list.push_back((*unique_vertices)[j]); + my_vertex_list.push_back(unique_vertices[j]); } rmm::device_uvector d_my_vertex_list(my_vertex_list.size(),