Skip to content

Commit

Permalink
TL/UCP: Allow self copy in allgather using network loopback
Browse files Browse the repository at this point in the history
  • Loading branch information
yaeliyac committed Dec 22, 2024
1 parent 73651ea commit caa90f8
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 81 deletions.
17 changes: 17 additions & 0 deletions src/components/tl/ucp/allgather/allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,20 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team)
UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num);
return str;
}

ucc_status_t loopback_self_copy(void* rbuf, void* sbuf, size_t data_size,
ucc_memory_type_t rmem, ucc_memory_type_t smem,
ucc_rank_t rank, ucc_tl_ucp_team_t *team, ucc_tl_ucp_task_t *task) {
ucc_status_t status;
status = ucc_tl_ucp_send_nb(sbuf, data_size, smem, rank, team, task);
if (UCC_OK != status) {
task->super.status = status;
return task->super.status;
}
status = ucc_tl_ucp_recv_nb(rbuf, data_size, rmem, rank, team, task);
if (UCC_OK != status) {
task->super.status = status;
return task->super.status;
}
return UCC_OK;
}
4 changes: 4 additions & 0 deletions src/components/tl/ucp/allgather/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define ALLGATHER_H_
#include "../tl_ucp.h"
#include "../tl_ucp_coll.h"
#include "tl_ucp_sendrecv.h"

enum {
UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL,
Expand Down Expand Up @@ -38,6 +39,9 @@ static inline int ucc_tl_ucp_allgather_alg_from_str(const char *str)

ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task);

ucc_status_t loopback_self_copy(void* rbuf, void* sbuf, size_t data_size, ucc_memory_type_t rmem, ucc_memory_type_t smem,
ucc_rank_t rank, ucc_tl_ucp_team_t *team, ucc_tl_ucp_task_t *task);

/* Ring */
ucc_status_t ucc_tl_ucp_allgather_ring_init(ucc_base_coll_args_t *coll_args,
ucc_base_team_t *team,
Expand Down
82 changes: 50 additions & 32 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "coll_patterns/sra_knomial.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include "allgather.h"

#define SAVE_STATE(_phase) \
do { \
Expand Down Expand Up @@ -54,22 +55,23 @@

void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task,
ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
uint8_t node_type = task->allgather_kn.p.node_type;
ucc_knomial_pattern_t *p = &task->allgather_kn.p;
void *rbuf = GET_DST(args);
ucc_memory_type_t mem_type = GET_MT(args);
size_t dt_size = ucc_dt_size(GET_DT(args));
ucc_rank_t size = task->subset.map.ep_num;
size_t data_size = GET_TOTAL_COUNT(args, size);
ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ?
args->root : 0;
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
size_t local = GET_LOCAL_COUNT(args, size, rank);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
uint8_t node_type = task->allgather_kn.p.node_type;
ucc_knomial_pattern_t *p = &task->allgather_kn.p;
void *rbuf = GET_DST(args);
ucc_memory_type_t mem_type = GET_MT(args);
size_t dt_size = ucc_dt_size(GET_DT(args));
ucc_rank_t size = task->subset.map.ep_num;
size_t data_size = GET_TOTAL_COUNT(args, size);
ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ?
args->root : 0;
ucc_rank_t rank = VRANK(task->subset.myrank, broot, size);
size_t local = GET_LOCAL_COUNT(args, size, rank);
int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback;
void *sbuf;
ptrdiff_t peer_seg_offset, local_seg_offset;
ucc_rank_t peer, peer_dist;
Expand All @@ -78,8 +80,13 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
ucc_status_t status;
size_t extra_count;

EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test",
task->allgather_kn.etask);
if (use_loopback) {
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)){
return;
}
} else {
EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", task->allgather_kn.etask);
}
task->allgather_kn.etask = NULL;
UCC_KN_GOTO_PHASE(task->allgather_kn.phase);
if (KN_NODE_EXTRA == node_type) {
Expand Down Expand Up @@ -209,6 +216,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
ct == UCC_COLL_TYPE_BCAST ?
args->root : 0, size);
ucc_ee_executor_task_args_t eargs = {0};
int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback;
ucc_status_t status;
ptrdiff_t offset;
ucc_ee_executor_t *exec;
Expand All @@ -225,21 +233,31 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
ucc_dt_size(args->dst.info.datatype);
rbuf = args->dst.info.buffer;
if (!UCC_IS_INPLACE(*args)) {
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset);
eargs.copy.src = args->src.info.buffer;
eargs.copy.len = args->src.info.count *
ucc_dt_size(args->src.info.datatype);
status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
if (use_loopback) {
status = loopback_self_copy(PTR_OFFSET(args->dst.info.buffer, offset),
args->src.info.buffer, args->src.info.count * ucc_dt_size(args->src.info.datatype),
args->dst.info.mem_type, args->src.info.mem_type, rank, team, task);
if (ucc_unlikely(status != UCC_OK)) {
return status;
}
} else {
/* Executer */
status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset);
eargs.copy.src = args->src.info.buffer;
eargs.copy.len = args->src.info.count *
ucc_dt_size(args->src.info.datatype);
status = ucc_ee_executor_task_post(exec, &eargs,
&task->allgather_kn.etask);
if (ucc_unlikely(status != UCC_OK)) {
task->super.status = status;
return status;
}
}
}
} else if (ct == UCC_COLL_TYPE_ALLGATHERV) {
Expand Down
33 changes: 20 additions & 13 deletions src/components/tl/ucp/allgather/allgather_neighbor.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,18 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task)

ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback;
ucc_status_t status;
ucc_rank_t neighbor;
void *tmprecv, *tmpsend;
Expand All @@ -150,8 +151,14 @@ ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task)
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);

