Skip to content

Commit

Permalink
CORE: add ucc task internal status (#441) (#452)
Browse files Browse the repository at this point in the history
Co-authored-by: Sergey Lebedev <[email protected]>
  • Loading branch information
shimmybalsam and Sergei-Lebedev authored Apr 4, 2022
1 parent 1e9f7af commit 20670f8
Show file tree
Hide file tree
Showing 38 changed files with 277 additions and 335 deletions.
12 changes: 6 additions & 6 deletions src/components/tl/nccl/allgatherv/allgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ ucc_status_t ucc_tl_nccl_allgatherv_p2p_start(ucc_coll_task_t *coll_task)
size_t sdt_size, rdt_size, count, displ;
ucc_rank_t peer;

task->super.super.status = UCC_INPROGRESS;
sdt_size = ucc_dt_size(args->src.info.datatype);
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
task->super.status = UCC_INPROGRESS;
sdt_size = ucc_dt_size(args->src.info.datatype);
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
count = args->src.info.count;
Expand Down Expand Up @@ -129,7 +129,7 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcopy_start(ucc_coll_task_t *coll_task)
size_t max_count, rdt_size, sdt_size, displ, scount, rcount;
ucc_rank_t peer;

task->super.super.status = UCC_INPROGRESS;
task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
max_count = task->allgatherv_bcopy.max_count;
scount = args->src.info.count;
Expand Down Expand Up @@ -236,8 +236,8 @@ ucc_status_t ucc_tl_nccl_allgatherv_bcast_start(ucc_coll_task_t *coll_task)
size_t rdt_size, count, displ;
ucc_rank_t peer;

task->super.super.status = UCC_INPROGRESS;
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
task->super.status = UCC_INPROGRESS;
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgatherv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
for (peer = 0; peer < size; peer++) {
Expand Down
36 changes: 15 additions & 21 deletions src/components/tl/nccl/tl_nccl_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ ucc_status_t ucc_tl_nccl_collective_sync(ucc_tl_nccl_task_t *task,
ucc_status_t status = UCC_OK;
CUresult cu_status;

task->host_status = task->super.super.status;
task->host_status = task->super.status;
if (ctx->cfg.sync_type == UCC_TL_NCCL_COMPLETION_SYNC_TYPE_EVENT) {
status = ucc_mc_ee_event_post(stream, task->completed,
UCC_EE_CUDA_STREAM);
Expand All @@ -166,14 +166,8 @@ ucc_status_t ucc_tl_nccl_collective_sync(ucc_tl_nccl_task_t *task,
}
}

status = task->super.progress(&task->super);
if (status == UCC_INPROGRESS) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(TASK_TEAM(task))->pq,
&task->super);
return UCC_OK;
}

return ucc_task_complete(&task->super);
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(TASK_TEAM(task))->pq,
&task->super);
}

ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
Expand All @@ -190,13 +184,13 @@ ucc_status_t ucc_tl_nccl_alltoall_start(ucc_coll_task_t *coll_task)
size_t data_size;
ucc_rank_t peer;

task->super.super.status = UCC_INPROGRESS;
data_size = (size_t)(args->src.info.count / gsize) *
task->super.status = UCC_INPROGRESS;
data_size = (size_t)(args->src.info.count / gsize) *
ucc_dt_size(args->src.info.datatype);
ucc_assert(args->src.info.count % gsize == 0);
if (data_size == 0) {
task->super.super.status = UCC_OK;
return UCC_OK;
task->super.status = UCC_OK;
return ucc_task_complete(&task->super);
}
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoall_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
Expand Down Expand Up @@ -242,9 +236,9 @@ ucc_status_t ucc_tl_nccl_alltoallv_start(ucc_coll_task_t *coll_task)
size_t sdt_size, rdt_size, count, displ;
ucc_rank_t peer;

