Skip to content

Commit

Permalink
Mtmg updates for rmm (#4031)
Browse files Browse the repository at this point in the history
Discovered RAII capabilities in RMM while reviewing issues related to MTMG testing.  This PR modifies the MTMG implementation to use the RAII capabilities for setting the device id temporarily to another device.

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)

URL: #4031
  • Loading branch information
ChuckHastings authored Dec 1, 2023
1 parent 5eaae7d commit c6aa981
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 37 deletions.
8 changes: 4 additions & 4 deletions cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());
CUGRAPH_EXPECTS(pos == objects_.end(), "Cannot overwrite wrapped object");

objects_.insert(std::make_pair(handle.get_local_rank(), std::move(obj)));
objects_.insert(std::make_pair(handle.get_rank(), std::move(obj)));
}

/**
Expand Down Expand Up @@ -90,7 +90,7 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());
CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object");

return pos->second;
Expand All @@ -106,7 +106,7 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());

CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object");

Expand Down
21 changes: 7 additions & 14 deletions cpp/include/cugraph/mtmg/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,19 @@ namespace mtmg {
*
*/
class handle_t {
handle_t(handle_t const&) = delete;
handle_t operator=(handle_t const&) = delete;

public:
/**
* @brief Constructor
*
* @param raft_handle Raft handle for the resources
* @param thread_rank Rank for this thread
* @param device_id Device id for the device this handle operates on
*/
handle_t(raft::handle_t const& raft_handle, int thread_rank, size_t device_id)
: raft_handle_(raft_handle),
thread_rank_(thread_rank),
local_rank_(raft_handle.get_comms().get_rank()), // FIXME: update for multi-node
device_id_(device_id)
handle_t(raft::handle_t const& raft_handle, int thread_rank, rmm::cuda_device_id device_id)
: raft_handle_(raft_handle), thread_rank_(thread_rank), device_id_raii_(device_id)
{
}

Expand Down Expand Up @@ -118,18 +119,10 @@ class handle_t {
*/
int get_rank() const { return raft_handle_.get_comms().get_rank(); }

/**
* @brief Get local gpu rank
*
* @return local gpu rank
*/
int get_local_rank() const { return local_rank_; }

private:
raft::handle_t const& raft_handle_;
int thread_rank_;
int local_rank_;
size_t device_id_;
rmm::cuda_set_device_raii device_id_raii_;
};

} // namespace mtmg
Expand Down
10 changes: 2 additions & 8 deletions cpp/include/cugraph/mtmg/instance_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,10 @@ class instance_manager_t {

~instance_manager_t()
{
int current_device{};
RAFT_CUDA_TRY(cudaGetDevice(&current_device));

for (size_t i = 0; i < nccl_comms_.size(); ++i) {
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[i].value()));
rmm::cuda_set_device_raii local_set_device(device_ids_[i]);
RAFT_NCCL_TRY(ncclCommDestroy(*nccl_comms_[i]));
}

RAFT_CUDA_TRY(cudaSetDevice(current_device));
}