if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf,
data_size, rmem, smem);
if (use_loopback) {
status = loopback_self_copy(PTR_OFFSET(rbuf, data_size * trank),
sbuf, data_size, rmem, smem, trank, team, task);
} else {
/* Use cuda copy */
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf,
data_size, rmem, smem);
}
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
Expand Down
57 changes: 33 additions & 24 deletions src/components/tl/ucp/allgather/allgather_ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,16 @@ static ucc_rank_t ucc_tl_ucp_allgather_ring_get_recv_block(ucc_subset_t *subset,

void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
size_t count = TASK_ARGS(task).dst.info.count;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback;
ucc_rank_t sendto, recvfrom, sblock, rblock;
int step;
void *buf;
Expand All @@ -49,9 +50,9 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
}
sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize);
recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize);
step = use_loopback ? task->tagged.send_posted - 1 : task->tagged.send_posted;

while (task->tagged.send_posted < tsize - 1) {
step = task->tagged.send_posted;
while (step < tsize - 1) {
sblock = task->allgather_ring.get_send_block(&task->subset, trank,
tsize, step);
rblock = task->allgather_ring.get_recv_block(&task->subset, trank,
Expand All @@ -67,6 +68,7 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
return;
}
step = use_loopback ? task->tagged.send_posted - 1 : task->tagged.send_posted;
}
ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task));
task->super.status = UCC_OK;
Expand All @@ -76,17 +78,19 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task)

ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = task->subset.myrank;
ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num;
ucc_rank_t rank = ucc_ep_map_eval(task->subset.map, trank);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback;
ucc_status_t status;
ucc_rank_t block;

Expand All @@ -96,13 +100,18 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task)
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
block = task->allgather_ring.get_send_block(&task->subset, trank, tsize,
0);
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
if (use_loopback) {
status = loopback_self_copy(PTR_OFFSET(rbuf, data_size * block),
sbuf, data_size, rmem, smem, rank, team, task);
} else {
/* Use cuda copy */
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block),
sbuf, data_size, rmem, smem);
}
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
}

return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

Expand Down
31 changes: 19 additions & 12 deletions src/components/tl/ucp/allgather/allgather_sparbit.c
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,18 @@ void ucc_tl_ucp_allgather_sparbit_progress(ucc_coll_task_t *coll_task)

ucc_status_t ucc_tl_ucp_allgather_sparbit_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
size_t count = TASK_ARGS(task).dst.info.count;
void *sbuf = TASK_ARGS(task).src.info.buffer;
void *rbuf = TASK_ARGS(task).dst.info.buffer;
ucc_memory_type_t smem = TASK_ARGS(task).src.info.mem_type;
ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type;
ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype;
ucc_rank_t trank = UCC_TL_TEAM_RANK(team);
ucc_rank_t tsize = UCC_TL_TEAM_SIZE(team);
size_t data_size = (count / tsize) * ucc_dt_size(dt);
int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback;
ucc_status_t status;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_sparbit_start",
Expand All @@ -131,8 +132,14 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_start(ucc_coll_task_t *coll_task)
task->allgather_sparbit.data_expected = 1;

if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf,
if (use_loopback) {
status = loopback_self_copy(PTR_OFFSET(rbuf, data_size * trank),
sbuf, data_size, rmem, smem, trank, team, task);
} else {
/* Use cuda copy */
status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf,
data_size, rmem, smem);
}
if (ucc_unlikely(UCC_OK != status)) {
return status;
}
Expand Down
4 changes: 4 additions & 0 deletions src/components/tl/ucp/tl_ucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = {
ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_kn_radix),
UCC_CONFIG_TYPE_UINT},

{"ALLGATHER_USE_LOOPBACK", "0", "If set to 1 performs network loopback for self copy, otherwise uses mc cuda copy",
ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_use_loopback),
UCC_CONFIG_TYPE_BOOL},

{"BCAST_KN_RADIX", "4", "Radix of the recursive-knomial bcast algorithm",
ucc_offsetof(ucc_tl_ucp_lib_config_t, bcast_kn_radix),
UCC_CONFIG_TYPE_UINT},
Expand Down
1 change: 1 addition & 0 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ typedef struct ucc_tl_ucp_lib_config {
ucc_mrange_uint_t allreduce_sra_kn_radix;
uint32_t reduce_scatter_kn_radix;
uint32_t allgather_kn_radix;
int allgather_use_loopback;
uint32_t bcast_kn_radix;
ucc_mrange_uint_t bcast_sag_kn_radix;
uint32_t reduce_kn_radix;
Expand Down

0 comments on commit caa90f8

Please sign in to comment.