From 38c18f382577ee4a5587fd1876d2e2b546d004b8 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Sat, 13 Jan 2024 15:42:32 -0800 Subject: [PATCH] [c10d] Add a timeout check interval variable for timeout dump (#117093) The current timeout check frequency is relied on monitoring thread's timeout thread which can be too long (even if we set it to 2mins) so let's use a separate timeout variable which users can configure it. And we only only let default PG to check TCPStore so even more frequent check should be fine. (Our stress test is performed on every half second). Pull Request resolved: https://github.com/pytorch/pytorch/pull/117093 Approved by: https://github.com/wconstab, https://github.com/kwen2501 --- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 4 +++- torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 12b1b3b5e2c57..8e3e82b9be0a1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -770,6 +770,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( getCvarInt(TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC, 60 * 10 /*10 Mins*/); waitTimeoutDumpInMilSec_ = getCvarInt(TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC, 2000); + coordCheckIntervalMilSec_ = getCvarInt(TORCH_NCCL_COORD_CHECK_MILSEC, 1000); ncclTraceBufferSize_ = getCvarInt(TORCH_NCCL_TRACE_BUFFER_SIZE, 0); enableCollecticeHashDebug_ = (dist_debug_level_ >= DebugLevel::Detail); // store_ usually is wrapped with PrefixStore and the prefix is different @@ -859,6 +860,7 @@ ProcessGroupNCCL::ProcessGroupNCCL( << monitorThreadEnabled_.load() << ", TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC: " << heartbeatTimeoutInSec_ << ", TORCH_NCCL_TRACE_BUFFER_SIZE: " << ncclTraceBufferSize_ + << ", TORCH_NCCL_COORD_CHECK_MILSEC: " << coordCheckIntervalMilSec_ << ", ID=" << this->getID(); if (options_->global_ranks_in_group.empty()) { @@ -1549,7 +1551,7 @@ void ProcessGroupNCCL::watchdogHandler() { (currentTime - lastTimePollStore)) .count(); if (timeSinceLastWorkListUpdate >= kWatchdogThreadSleepMillis && - timeSinceLastPollStore >= heartbeatTimeoutInSec_ * 1000) { + timeSinceLastPollStore >= coordCheckIntervalMilSec_) { lastTimePollStore = currentTime; if (globalStore_->check({std::string(TIMEOUT_DUMP)}) && !optAsyncDebugDump) { diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 4a88b174c8ea4..72a2aae950ef4 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -89,6 +89,11 @@ static std::vector TORCH_NCCL_TRACE_BUFFER_SIZE = { static std::vector TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC = { "TORCH_NCCL_WAIT_TIMEOUT_DUMP_MILSEC"}; +// Control the interval inside the watchdog thread to check the coordinated +// signal from other ranks, e.g. to dump the debugging information. +static std::vector TORCH_NCCL_COORD_CHECK_MILSEC = { + "TORCH_NCCL_COORD_CHECK_MILSEC"}; + constexpr const char* NCCL_BACKEND_NAME = "nccl"; constexpr const char* TIMEOUT_DUMP = "timeout_dump"; @@ -853,6 +858,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Extra time of sleep when waiting for timeout dump to finish. int waitTimeoutDumpInMilSec_; + // Interval of check coordinated signals in ProcessGroupNCCL from other ranks + // e.g., trigger the dump of the debugging info for timeout when notified. + int coordCheckIntervalMilSec_; + // Size of ring buffer where we store NCCL Traces for debugging. int ncclTraceBufferSize_;