/**
Expand All @@ -75,8 +70,7 @@ class instance_manager_t {
int gpu_id = local_id % raft_handle_.size();
int thread_id = local_id / raft_handle_.size();

RAFT_CUDA_TRY(cudaSetDevice(device_ids_[gpu_id].value()));
return handle_t(*raft_handle_[gpu_id], thread_id, static_cast<size_t>(gpu_id));
return handle_t(*raft_handle_[gpu_id], thread_id, device_ids_[gpu_id]);
}

/**
Expand Down
11 changes: 3 additions & 8 deletions cpp/include/cugraph/mtmg/resource_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class resource_manager_t {

local_rank_map_.insert(std::pair(global_rank, local_device_id));

RAFT_CUDA_TRY(cudaSetDevice(local_device_id.value()));
rmm::cuda_set_device_raii local_set_device(local_device_id);

// FIXME: There is a bug in the cuda_memory_resource that results in a Hang.
// using the pool resource as a work-around.
Expand Down Expand Up @@ -182,14 +182,12 @@ class resource_manager_t {
--gpu_row_comm_size;
}

int current_device{};
RAFT_CUDA_TRY(cudaGetDevice(&current_device));
RAFT_NCCL_TRY(ncclGroupStart());

for (size_t i = 0; i < local_ranks_to_include.size(); ++i) {
int rank = local_ranks_to_include[i];
auto pos = local_rank_map_.find(rank);
RAFT_CUDA_TRY(cudaSetDevice(pos->second.value()));
rmm::cuda_set_device_raii local_set_device(pos->second);

nccl_comms.push_back(std::make_unique<ncclComm_t>());
handles.push_back(
Expand All @@ -204,7 +202,6 @@ class resource_manager_t {
handles[i].get(), *nccl_comms[i], ranks_to_include.size(), rank);
}
RAFT_NCCL_TRY(ncclGroupEnd());
RAFT_CUDA_TRY(cudaSetDevice(current_device));

std::vector<std::thread> running_threads;

Expand All @@ -217,9 +214,7 @@ class resource_manager_t {
&device_ids,
&nccl_comms,
&handles]() {
int rank = local_ranks_to_include[idx];
RAFT_CUDA_TRY(cudaSetDevice(device_ids[idx].value()));

rmm::cuda_set_device_raii local_set_device(device_ids[idx]);
cugraph::partition_manager::init_subcomm(*handles[idx], gpu_row_comm_size);
});
}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/structure/detail/structure_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ std::tuple<size_t, rmm::device_uvector<uint32_t>> mark_entries(raft::handle_t co
return word;
});

// FIXME: use detail::count_set_bits
size_t bit_count = thrust::transform_reduce(
handle.get_thrust_policy(),
marked_entries.begin(),
Expand Down
21 changes: 18 additions & 3 deletions cpp/tests/mtmg/threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,25 @@ class Tests_Multithreaded
input_usecase.template construct_edgelist<vertex_t, weight_t>(
handle, multithreaded_usecase.test_weighted, false, false);

rmm::device_uvector<vertex_t> d_unique_vertices(2 * d_src_v.size(), handle.get_stream());
thrust::copy(
handle.get_thrust_policy(), d_src_v.begin(), d_src_v.end(), d_unique_vertices.begin());
thrust::copy(handle.get_thrust_policy(),
d_dst_v.begin(),
d_dst_v.end(),
d_unique_vertices.begin() + d_src_v.size());
thrust::sort(handle.get_thrust_policy(), d_unique_vertices.begin(), d_unique_vertices.end());

d_unique_vertices.resize(thrust::distance(d_unique_vertices.begin(),
thrust::unique(handle.get_thrust_policy(),
d_unique_vertices.begin(),
d_unique_vertices.end())),
handle.get_stream());

auto h_src_v = cugraph::test::to_host(handle, d_src_v);
auto h_dst_v = cugraph::test::to_host(handle, d_dst_v);
auto h_weights_v = cugraph::test::to_host(handle, d_weights_v);
auto unique_vertices = cugraph::test::to_host(handle, d_vertices_v);
auto unique_vertices = cugraph::test::to_host(handle, d_unique_vertices);

// Load edgelist from different threads. We'll use more threads than GPUs here
for (int i = 0; i < num_threads; ++i) {
Expand Down Expand Up @@ -293,13 +308,13 @@ class Tests_Multithreaded
num_threads]() {
auto thread_handle = instance_manager->get_handle();

auto number_of_vertices = unique_vertices->size();
auto number_of_vertices = unique_vertices.size();

std::vector<vertex_t> my_vertex_list;
my_vertex_list.reserve((number_of_vertices + num_threads - 1) / num_threads);

for (size_t j = i; j < number_of_vertices; j += num_threads) {
my_vertex_list.push_back((*unique_vertices)[j]);
my_vertex_list.push_back(unique_vertices[j]);
}

rmm::device_uvector<vertex_t> d_my_vertex_list(my_vertex_list.size(),
Expand Down

0 comments on commit c6aa981

Please sign in to comment.