task->super.super.status = UCC_INPROGRESS;
sdt_size = ucc_dt_size(args->src.info_v.datatype);
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
task->super.status = UCC_INPROGRESS;
sdt_size = ucc_dt_size(args->src.info_v.datatype);
rdt_size = ucc_dt_size(args->dst.info_v.datatype);
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_alltoallv_start", 0);
NCCLCHECK_GOTO(ncclGroupStart(), exit_coll, status, UCC_TL_TEAM_LIB(team));
for (peer = 0; peer < UCC_TL_TEAM_SIZE(team); peer++) {
Expand Down Expand Up @@ -304,7 +298,7 @@ ucc_status_t ucc_tl_nccl_allreduce_start(ucc_coll_task_t *coll_task)
ncclDataType_t dt;

dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(args->dst.info.datatype)];
task->super.super.status = UCC_INPROGRESS;
task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task,
args->coll_type == UCC_COLL_TYPE_BARRIER
? "nccl_barrier_start"
Expand Down Expand Up @@ -356,7 +350,7 @@ ucc_status_t ucc_tl_nccl_allgather_start(ucc_coll_task_t *coll_task)
src = (void *)((ptrdiff_t)args->dst.info.buffer + (count / size) *
ucc_dt_size(args->dst.info.datatype) * rank);
}
task->super.super.status = UCC_INPROGRESS;
task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_allgather_start", 0);
NCCLCHECK_GOTO(ncclAllGather(src, dst, count / size, dt,
team->nccl_comm, stream),
Expand Down Expand Up @@ -411,7 +405,7 @@ ucc_status_t ucc_tl_nccl_bcast_start(ucc_coll_task_t *coll_task)
ncclDataType_t dt;

dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(args->src.info.datatype)];
task->super.super.status = UCC_INPROGRESS;
task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_bcast_start", 0);
NCCLCHECK_GOTO(ncclBroadcast(src, src, count, dt, root, team->nccl_comm,
stream),
Expand Down Expand Up @@ -449,7 +443,7 @@ ucc_status_t ucc_tl_nccl_reduce_scatter_start(ucc_coll_task_t *coll_task)
ncclDataType_t dt;

dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(args->dst.info.datatype)];
task->super.super.status = UCC_INPROGRESS;
task->super.status = UCC_INPROGRESS;
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_reduce_scatter_start", 0);
if (UCC_IS_INPLACE(*args)) {
count /= UCC_TL_TEAM_SIZE(team);
Expand Down Expand Up @@ -507,7 +501,7 @@ ucc_status_t ucc_tl_nccl_reduce_start(ucc_coll_task_t *coll_task)
}
}
nccl_dt = ucc_to_nccl_dtype[UCC_DT_PREDEFINED_ID(ucc_dt)];
task->super.super.status = UCC_INPROGRESS;
task->super.status = UCC_INPROGRESS;
NCCLCHECK_GOTO(ncclReduce(src, dst, count, nccl_dt, op, args->root,
team->nccl_comm, stream),
exit_coll, status, UCC_TL_TEAM_LIB(team));
Expand Down
16 changes: 7 additions & 9 deletions src/components/tl/nccl/tl_nccl_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,33 +10,31 @@
#include "core/ucc_ee.h"


ucc_status_t ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task)
void ucc_tl_nccl_event_collective_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);
ucc_status_t status;

ucc_assert(task->completed != NULL);
status = ucc_mc_ee_event_test(task->completed, UCC_EE_CUDA_STREAM);
coll_task->super.status = status;
coll_task->status = status;
#ifdef HAVE_PROFILING_TL_NCCL
if (coll_task->super.status == UCC_OK) {
if (coll_task->status == UCC_OK) {
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_coll_done", 0);
}
#endif
return coll_task->super.status;
}

ucc_status_t ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task)
void ucc_tl_nccl_driver_collective_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_nccl_task_t *task = ucc_derived_of(coll_task, ucc_tl_nccl_task_t);

coll_task->super.status = task->host_status;
coll_task->status = task->host_status;
#ifdef HAVE_PROFILING_TL_NCCL
if (coll_task->super.status == UCC_OK) {
if (coll_task->status == UCC_OK) {
UCC_TL_NCCL_PROFILE_REQUEST_EVENT(coll_task, "nccl_coll_done", 0);
}
#endif
return coll_task->super.status;
}

static void ucc_tl_nccl_req_mpool_obj_init(ucc_mpool_t *mp, void *obj,
Expand Down Expand Up @@ -82,7 +80,7 @@ static void ucc_tl_nccl_req_mapped_mpool_obj_init(ucc_mpool_t *mp, void *obj,
st = cudaHostGetDevicePointer((void **)(&req->dev_status),
(void *)&req->host_status, 0);
if (st != cudaSuccess) {
req->super.super.status = UCC_ERR_NO_MESSAGE;
req->super.status = UCC_ERR_NO_MESSAGE;
}
req->super.progress = ucc_tl_nccl_driver_collective_progress;
}
Expand Down
32 changes: 10 additions & 22 deletions src/components/tl/sharp/tl_sharp_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ ucc_tl_sharp_mem_deregister(ucc_tl_sharp_context_t *ctx,
return UCC_OK;
}

ucc_status_t ucc_tl_sharp_collective_progress(ucc_coll_task_t *coll_task)
void ucc_tl_sharp_collective_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_sharp_task_t *task = ucc_derived_of(coll_task, ucc_tl_sharp_task_t);
int completed;
Expand All @@ -125,18 +125,18 @@ ucc_status_t ucc_tl_sharp_collective_progress(ucc_coll_task_t *coll_task)
if (completed) {
if (TASK_ARGS(task).coll_type == UCC_COLL_TYPE_ALLREDUCE) {
if (!UCC_IS_INPLACE(TASK_ARGS(task))) {
ucc_tl_sharp_mem_deregister(TASK_CTX(task), task->allreduce.s_mem_h);
ucc_tl_sharp_mem_deregister(TASK_CTX(task),
task->allreduce.s_mem_h);
}
ucc_tl_sharp_mem_deregister(TASK_CTX(task), task->allreduce.r_mem_h);
ucc_tl_sharp_mem_deregister(TASK_CTX(task),
task->allreduce.r_mem_h);
}
sharp_coll_req_free(task->req_handle);
coll_task->super.status = UCC_OK;
coll_task->status = UCC_OK;
UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task,
"sharp_collective_done", 0);
}
}

