Skip to content

Commit

Permalink
Merge pull request #6739 from Akshay-Venkatesh/topic/cuda-reg-whole-a…
Browse files Browse the repository at this point in the history
…lloc

UCS/MEMORY: register whole cuda allocations
  • Loading branch information
yosefe authored Jun 7, 2021
2 parents d6d412e + 919a95e commit 2b81e34
Show file tree
Hide file tree
Showing 20 changed files with 124 additions and 54 deletions.
21 changes: 16 additions & 5 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ static ucs_config_field_t ucp_config_table[] = {
"MD whose distance is queried when evaluating transport selection score",
ucs_offsetof(ucp_config_t, selection_cmp), UCS_CONFIG_TYPE_STRING},

{"MEMTYPE_REG_WHOLE_ALLOC_TYPES", "",
"Memory types which have whole allocations registered.\n"
"Allowed memory types: cuda, rocm, rocm-managed",
ucs_offsetof(ucp_config_t, ctx.reg_whole_alloc_bitmap),
UCS_CONFIG_TYPE_BITMAP(ucs_memory_type_names)},

{"WARN_INVALID_CONFIG", "y",
"Issue a warning in case of invalid device and/or transport configuration.",
ucs_offsetof(ucp_config_t, warn_invalid_config), UCS_CONFIG_TYPE_BOOL},
Expand Down Expand Up @@ -1722,24 +1728,29 @@ void ucp_memory_detect_slowpath(ucp_context_h context, const void *address,
uct_md_h md;

mem_attr.field_mask = UCT_MD_MEM_ATTR_FIELD_MEM_TYPE |
UCT_MD_MEM_ATTR_FIELD_BASE_ADDRESS |
UCT_MD_MEM_ATTR_FIELD_ALLOC_LENGTH |
UCT_MD_MEM_ATTR_FIELD_SYS_DEV;

for (i = 0; i < context->num_mem_type_detect_mds; ++i) {
md = context->tl_mds[context->mem_type_detect_mds[i]].md;
status = uct_md_mem_query(md, address, length, &mem_attr);
if (status == UCS_OK) {
mem_info->type = mem_attr.mem_type;
mem_info->sys_dev = mem_attr.sys_dev;
mem_info->type = mem_attr.mem_type;
mem_info->sys_dev = mem_attr.sys_dev;
mem_info->base_address = mem_attr.base_address;
mem_info->alloc_length = mem_attr.alloc_length;
if (context->memtype_cache != NULL) {
ucs_memtype_cache_update(context->memtype_cache, address,
length, mem_info);
ucs_memtype_cache_update(context->memtype_cache,
mem_attr.base_address,
mem_attr.alloc_length, mem_info);
}
return;
}
}

/* Memory type not detected by any memtype MD - assume it is host memory */
ucp_memory_info_set_host(mem_info);
ucs_memory_info_set_host(mem_info);
}

void
Expand Down
24 changes: 20 additions & 4 deletions src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ucp_thread.h"

#include <ucp/api/ucp.h>
#include <ucp/dt/dt.h>
#include <ucp/proto/proto.h>
#include <uct/api/uct.h>
#include <ucs/datastruct/mpool.h>
Expand Down Expand Up @@ -114,6 +115,8 @@ typedef struct ucp_context_config {
unsigned keepalive_num_eps;
/** Enable indirect IDs to object pointers in wire protocols */
ucs_on_off_auto_value_t proto_indirect_id;
/** Bitmap of memory types whose allocations are registered fully */
unsigned reg_whole_alloc_bitmap;
/** Error handler delay */
ucs_time_t err_handler_delay;
} ucp_context_config_t;
Expand Down Expand Up @@ -450,15 +453,15 @@ static UCS_F_ALWAYS_INLINE int ucp_memory_type_cache_is_empty(ucp_context_h cont
}

static UCS_F_ALWAYS_INLINE void
ucp_memory_info_set_host(ucs_memory_info_t *mem_info)
ucp_memory_info_set_host(ucp_memory_info_t *mem_info)
{
mem_info->type = UCS_MEMORY_TYPE_HOST;
mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
}

