From e7adcb05a6323941343fcd3286e8d08d15413138 Mon Sep 17 00:00:00 2001 From: Lightbulb128 Date: Wed, 17 Jul 2024 17:55:45 +0800 Subject: [PATCH] bugfix: memory pool shared between multiple threads. --- .gitignore | 3 +- src/kernel_provider.h | 1 + src/utils/memory_pool.h | 12 ++++ src/utils/memory_pool_safe.in | 29 ++++++--- src/utils/memory_pool_unsafe.in | 24 ++++++-- test/CMakeLists.txt | 8 ++- test/multithread.cu | 106 +++++++++++++++++++++++++++----- test/temp.cu | 76 ----------------------- 8 files changed, 153 insertions(+), 106 deletions(-) delete mode 100644 test/temp.cu diff --git a/.gitignore b/.gitignore index 8788ab3..bcca3c3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ reports/ binder/*.so gather_code.* .cache/ -.clangd \ No newline at end of file +.clangd +test/custom.cu \ No newline at end of file diff --git a/src/kernel_provider.h b/src/kernel_provider.h index 70719b4..7080750 100644 --- a/src/kernel_provider.h +++ b/src/kernel_provider.h @@ -69,6 +69,7 @@ namespace troy { if (status != cudaSuccess) { runtime_error("[kernel_provider::memset] cudaMemset failed", status); } + utils::stream_sync(); } template diff --git a/src/utils/memory_pool.h b/src/utils/memory_pool.h index 7afd946..d817741 100644 --- a/src/utils/memory_pool.h +++ b/src/utils/memory_pool.h @@ -21,8 +21,20 @@ namespace troy {namespace utils { return count; } + // If "TROY_STREAM_SYNC_AFTER_KERNEL_CALLS" is enabled, + // this will call cudaStreamSynchronize(0). Otherwise, it does nothing. void stream_sync(); + inline void stream_sync_concrete() { + cudaError_t status = cudaStreamSynchronize(0); + if (status != cudaSuccess) { + std::string msg = "[stream_sync_concrete] cudaStreamSynchronize failed: "; + msg += cudaGetErrorString(status); + throw std::runtime_error(msg); + } + } + + class MemoryPool; typedef std::shared_ptr MemoryPoolHandle; diff --git a/src/utils/memory_pool_safe.in b/src/utils/memory_pool_safe.in index d7aeb48..c4f803f 100644 --- a/src/utils/memory_pool_safe.in +++ b/src/utils/memory_pool_safe.in @@ -1,4 +1,5 @@ #include "memory_pool.h" +#include namespace troy::utils { @@ -12,13 +13,15 @@ namespace troy::utils { struct MemoryPool::Impl { static const int PRESERVED_MEMORY_BYTES = 1024 * 1024 * 32; std::shared_mutex mutex; - std::multimap unused; + + // key-value pairs are (allocated memory size, (pointer address, last used thread id)) + std::multimap> unused; + std::multimap allocated; std::set zombie; size_t total_allocated; bool destroyed; }; - // User shouldn't call this but call create() instead MemoryPool::MemoryPool(size_t device_index): device_index(device_index) { @@ -104,7 +107,7 @@ namespace troy::utils { throw std::runtime_error("[MemoryPool(safe)::release] The pointer is not in the allocated set."); } size_t size = iterator->second; - impl_->unused.insert(std::make_pair(size, ptr)); + impl_->unused.insert(std::make_pair(size, std::make_pair(ptr, std::this_thread::get_id()))); } void* MemoryPool::allocate(size_t required) { @@ -120,7 +123,17 @@ namespace troy::utils { lock.unlock(); return try_allocate(required); } else { - void* ptr = iterator->second; + void* ptr = iterator->second.first; + std::thread::id last_used_id = iterator->second.second; + if (last_used_id != std::this_thread::get_id()) { + // A new thread is now taking over the memory which was last used by another thread. + // To avoid data racing, we need to execute a cudaDeviceSync + // before returning the pointer. + cudaError_t status = cudaDeviceSynchronize(); + if (status != cudaSuccess) { + runtime_error("[MemoryPool(safe)::allocate] cudaDeviceSynchronize failed.", status); + } + } impl_->unused.erase(iterator); return ptr; } @@ -130,12 +143,12 @@ namespace troy::utils { std::unique_lock lock(impl_->mutex); for (auto it = impl_->unused.begin(); it != impl_->unused.end();) { set_device(); - cudaError_t status = cudaFree(it->second); + cudaError_t status = cudaFree(it->second.first); if (status != cudaSuccess) { runtime_error("[MemoryPool(safe)::release_unused] cudaFree failed.", status); } // remove pointer from allocated - auto it2 = impl_->allocated.find(it->second); + auto it2 = impl_->allocated.find(it->second.first); if (it2 == impl_->allocated.end()) { throw std::runtime_error("[MemoryPool(safe)::release_unused] The pointer is not in the allocated set."); } @@ -149,12 +162,12 @@ namespace troy::utils { // first release all unused for (auto it = impl_->unused.begin(); it != impl_->unused.end();) { set_device(); - cudaError_t status = cudaFree(it->second); + cudaError_t status = cudaFree(it->second.first); if (status != cudaSuccess) { runtime_error("[MemoryPool(safe)::destroy] cudaFree unused failed.", status); } // remove pointer from allocated - auto it2 = impl_->allocated.find(it->second); + auto it2 = impl_->allocated.find(it->second.first); if (it2 == impl_->allocated.end()) { throw std::runtime_error("[MemoryPool(safe)::destroy] The pointer is not in the allocated set."); } diff --git a/src/utils/memory_pool_unsafe.in b/src/utils/memory_pool_unsafe.in index d4d5a3b..5be786c 100644 --- a/src/utils/memory_pool_unsafe.in +++ b/src/utils/memory_pool_unsafe.in @@ -1,11 +1,15 @@ #include "memory_pool.h" +#include namespace troy::utils { struct MemoryPool::Impl { static const int PRESERVED_MEMORY_BYTES = 1024 * 1024 * 32; std::shared_mutex mutex; - std::multimap unused; + + // key-value pairs are (allocated memory size, (pointer address, last used thread id)) + std::multimap> unused; + std::multimap allocated; size_t total_allocated; }; @@ -67,7 +71,7 @@ namespace troy::utils { throw std::runtime_error("[MemoryPool(unsafe)::release] The pointer is not in the allocated set."); } size_t size = iterator->second; - impl_->unused.insert(std::make_pair(size, ptr)); + impl_->unused.insert(std::make_pair(size, std::make_pair(ptr, std::this_thread::get_id()))); } void* MemoryPool::allocate(size_t required) { @@ -80,7 +84,17 @@ namespace troy::utils { lock.unlock(); return try_allocate(required); } else { - void* ptr = iterator->second; + void* ptr = iterator->second.first; + std::thread::id last_used_id = iterator->second.second; + if (last_used_id != std::this_thread::get_id()) { + // A new thread is now taking over the memory which was last used by another thread. + // To avoid data racing, we need to execute a cudaDeviceSync + // before returning the pointer. + cudaError_t status = cudaDeviceSynchronize(); + if (status != cudaSuccess) { + runtime_error("[MemoryPool(safe)::allocate] cudaDeviceSynchronize failed.", status); + } + } impl_->unused.erase(iterator); return ptr; } @@ -90,12 +104,12 @@ namespace troy::utils { std::unique_lock lock(impl_->mutex); for (auto it = impl_->unused.begin(); it != impl_->unused.end();) { set_device(); - cudaError_t status = cudaFree(it->second); + cudaError_t status = cudaFree(it->second.first); if (status != cudaSuccess) { runtime_error("[MemoryPool(unsafe)::release_unused] cudaFree failed.", status); } // remove pointer from allocated - auto it2 = impl_->allocated.find(it->second); + auto it2 = impl_->allocated.find(it->second.first); if (it2 == impl_->allocated.end()) { throw std::runtime_error("[MemoryPool(unsafe)::release_unused] The pointer is not in the allocated set."); } diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b276c6c..14dd405 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -83,4 +83,10 @@ if(TROY_BENCH) ) target_link_libraries(bench_ntt troy) -endif() \ No newline at end of file +endif() + +# add_executable(custom) +# target_sources(custom PRIVATE +# custom.cu +# ) +# target_link_libraries(custom troy) diff --git a/test/multithread.cu b/test/multithread.cu index ef98910..445e653 100644 --- a/test/multithread.cu +++ b/test/multithread.cu @@ -106,17 +106,17 @@ namespace multithread { } - TEST(MultithreadTest, HostSinglePoolMultiThread) { + TEST(MultithreadTest, HostSharedPoolSimple) { GeneralHeContext ghe(false, SchemeType::BFV, 32, 20, { 60, 40, 40, 60 }, false, 0x123, 0); test_single_pool_multi_thread(ghe, 64, 4); } - TEST(MultithreadTest, DeviceSinglePoolMultiThread) { + TEST(MultithreadTest, DeviceSharedPoolSimple) { GeneralHeContext ghe(true, SchemeType::BFV, 32, 20, { 60, 40, 40, 60 }, false, 0x123, 0); test_single_pool_multi_thread(ghe, 64, 4); utils::MemoryPool::Destroy(); } - bool test_multiple_pools(const GeneralHeContext& context, size_t device_index, size_t thread_index = 0) { + bool test_troublesome_pools(const GeneralHeContext& context, size_t device_index, size_t thread_index = 0, bool check_pool = true, bool shared_pool = false) { auto context_pool = context.context()->pool(); @@ -128,7 +128,8 @@ namespace multithread { IF_FALSE_PRINT_RETURN(context_pool == nullptr, "context_pool"); } - auto good_pool = [device, context_pool](MemoryPoolHandle pool, MemoryPoolHandle expect) { + auto good_pool = [device, context_pool, check_pool](MemoryPoolHandle pool, MemoryPoolHandle expect) { + if (!check_pool) return true; if (device) { bool all_on_device = pool != nullptr && expect != nullptr && context_pool != nullptr; bool same = pool != context_pool && pool == expect; @@ -139,7 +140,8 @@ namespace multithread { } }; - auto create_new_memory_pool = [device, device_index]() { + auto create_new_memory_pool = [device, device_index, shared_pool, &context]() { + if (shared_pool) return context.pool(); return device ? MemoryPool::create(device_index) : nullptr; }; @@ -857,7 +859,7 @@ namespace multithread { Encryptor encryptor_other = Encryptor(context.context()); encryptor_other.set_secret_key(secret_key_other, context_pool_other); KSwitchKeys kswitch_key = context.key_generator().create_keyswitching_key(secret_key_other, false, context_pool_other); - if (device) context_pool_other->deny(); + if (device && check_pool && !shared_pool) context_pool_other->deny(); MemoryPoolHandle pool = create_new_memory_pool(); GeneralVector message = context.random_simd_full(); @@ -1171,11 +1173,81 @@ namespace multithread { } + return true; + } - return true; + void test_shared_pool(size_t threads, bool device, SchemeType scheme, size_t n, size_t log_t, vector log_qi, + bool expand_mod_chain, uint64_t seed, double input_max = 0, double scale = 0, double tolerance = 1e-4, + bool to_device_after_keygeneration = false, bool use_special_prime_for_encryption = false + ) { + + GeneralHeContext context( + device, scheme, n, log_t, log_qi, + expand_mod_chain, seed, input_max, scale, tolerance, + to_device_after_keygeneration, use_special_prime_for_encryption + ); + + auto test_thread = [&context](int thread) { + return test_troublesome_pools(context, 0, thread, false, true); + }; + + utils::stream_sync(); + vector> thread_instances; + for (size_t i = 0; i < threads; i++) { + thread_instances.push_back(std::async(test_thread, i)); + } + + for (size_t i = 0; i < threads; i++) { + ASSERT_TRUE(thread_instances[i].get()); + } } + + static constexpr size_t SHARED_POOL_THREADS = 16; + + TEST(MultithreadTest, HostBFVSharedPool) { + test_shared_pool(4, false, + SchemeType::BFV, 32, 35, + { 60, 40, 40, 60 }, true, 0x123, 0 + ); + } + TEST(MultithreadTest, DeviceBFVSharedPool) { + test_shared_pool(SHARED_POOL_THREADS, true, + SchemeType::BFV, 32, 35, + { 60, 40, 40, 60 }, true, 0x123, 0 + ); + MemoryPool::Destroy(); + } + TEST(MultithreadTest, HostBGVSharedPool) { + test_shared_pool(4, false, + SchemeType::BGV, 32, 35, + { 60, 40, 40, 60 }, true, 0x123, 0 + ); + } + TEST(MultithreadTest, DeviceBGVSharedPool) { + test_shared_pool(SHARED_POOL_THREADS, true, + SchemeType::BGV, 32, 35, + { 60, 40, 40, 60 }, true, 0x123, 0 + ); + MemoryPool::Destroy(); + } + TEST(MultithreadTest, HostCKKSSharedPool) { + test_shared_pool(4, false, + SchemeType::CKKS, 32, 0, + { 60, 40, 40, 60 }, true, 0x123, + 10, 1ull<<20, 1e-2 + ); + } + TEST(MultithreadTest, DeviceCKKSSharedPool) { + test_shared_pool(SHARED_POOL_THREADS, true, + SchemeType::CKKS, 32, 0, + { 60, 40, 40, 60 }, true, 0x123, + 10, 1ull<<20, 1e-2 + ); + MemoryPool::Destroy(); + } + void test_shared_context_multiple_pools(size_t threads, bool device, SchemeType scheme, size_t n, size_t log_t, vector log_qi, bool expand_mod_chain, uint64_t seed, double input_max = 0, double scale = 0, double tolerance = 1e-4, @@ -1205,7 +1277,7 @@ namespace multithread { auto test_thread = [ context_pool, scheme, &context ](int thread) { - return test_multiple_pools(context, 0, thread); + return test_troublesome_pools(context, 0, thread); }; utils::stream_sync(); @@ -1220,6 +1292,8 @@ namespace multithread { } + static constexpr size_t DEVICE_THREADS = 4; + TEST(MultithreadTest, HostBFVSharedContextMultiPools) { test_shared_context_multiple_pools(4, false, SchemeType::BFV, 32, 35, @@ -1227,7 +1301,7 @@ namespace multithread { ); } TEST(MultithreadTest, DeviceBFVSharedContextMultiPools) { - test_shared_context_multiple_pools(4, true, + test_shared_context_multiple_pools(DEVICE_THREADS, true, SchemeType::BFV, 32, 35, { 60, 40, 40, 60 }, true, 0x123, 0 ); @@ -1240,7 +1314,7 @@ namespace multithread { ); } TEST(MultithreadTest, DeviceBGVSharedContextMultiPools) { - test_shared_context_multiple_pools(4, true, + test_shared_context_multiple_pools(DEVICE_THREADS, true, SchemeType::BGV, 32, 35, { 60, 40, 40, 60 }, true, 0x123, 0 ); @@ -1254,7 +1328,7 @@ namespace multithread { ); } TEST(MultithreadTest, DeviceCKKSSharedContextMultiPools) { - test_shared_context_multiple_pools(4, true, + test_shared_context_multiple_pools(DEVICE_THREADS, true, SchemeType::CKKS, 32, 0, { 60, 40, 40, 60 }, true, 0x123, 10, 1ull<<20, 1e-2 @@ -1296,7 +1370,7 @@ namespace multithread { utils::stream_sync(); auto test_thread = [=](int thread, std::shared_ptr context) { size_t device_index = thread % device_count; - return test_multiple_pools(*context, device_index, thread); + return test_troublesome_pools(*context, device_index, thread); }; vector> thread_instances; @@ -1310,6 +1384,8 @@ namespace multithread { } + static constexpr size_t DEVICE_THREADS_MULTIPLE_OF_DEVICES = 4; + TEST(MultithreadTest, HostBFVMultiDevices) { test_multi_devices(8, 4, false, SchemeType::BFV, 32, 35, @@ -1322,7 +1398,7 @@ namespace multithread { if (success != cudaSuccess || device_count <= 1) { GTEST_SKIP_("No multiple devices available"); } - test_multi_devices(device_count * 2 + 1, device_count, true, + test_multi_devices(device_count * DEVICE_THREADS_MULTIPLE_OF_DEVICES + 1, device_count, true, SchemeType::BFV, 32, 35, { 60, 40, 40, 60 }, true, 0x123, 0 ); @@ -1339,7 +1415,7 @@ namespace multithread { if (success != cudaSuccess || device_count <= 1) { GTEST_SKIP_("No multiple devices available"); } - test_multi_devices(device_count * 2 + 1, device_count, true, + test_multi_devices(device_count * DEVICE_THREADS_MULTIPLE_OF_DEVICES + 1, device_count, true, SchemeType::BGV, 32, 35, { 60, 40, 40, 60 }, true, 0x123, 0 ); @@ -1357,7 +1433,7 @@ namespace multithread { if (success != cudaSuccess || device_count <= 1) { GTEST_SKIP_("No multiple devices available"); } - test_multi_devices(device_count * 2 + 1, device_count, true, + test_multi_devices(device_count * DEVICE_THREADS_MULTIPLE_OF_DEVICES + 1, device_count, true, SchemeType::CKKS, 32, 0, { 60, 40, 40, 60 }, true, 0x123, 10, 1ull<<20, 1e-2 diff --git a/test/temp.cu b/test/temp.cu deleted file mode 100644 index 60f032c..0000000 --- a/test/temp.cu +++ /dev/null @@ -1,76 +0,0 @@ -#include "test_adv.h" -#include "gtest/gtest.h" -#include -#include -#include - -namespace temp { - -#define IF_FALSE_RETURN(condition) if (!(condition)) { return false; } -#define IF_FALSE_PRINT_RETURN(condition, message) \ - if (!(condition)) { \ - std::cerr << "[" << thread_index << "] File " << __FILE__ << ", Line " << __LINE__ << ": " << message << std::endl; \ - return false; \ - } -#define CHECKPOINT(message) std::cerr << "[" << thread_index << "] " << message << std::endl; - - using namespace std; - using namespace troy; - using troy::utils::Array; - using tool::GeneralHeContext; - using tool::GeneralVector; - - void test_single_pool_multi_thread(const GeneralHeContext& context, size_t threads, size_t repeat) { - - uint64_t t = context.t(); - double scale = context.scale(); - double tolerance = context.tolerance(); - - GeneralVector m = context.random_simd_full(); - Plaintext p = context.encoder().encode_simd(m, std::nullopt, scale); - std::cout << "p = " << p.data().slice(0, 4) << std::endl; - - auto test_thread = [t, scale, repeat, tolerance, &context, &m, &p](int thread) { - - for (size_t rep = 0; rep < repeat; rep++) { - bool succ = true; - - Array copied(32, true); copied.slice(0, p.coeff_count()).copy_from_slice(p.const_poly()); - - { - Array h = Array::create_and_copy_from_slice(copied.const_slice(0, 1), false); - if (h[0] != 10572) { - std::cerr << "ckpt 3 h[0] = " << h[0] << std::endl; - succ = false; - } - } - - auto decoded = context.encoder().batch().decode_new(p); - if (!succ) { - return false; - } - } - - return true; - - }; - - utils::stream_sync(); - vector> thread_instances; - for (size_t i = 0; i < threads; i++) { - thread_instances.push_back(std::async(test_thread, i)); - } - - for (size_t i = 0; i < threads; i++) { - ASSERT_TRUE(thread_instances[i].get()); - } - - } - - TEST(Temp, Temp) { - GeneralHeContext ghe(true, SchemeType::BFV, 32, 20, { 60, 40, 40, 60 }, false, 0x123, 0); - test_single_pool_multi_thread(ghe, 64, 4); - utils::MemoryPool::Destroy(); - } - -} \ No newline at end of file