Skip to content

Commit

Permalink
bugfix: memory pool shared between multiple threads.
Browse files Browse the repository at this point in the history
  • Loading branch information
lightbulb128 committed Jul 17, 2024
1 parent e3936be commit e7adcb0
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 106 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ reports/
binder/*.so
gather_code.*
.cache/
.clangd
.clangd
test/custom.cu
1 change: 1 addition & 0 deletions src/kernel_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ namespace troy {
if (status != cudaSuccess) {
runtime_error("[kernel_provider::memset] cudaMemset failed", status);
}
utils::stream_sync();
}

template <typename T>
Expand Down
12 changes: 12 additions & 0 deletions src/utils/memory_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,20 @@ namespace troy {namespace utils {
return count;
}

// If "TROY_STREAM_SYNC_AFTER_KERNEL_CALLS" is enabled,
// this will call cudaStreamSynchronize(0). Otherwise, it does nothing.
void stream_sync();

inline void stream_sync_concrete() {
cudaError_t status = cudaStreamSynchronize(0);
if (status != cudaSuccess) {
std::string msg = "[stream_sync_concrete] cudaStreamSynchronize failed: ";
msg += cudaGetErrorString(status);
throw std::runtime_error(msg);
}
}


class MemoryPool;
typedef std::shared_ptr<MemoryPool> MemoryPoolHandle;

Expand Down
29 changes: 21 additions & 8 deletions src/utils/memory_pool_safe.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "memory_pool.h"
#include <thread>

namespace troy::utils {

Expand All @@ -12,13 +13,15 @@ namespace troy::utils {
struct MemoryPool::Impl {
static const int PRESERVED_MEMORY_BYTES = 1024 * 1024 * 32;
std::shared_mutex mutex;
std::multimap<size_t, void*> unused;

// key-value pairs are (allocated memory size, (pointer address, last used thread id))
std::multimap<size_t, std::pair<void*, std::thread::id>> unused;

std::multimap<void*, size_t> allocated;
std::set<void*> zombie;
size_t total_allocated;
bool destroyed;
};


// User shouldn't call this but call create() instead
MemoryPool::MemoryPool(size_t device_index): device_index(device_index) {
Expand Down Expand Up @@ -104,7 +107,7 @@ namespace troy::utils {
throw std::runtime_error("[MemoryPool(safe)::release] The pointer is not in the allocated set.");
}
size_t size = iterator->second;
impl_->unused.insert(std::make_pair(size, ptr));
impl_->unused.insert(std::make_pair(size, std::make_pair(ptr, std::this_thread::get_id())));
}

void* MemoryPool::allocate(size_t required) {
Expand All @@ -120,7 +123,17 @@ namespace troy::utils {
lock.unlock();
return try_allocate(required);
} else {
void* ptr = iterator->second;
void* ptr = iterator->second.first;
std::thread::id last_used_id = iterator->second.second;
if (last_used_id != std::this_thread::get_id()) {
// A new thread is now taking over the memory which was last used by another thread.
// To avoid data racing, we need to execute a cudaDeviceSync
// before returning the pointer.
cudaError_t status = cudaDeviceSynchronize();
if (status != cudaSuccess) {
runtime_error("[MemoryPool(safe)::allocate] cudaDeviceSynchronize failed.", status);
}
}
impl_->unused.erase(iterator);
return ptr;
}
Expand All @@ -130,12 +143,12 @@ namespace troy::utils {
std::unique_lock lock(impl_->mutex);
for (auto it = impl_->unused.begin(); it != impl_->unused.end();) {
set_device();
cudaError_t status = cudaFree(it->second);
cudaError_t status = cudaFree(it->second.first);
if (status != cudaSuccess) {
runtime_error("[MemoryPool(safe)::release_unused] cudaFree failed.", status);
}
// remove pointer from allocated
auto it2 = impl_->allocated.find(it->second);
auto it2 = impl_->allocated.find(it->second.first);
if (it2 == impl_->allocated.end()) {
throw std::runtime_error("[MemoryPool(safe)::release_unused] The pointer is not in the allocated set.");
}
Expand All @@ -149,12 +162,12 @@ namespace troy::utils {
// first release all unused
for (auto it = impl_->unused.begin(); it != impl_->unused.end();) {
set_device();
cudaError_t status = cudaFree(it->second);
cudaError_t status = cudaFree(it->second.first);
if (status != cudaSuccess) {
runtime_error("[MemoryPool(safe)::destroy] cudaFree unused failed.", status);
}
// remove pointer from allocated
auto it2 = impl_->allocated.find(it->second);
auto it2 = impl_->allocated.find(it->second.first);
if (it2 == impl_->allocated.end()) {
throw std::runtime_error("[MemoryPool(safe)::destroy] The pointer is not in the allocated set.");
}
Expand Down
24 changes: 19 additions & 5 deletions src/utils/memory_pool_unsafe.in
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#include "memory_pool.h"
#include <thread>

namespace troy::utils {

struct MemoryPool::Impl {
static const int PRESERVED_MEMORY_BYTES = 1024 * 1024 * 32;
std::shared_mutex mutex;
std::multimap<size_t, void*> unused;

// key-value pairs are (allocated memory size, (pointer address, last used thread id))
std::multimap<size_t, std::pair<void*, std::thread::id>> unused;

std::multimap<void*, size_t> allocated;
size_t total_allocated;
};
Expand Down Expand Up @@ -67,7 +71,7 @@ namespace troy::utils {
throw std::runtime_error("[MemoryPool(unsafe)::release] The pointer is not in the allocated set.");
}
size_t size = iterator->second;
impl_->unused.insert(std::make_pair(size, ptr));
impl_->unused.insert(std::make_pair(size, std::make_pair(ptr, std::this_thread::get_id())));
}

void* MemoryPool::allocate(size_t required) {
Expand All @@ -80,7 +84,17 @@ namespace troy::utils {
lock.unlock();
return try_allocate(required);
} else {
void* ptr = iterator->second;
void* ptr = iterator->second.first;
std::thread::id last_used_id = iterator->second.second;
if (last_used_id != std::this_thread::get_id()) {
// A new thread is now taking over the memory which was last used by another thread.
// To avoid data racing, we need to execute a cudaDeviceSync
// before returning the pointer.
cudaError_t status = cudaDeviceSynchronize();
if (status != cudaSuccess) {
runtime_error("[MemoryPool(safe)::allocate] cudaDeviceSynchronize failed.", status);
}
}
impl_->unused.erase(iterator);
return ptr;
}
Expand All @@ -90,12 +104,12 @@ namespace troy::utils {
std::unique_lock lock(impl_->mutex);
for (auto it = impl_->unused.begin(); it != impl_->unused.end();) {
set_device();
cudaError_t status = cudaFree(it->second);
cudaError_t status = cudaFree(it->second.first);
if (status != cudaSuccess) {
runtime_error("[MemoryPool(unsafe)::release_unused] cudaFree failed.", status);
}
// remove pointer from allocated
auto it2 = impl_->allocated.find(it->second);
auto it2 = impl_->allocated.find(it->second.first);
if (it2 == impl_->allocated.end()) {
throw std::runtime_error("[MemoryPool(unsafe)::release_unused] The pointer is not in the allocated set.");
}
Expand Down
8 changes: 7 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,10 @@ if(TROY_BENCH)
)
target_link_libraries(bench_ntt troy)

endif()
endif()

# add_executable(custom)
# target_sources(custom PRIVATE
# custom.cu
# )
# target_link_libraries(custom troy)
Loading

0 comments on commit e7adcb0

Please sign in to comment.