static UCS_F_ALWAYS_INLINE void
ucp_memory_detect(ucp_context_h context, const void *address, size_t length,
ucs_memory_info_t *mem_info)
ucp_memory_detect_internal(ucp_context_h context, const void *address,
size_t length, ucs_memory_info_t *mem_info)
{
ucs_status_t status;

Expand Down Expand Up @@ -494,7 +497,20 @@ ucp_memory_detect(ucp_context_h context, const void *address, size_t length,
return;

out_host_mem:
ucp_memory_info_set_host(mem_info);
/* Memory type cache lookup failed - assume it is host memory */
ucs_memory_info_set_host(mem_info);
}

static UCS_F_ALWAYS_INLINE void
ucp_memory_detect(ucp_context_h context, const void *address, size_t length,
ucp_memory_info_t *mem_info)
{
ucs_memory_info_t mem_info_internal;

ucp_memory_detect_internal(context, address, length, &mem_info_internal);

mem_info->type = mem_info_internal.type;
mem_info->sys_dev = mem_info_internal.sys_dev;
}


Expand Down
22 changes: 17 additions & 5 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map,
unsigned md_index;
ucs_status_t status;
ucs_log_level_t level;
ucs_memory_info_t mem_info;
size_t reg_length;
void *base_address;

if (reg_md_map == *md_map_p) {
return UCS_OK; /* shortcut - no changes required */
Expand Down Expand Up @@ -112,13 +115,22 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map,
continue;
}

base_address = address;
reg_length = length;

if (context->config.ext.reg_whole_alloc_bitmap & UCS_BIT(mem_type)) {
ucp_memory_detect_internal(context, address, length, &mem_info);
base_address = mem_info.base_address;
reg_length = mem_info.alloc_length;
}

/* MD supports registration, register new memh on it */
status = uct_md_mem_reg(context->tl_mds[md_index].md, address,
length, uct_flags, &uct_memh[memh_index]);
status = uct_md_mem_reg(context->tl_mds[md_index].md, base_address,
reg_length, uct_flags, &uct_memh[memh_index]);
if (status == UCS_OK) {
ucs_trace("registered address %p length %zu on md[%d]"
" memh[%d]=%p",
address, length, md_index, memh_index,
base_address, reg_length, md_index, memh_index,
uct_memh[memh_index]);
new_md_map |= UCS_BIT(md_index);
++memh_index;
Expand All @@ -131,7 +143,7 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map,
ucs_log(level,
"failed to register address %p mem_type bit 0x%lx length %zu on "
"md[%d]=%s: %s (md reg_mem_types 0x%"PRIx64")",
address, UCS_BIT(mem_type), length, md_index,
base_address, UCS_BIT(mem_type), reg_length, md_index,
context->tl_mds[md_index].rsc.md_name,
ucs_status_string(status),
md_attr->cap.reg_mem_types);
Expand Down Expand Up @@ -387,7 +399,7 @@ ucs_status_t ucp_mem_map(ucp_context_h context, const ucp_mem_map_params_t *para
ucp_mem_h *memh_p)
{
ucs_memory_type_t memory_type;
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;
ucs_status_t status;
unsigned flags;
void *address;
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_request.inl
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ static UCS_F_ALWAYS_INLINE ucs_memory_type_t
ucp_request_get_memory_type(ucp_context_h context, const void *address,
size_t length, const ucp_request_param_t *param)
{
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;

if (!(param->op_attr_mask & UCP_OP_ATTR_FIELD_MEMORY_TYPE) ||
(param->memory_type == UCS_MEMORY_TYPE_UNKNOWN)) {
Expand Down
4 changes: 2 additions & 2 deletions src/ucp/core/ucp_rkey.c
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_uct,
(context, md_map, memh, mem_info, sys_dev_map, sys_distance,
buffer),
ucp_context_h context, ucp_md_map_t md_map,
const uct_mem_h *memh, const ucs_memory_info_t *mem_info,
const uct_mem_h *memh, const ucp_memory_info_t *mem_info,
uint64_t sys_dev_map,
const ucs_sys_dev_distance_t *sys_distance, void *buffer)
{
Expand Down Expand Up @@ -184,7 +184,7 @@ UCS_PROFILE_FUNC(ssize_t, ucp_rkey_pack_uct,
ucs_status_t ucp_rkey_pack(ucp_context_h context, ucp_mem_h memh,
void **rkey_buffer_p, size_t *size_p)
{
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;
ucs_status_t status;
ssize_t packed_size;
void *rkey_buffer;
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_rkey.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ void ucp_rkey_packed_copy(ucp_context_h context, ucp_md_map_t md_map,

ssize_t
ucp_rkey_pack_uct(ucp_context_h context, ucp_md_map_t md_map,
const uct_mem_h *memh, const ucs_memory_info_t *mem_info,
const uct_mem_h *memh, const ucp_memory_info_t *mem_info,
uint64_t sys_dev_map,
const ucs_sys_dev_distance_t *sys_distance, void *buffer);

Expand Down
2 changes: 1 addition & 1 deletion src/ucp/dt/datatype_iter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
*/
typedef struct {
ucp_dt_class_t dt_class; /* Datatype class (contig/iov/...) */
ucs_memory_info_t mem_info; /* Memory type and locality, needed to
ucp_memory_info_t mem_info; /* Memory type and locality, needed to
pack/unpack */
size_t length; /* Total packed flat length */
size_t offset; /* Current flat offset */
Expand Down
9 changes: 9 additions & 0 deletions src/ucp/dt/dt.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@ typedef struct ucp_dt_state {
} ucp_dt_state_t;


/**
* UCP layer memory information
*/
typedef struct {
uint8_t type; /**< Memory type, use uint8 for compact size */
ucs_sys_device_t sys_dev; /**< System device index */
} ucp_memory_info_t;


extern const char *ucp_datatype_class_names[];

size_t ucp_dt_pack(ucp_worker_h worker, ucp_datatype_t datatype,
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/proto/proto_common.inl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ ucp_proto_request_send_op(ucp_ep_h ep, ucp_proto_select_t *proto_select,
const void *buffer, size_t count, ucp_datatype_t datatype,
size_t contig_length, const ucp_request_param_t *param)
{
ucp_worker_h worker = ep->worker;
ucp_worker_h worker = ep->worker;
ucp_proto_select_param_t sel_param;
ucs_status_t status;
uint8_t sg_count;
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/proto/proto_select.c
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ ucp_proto_select_short_init(ucp_worker_h worker, ucp_proto_select_t *proto_selec
const ucp_proto_threshold_elem_t *thresh;
ucp_proto_select_param_t select_param;
const ucp_proto_single_priv_t *spriv;
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;
uint32_t op_attr;

ucp_memory_info_set_host(&mem_info);
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/proto/proto_select.inl
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ static UCS_F_ALWAYS_INLINE void
ucp_proto_select_param_init(ucp_proto_select_param_t *select_param,
ucp_operation_id_t op_id, uint32_t op_attr_mask,
ucp_dt_class_t dt_class,
const ucs_memory_info_t *mem_info, uint8_t sg_count)
const ucp_memory_info_t *mem_info, uint8_t sg_count)
{
if (dt_class == UCP_DATATYPE_CONTIG) {
ucs_assert(sg_count == 1);
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/rndv/proto_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ ucp_proto_rndv_ctrl_init(const ucp_proto_rndv_ctrl_init_params_t *params)
ucp_proto_perf_range_t *perf_range;
const uct_iface_attr_t *iface_attr;
ucs_linear_func_t send_overheads;
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;
ucp_md_index_t md_index;
ucp_proto_caps_t *caps;
ucs_status_t status;
Expand Down
2 changes: 1 addition & 1 deletion src/ucp/rndv/proto_rndv.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ typedef struct {
double perf_bias;

/* Memory type of the transfer */
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;

/* Minimal data length */
size_t min_length;
Expand Down
4 changes: 2 additions & 2 deletions src/ucp/rndv/rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ size_t ucp_rndv_rts_pack(ucp_request_t *sreq, ucp_rndv_rts_hdr_t *rndv_rts_hdr,
ucp_rndv_rts_opcode_t opcode)
{
ucp_worker_h worker = sreq->send.ep->worker;
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;
ssize_t packed_rkey_size;
void *rkey_buf;

Expand Down Expand Up @@ -141,7 +141,7 @@ static size_t ucp_rndv_rtr_pack(void *dest, void *arg)
ucp_rndv_rtr_hdr_t *rndv_rtr_hdr = dest;
ucp_request_t *rreq = ucp_request_get_super(rndv_req);
ucp_ep_h ep = rndv_req->send.ep;
ucs_memory_info_t mem_info;
ucp_memory_info_t mem_info;
ssize_t packed_rkey_size;

/* Request ID of sender side (remote) */
Expand Down
3 changes: 2 additions & 1 deletion src/ucs/memory/memory_type.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ const char *ucs_memory_type_names[] = {
[UCS_MEMORY_TYPE_CUDA_MANAGED] = "cuda-managed",
[UCS_MEMORY_TYPE_ROCM] = "rocm",
[UCS_MEMORY_TYPE_ROCM_MANAGED] = "rocm-managed",
[UCS_MEMORY_TYPE_LAST] = "unknown"
[UCS_MEMORY_TYPE_LAST] = "unknown",
[UCS_MEMORY_TYPE_LAST + 1] = NULL
};

const char *ucs_memory_type_descs[] = {
Expand Down
50 changes: 30 additions & 20 deletions src/ucs/memory/memtype_cache.c
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,18 @@ typedef enum {
static UCS_F_ALWAYS_INLINE void
ucs_memory_info_set_unknown(ucs_memory_info_t *mem_info)
{
mem_info->type = UCS_MEMORY_TYPE_UNKNOWN;
mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
mem_info->type = UCS_MEMORY_TYPE_UNKNOWN;
mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
mem_info->base_address = NULL;
mem_info->alloc_length = -1;
}

void ucs_memory_info_set_host(ucs_memory_info_t *mem_info)
{
mem_info->type = UCS_MEMORY_TYPE_HOST;
mem_info->sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN;
mem_info->base_address = NULL;
mem_info->alloc_length = -1;
}

static ucs_pgt_dir_t *ucs_memtype_cache_pgt_dir_alloc(const ucs_pgtable_t *pgtable)
Expand Down Expand Up @@ -91,11 +101,13 @@ static void ucs_memtype_cache_insert(ucs_memtype_cache_t *memtype_cache,
return;
}

ucs_trace("memtype_cache: insert " UCS_PGT_REGION_FMT " mem_type %s dev %s",
ucs_trace("memtype_cache: insert " UCS_PGT_REGION_FMT " mem_type %s dev %s"
" base_addr %p alloc_length %ld",
UCS_PGT_REGION_ARG(&region->super),
ucs_memory_type_names[mem_info->type],
ucs_topo_sys_device_bdf_name(mem_info->sys_dev, dev_name,
sizeof(dev_name)));
sizeof(dev_name)),
mem_info->base_address, mem_info->alloc_length);
}

static void ucs_memtype_cache_region_collect_callback(const ucs_pgtable_t *pgtable,
Expand Down Expand Up @@ -127,23 +139,17 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
start = ucs_align_down_pow2((uintptr_t)address, UCS_PGT_ADDR_ALIGN);
end = ucs_align_up_pow2 ((uintptr_t)address + size, UCS_PGT_ADDR_ALIGN);

ucs_trace("%s: [0x%lx..0x%lx] mem_type %s dev %s",
ucs_trace("%s: [0x%lx..0x%lx] mem_type %s dev %s"
" base_addr %p alloc_length %ld",
(action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) ? "update" :
"remove",
start, end, ucs_memory_type_names[mem_info->type],
ucs_topo_sys_device_bdf_name(mem_info->sys_dev, dev_name,
sizeof(dev_name)));
sizeof(dev_name)),
mem_info->base_address, mem_info->alloc_length);

if (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) {
/* try to find regions that are contiguous and instersected
* with current one */
search_start = start - 1;
search_end = end;
} else {
/* try to find regions that are instersected with current one */
search_start = start;
search_end = end - 1;
}
search_start = start;
search_end = end - 1;

pthread_rwlock_wrlock(&memtype_cache->lock);

Expand Down Expand Up @@ -176,11 +182,13 @@ UCS_PROFILE_FUNC_VOID(ucs_memtype_cache_update_internal,
goto out_unlock;
}

ucs_trace("memtype_cache: removed " UCS_PGT_REGION_FMT " %s dev %s",
ucs_trace("memtype_cache: removed " UCS_PGT_REGION_FMT " %s dev %s"
" base_addr %p alloc_length %ld",
UCS_PGT_REGION_ARG(&region->super),
ucs_memory_type_names[region->mem_info.type],
ucs_topo_sys_device_bdf_name(region->mem_info.sys_dev,
dev_name, sizeof(dev_name)));
dev_name, sizeof(dev_name)),
mem_info->base_address, mem_info->alloc_length);
}

if (action == UCS_MEMTYPE_CACHE_ACTION_SET_MEMTYPE) {
Expand Down Expand Up @@ -232,8 +240,10 @@ static void ucs_memtype_cache_event_callback(ucm_event_type_t event_type,
{
ucs_memtype_cache_t *memtype_cache = arg;
ucs_memory_info_t mem_info = {
.type = event->mem_type.mem_type,
.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN
.type = event->mem_type.mem_type,
.sys_dev = UCS_SYS_DEVICE_ID_UNKNOWN,
.base_address = event->mem_type.address,
.alloc_length = event->mem_type.size,
};
ucs_memtype_cache_action_t action;

Expand Down
Loading

0 comments on commit 2b81e34

Please sign in to comment.