Skip to content

Commit

Permalink
[Runtime][Dist] Implementation of KV cache transfer (#17557)
Browse files Browse the repository at this point in the history
This PR introduces kv transfer kernel and KV cache integration used
in prefill-decode disaggregation.

Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Charlie Ruan <[email protected]>
Co-authored-by: Yingcheng Wang <[email protected]>
  • Loading branch information
4 people authored Dec 15, 2024
1 parent 4454f8d commit 567eeed
Show file tree
Hide file tree
Showing 15 changed files with 1,743 additions and 40 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/flashinfer
Submodule flashinfer updated 1 files
+2 −2 src/tvm_wrapper.cu
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ if (USE_CUDA AND USE_NVSHMEM)
if (NOT NVSHMEM_FOUND)
message(FATAL_ERROR "Cannot find NVSHMEM, USE_NVSHMEM=" ${USE_NVSHMEM})
endif()
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc)
set(CMAKE_CUDA_SEPARABLE_COMPILATION ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
tvm_file_glob(GLOB RUNTIME_NVSHMEM_SRCS src/runtime/contrib/nvshmem/*.cc src/runtime/contrib/nvshmem/*.cu)
list(APPEND RUNTIME_SRCS ${RUNTIME_NVSHMEM_SRCS})
endif()

Expand Down
1 change: 1 addition & 0 deletions docs/how_to/tutorials/optimize_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def create_tir_paged_kv_cache(
rotary_dim=self.head_dim,
dtype=self.dtype,
target=target,
enable_disaggregation=False,
)

def get_default_spec(self):
Expand Down
33 changes: 27 additions & 6 deletions python/tvm/relax/frontend/nn/llm/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__( # pylint: disable=too-many-locals
rope_scaling: Dict[str, Any],
rope_ext_factors: rx.Expr,
rotary_dim: int,
enable_disaggregation: bool,
dtype: str,
target: Target,
name: str = "paged_kv_cache",
Expand Down Expand Up @@ -214,6 +215,8 @@ def __init__( # pylint: disable=too-many-locals
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to.
enable_disaggregation : bool
Whether to enable disaggregation in the KV cache.
"""
if rope_mode == RopeMode.INLINE:
assert rotary_dim == head_dim, "FlashInfer RoPE does not support partial rotary dim."
Expand Down Expand Up @@ -259,6 +262,7 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
rx.PrimValue(enable_disaggregation),
# fmt: on
# pylint: enable=line-too-long
]
Expand Down Expand Up @@ -293,6 +297,7 @@ def __init__( # pylint: disable=too-many-locals
rope_scaling: Dict[str, Any],
rope_ext_factors: rx.Expr,
rotary_dim: int,
enable_disaggregation: bool,
dtype: str,
target: Target,
name: str = "paged_kv_cache",
Expand Down Expand Up @@ -338,6 +343,8 @@ def __init__( # pylint: disable=too-many-locals
The RoPE extension factors when "longrope" mode RoPE scaling is enabled.
rotary_dim : int
The number of dimensions in the embedding that RoPE is applied to.
enable_disaggregation : bool
Whether to enable disaggregation in the KV cache.
target : Target
The target to build the model to.
"""
Expand Down Expand Up @@ -377,6 +384,7 @@ def __init__( # pylint: disable=too-many-locals
bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask"),
bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache"),
rope_ext_factors,
rx.PrimValue(enable_disaggregation),
# fmt: on
# pylint: enable=line-too-long
]
Expand Down Expand Up @@ -409,8 +417,9 @@ def tir_kv_cache_transpose_append(
T.func_attr({"tir.noalias": T.bool(True)})
ntoken = T.SizeVar("num_tokens_excluding_cache", "int64")
num_pages = T.int64()
pages_elem_offset = T.int64()
position_map_elem_offset = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype)
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset)
k_data = T.match_buffer(var_k_data, (ntoken, num_key_value_heads, head_dim), dtype)
v_data = T.match_buffer(var_v_data, (ntoken, num_key_value_heads, head_dim), dtype)
position_map = T.match_buffer(
Expand Down Expand Up @@ -453,8 +462,9 @@ def tir_kv_cache_debug_get_kv(
seqlen = T.SizeVar("num_tokens_including_cache", "int64")
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()
pages_elem_offset = T.int64()
position_map_elem_offset = T.int64()
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype)
pages = T.match_buffer(var_pages, (num_pages, 2, num_key_value_heads, page_size, head_dim), dtype,elem_offset=pages_elem_offset)
position_map = T.match_buffer(
var_position_map, (seqlen,), "int32", elem_offset=position_map_elem_offset
)
Expand Down Expand Up @@ -594,6 +604,7 @@ def batch_prefill_paged_kv(
total_len = T.int32(is_size_var=True)
nnz_pages = T.int32(is_size_var=True)
max_num_pages = T.int32(is_size_var=True)
pages_elem_offset = T.int64(is_size_var=True)
q_indptr_elem_offset = T.int32(is_size_var=True)
page_indptr_elem_offset = T.int32(is_size_var=True)
page_values_elem_offset = T.int32(is_size_var=True)
Expand All @@ -603,7 +614,7 @@ def batch_prefill_paged_kv(

q = T.match_buffer(var_q, (total_len, h_q, d), dtype)
q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset)
pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype)
pages = T.match_buffer(var_pages, (max_num_pages, 2, h_kv, 16, d), dtype, elem_offset=pages_elem_offset)
page_indptr = T.match_buffer(var_page_indptr, (batch_size + 1,), "int32", elem_offset=page_indptr_elem_offset)
page_values = T.match_buffer(var_page_values, (nnz_pages,), "int32", elem_offset=page_values_elem_offset)
k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset)
Expand Down Expand Up @@ -975,6 +986,7 @@ def batch_decode_paged_kv(
B = T.int32(is_size_var=True)
nnz_pages = T.int32(is_size_var=True)
max_num_pages = T.int32(is_size_var=True)
pages_elem_offset = T.int64(is_size_var=True)
page_indptr_elem_offset = T.int32(is_size_var=True)
page_values_elem_offset = T.int32(is_size_var=True)
k_rope_pos_offset_elem_offset = T.int32(is_size_var=True)
Expand All @@ -983,7 +995,7 @@ def batch_decode_paged_kv(

Q = T.match_buffer(Q_handle, (B, H_qo, D), qkv_dtype)
pages = T.match_buffer(
pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype
pages_handle, (max_num_pages, 2, H_kv, 16, D), qkv_dtype, elem_offset=pages_elem_offset
)
page_table_indptr = T.match_buffer(page_table_indptr_handle, (B + 1,), "int32", elem_offset=page_indptr_elem_offset)
page_table_values = T.match_buffer(page_table_values_handle, (nnz_pages,), "int32", elem_offset=page_values_elem_offset)
Expand Down Expand Up @@ -1949,7 +1961,13 @@ def copy_single_page(
):
T.func_attr({"tir.is_scheduled": 1})
num_pages = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, page_size, head_dim), dtype)
pages_elem_offset = T.int64()
pages = T.match_buffer(
var_pages,
(num_pages, 2, num_heads, page_size, head_dim),
dtype,
elem_offset=pages_elem_offset,
)

for b in T.thread_binding(
(copy_length * num_heads * head_dim + tx - 1) // tx, thread="blockIdx.x"
Expand Down Expand Up @@ -1993,7 +2011,10 @@ def compact_kv_copy(
total_copy_length = T.int32()
copy_length_indptr_elem_offset = T.int32()
copy_src_dst_pos_elem_offset = T.int32()
pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype)
pages_elem_offset = T.int64()
pages = T.match_buffer(
var_pages, (num_pages, 2, num_heads, 16, head_dim), dtype, elem_offset=pages_elem_offset
)
copy_length_indptr = T.match_buffer(
var_copy_length_indptr,
(batch_size + 1,),
Expand Down
58 changes: 54 additions & 4 deletions src/runtime/contrib/nvshmem/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
#include <nvshmem.h>
#include <nvshmemx.h>
#include <picojson.h>
#include <tvm/runtime/disco/disco_worker.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand All @@ -38,9 +39,14 @@ ShapeTuple InitNVSHMEMUID() {
return ShapeTuple(uid_64);
}

void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
DiscoWorker* worker = DiscoWorker::ThreadLocal();
ICHECK(worker != nullptr);
void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
int worker_id;
if (worker == nullptr) {
worker_id = worker_id_start;
} else {
worker_id = worker_id_start + worker->worker_id;
}
CHECK_EQ(uid_64.size(), UNIQUEID_PADDING + 1)
<< "ValueError: The length of unique_id must be " << UNIQUEID_PADDING << ", but got "
<< uid_64.size() << ".";
Expand All @@ -52,17 +58,61 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers) {
for (int i = 0; i < UNIQUEID_PADDING; ++i) {
uid.internal[i] = static_cast<char>(uid_64[i + 1]);
}
nvshmemx_set_attr_uniqueid_args(worker->worker_id, num_workers, &uid, &attr);
// FIXME: this is a hack to avoid the issue of NVSHMEM using Multi-process-per-GPU to initialize
cudaSetDevice(worker_id);
nvshmemx_set_attr_uniqueid_args(worker_id, num_workers, &uid, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE);
CUDA_CALL(cudaSetDevice(mype_node));
if (worker != nullptr) {
if (worker->default_device.device_type == DLDeviceType::kDLCPU) {
worker->default_device = Device{DLDeviceType::kDLCUDA, mype_node};
} else {
ICHECK(worker->default_device.device_type == DLDeviceType::kDLCUDA &&
worker->default_device.device_id == mype_node)
<< "The default device of the worker is inconsistent with the device used for NVSHMEM. "
<< "The default device is " << worker->default_device
<< ", but the device used for NVSHMEM is " << Device{DLDeviceType::kDLCUDA, mype_node}
<< ".";
}
}
LOG_INFO << "NVSHMEM init finished: mype=" << nvshmem_my_pe() << " "
<< ", npes=" << nvshmem_n_pes();
}

void InitNVSHMEMWrapper(String args) {
picojson::value v;
std::string err = picojson::parse(v, args);
if (!err.empty()) {
LOG(FATAL) << "JSON parse error: " << err;
}

if (!v.is<picojson::object>()) {
LOG(FATAL) << "JSON is not an object";
}

picojson::object& obj = v.get<picojson::object>();

picojson::array uid_array = obj["uid"].get<picojson::array>();
std::vector<int64_t> uid_vector;
for (const auto& elem : uid_array) {
uid_vector.push_back(elem.get<int64_t>());
}

ShapeTuple uid_64(uid_vector);

int num_workers = static_cast<int>(obj["npes"].get<int64_t>());
int worker_id_start = static_cast<int>(obj["pe_start"].get<int64_t>());

InitNVSHMEM(uid_64, num_workers, worker_id_start);
}

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper")
.set_body_typed(InitNVSHMEMWrapper);

} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 567eeed

Please sign in to comment.