return coll_task->super.status;
}

ucc_status_t ucc_tl_sharp_barrier_start(ucc_coll_task_t *coll_task)
Expand All @@ -145,23 +145,17 @@ ucc_status_t ucc_tl_sharp_barrier_start(ucc_coll_task_t *coll_task)
ucc_tl_sharp_team_t *team = TASK_TEAM(task);
int ret;

task->super.super.status = UCC_INPROGRESS;
UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task, "sharp_barrier_start", 0);

ret = sharp_coll_do_barrier_nb(team->sharp_comm, &task->req_handle);
if (ret != SHARP_COLL_SUCCESS) {
tl_error(UCC_TASK_LIB(task), "sharp_coll_do_barrier_nb failed:%s",
sharp_coll_strerror(ret));
coll_task->super.status = UCC_ERR_NO_RESOURCE;
coll_task->status = UCC_ERR_NO_RESOURCE;
return ucc_task_complete(coll_task);
}

if (UCC_INPROGRESS == ucc_tl_sharp_collective_progress(coll_task)) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}

return ucc_task_complete(coll_task);
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_sharp_allreduce_start(ucc_coll_task_t *coll_task)
Expand All @@ -177,7 +171,6 @@ ucc_status_t ucc_tl_sharp_allreduce_start(ucc_coll_task_t *coll_task)
size_t data_size;
int ret;

task->super.super.status = UCC_INPROGRESS;
UCC_TL_SHARP_PROFILE_REQUEST_EVENT(coll_task, "sharp_allreduce_start", 0);

sharp_type = ucc_to_sharp_dtype[UCC_DT_PREDEFINED_ID(dt)];
Expand Down Expand Up @@ -217,16 +210,11 @@ ucc_status_t ucc_tl_sharp_allreduce_start(ucc_coll_task_t *coll_task)
if (ret != SHARP_COLL_SUCCESS) {
tl_error(UCC_TASK_LIB(task), "sharp_coll_do_allreduce_nb failed:%s",
sharp_coll_strerror(ret));
coll_task->super.status = UCC_ERR_NO_RESOURCE;
coll_task->status = UCC_ERR_NO_RESOURCE;
return ucc_task_complete(coll_task);
}

if (UCC_INPROGRESS == ucc_tl_sharp_collective_progress(coll_task)) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}

return ucc_task_complete(coll_task);
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_sharp_allreduce_init(ucc_tl_sharp_task_t *task)
Expand Down
3 changes: 0 additions & 3 deletions src/components/tl/ucp/allgather/allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
#include "tl_ucp.h"
#include "allgather.h"

ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *task);
ucc_status_t ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task)
{
if ((!UCC_DT_IS_PREDEFINED((TASK_ARGS(task)).dst.info.datatype)) ||
Expand Down
4 changes: 3 additions & 1 deletion src/components/tl/ucp/allgather/allgather.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include "../tl_ucp_coll.h"

ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task);
ucc_status_t ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);

void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *task);

ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *task);

/* Uses allgather_kn_radix from config */
Expand Down
25 changes: 9 additions & 16 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
task->allgather_kn.phase = _phase; \
} while (0)

ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
Expand Down Expand Up @@ -64,7 +64,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
if (KN_NODE_EXTRA == node_type) {
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
SAVE_STATE(UCC_KN_PHASE_EXTRA);
return task->super.super.status;
return;
}
goto out;
}
Expand Down Expand Up @@ -111,7 +111,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
UCC_KN_PHASE_LOOP:
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
SAVE_STATE(UCC_KN_PHASE_LOOP);
return task->super.super.status;
return;
}
ucc_knomial_pattern_next_iteration_backward(p);
}
Expand All @@ -128,29 +128,28 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
UCC_KN_PHASE_PROXY:
if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) {
SAVE_STATE(UCC_KN_PHASE_PROXY);
return task->super.super.status;
return;
}

out:
task->super.super.status = UCC_OK;
task->super.status = UCC_OK;
UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0);
return task->super.super.status;
}

ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
{
ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t);
ucc_coll_args_t *args = &TASK_ARGS(task);
ucc_tl_ucp_team_t *team = TASK_TEAM(task);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_rank_t rank = UCC_TL_TEAM_RANK(team);
ucc_rank_t size = UCC_TL_TEAM_SIZE(team);
ucc_kn_radix_t radix = task->allgather_kn.p.radix;
ucc_rank_t broot = 0;
ucc_status_t status;
ptrdiff_t offset;

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0);
ucc_tl_ucp_task_reset(task);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
if (coll_task->bargs.args.coll_type == UCC_COLL_TYPE_BCAST) {
broot = coll_task->bargs.args.root;
rank = VRANK(rank, broot, size);
Expand All @@ -173,13 +172,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)
}
}
task->allgather_kn.sbuf = PTR_OFFSET(args->dst.info.buffer, offset);

status = ucc_tl_ucp_allgather_knomial_progress(&task->super);
if (UCC_INPROGRESS == status) {
ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
return UCC_OK;
}
return ucc_task_complete(coll_task);
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
}

ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
Expand Down
Loading

0 comments on commit 20670f8

Please sign in to comment.