Skip to content

Commit

Permalink
preallocate cache for comm
Browse files Browse the repository at this point in the history
  • Loading branch information
luoxiaojian committed Sep 19, 2024
1 parent 19e0992 commit 455192d
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 19 deletions.
3 changes: 3 additions & 0 deletions examples/analytical_apps/pagerank/pagerank_vc.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class PageRankVC

void IncEval(const fragment_t& frag, context_t& ctx,
message_manager_t& messages) {
if (ctx.step == 0) {
messages.AllocateGatherBuffers<fragment_t, double>(frag);
}
++ctx.step;

double base = (1.0 - ctx.delta) / ctx.graph_vnum +
Expand Down
117 changes: 98 additions & 19 deletions grape/parallel/gather_scatter_message_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#ifndef GRAPE_PARALLEL_GATHER_SCATTER_MESSAGE_MANAGER_H_
#define GRAPE_PARALLEL_GATHER_SCATTER_MESSAGE_MANAGER_H_

#include <queue>

#include <mpi.h>

#include "grape/communication/sync_comm.h"
Expand Down Expand Up @@ -97,6 +99,39 @@ class GatherScatterMessageManager : public MessageManagerBase {

size_t GetMsgSize() const override { return sent_size_; }

template <typename FRAG_T, typename MESSAGE_T>
void AllocateGatherBuffers(const FRAG_T& frag) {
auto output_range = frag.MasterVertices();
size_t output_size = output_range.size() * sizeof(MESSAGE_T);
for (fid_t i = 1; i < fnum_; ++i) {
fid_t src_fid = (fid_ + fnum_ - i) % fnum_;
auto src_vertices = frag.GetPartitioner().get_src_vertices(src_fid);
if (output_range.IsSubsetOf(src_vertices)) {
#ifdef TRACKING_MEMORY_ALLOCATIONS
// allocate memory for received messages
MemoryTracker::GetInstance().allocate(output_size);
#endif
std::vector<char> buffer(output_size);
returnToPool(std::move(buffer));
continue;
} else {
CHECK(!output_range.OverlapWith(src_vertices));
}
auto dst_vertices = frag.GetPartitioner().get_dst_vertices(src_fid);
if (output_range.IsSubsetOf(dst_vertices)) {
#ifdef TRACKING_MEMORY_ALLOCATIONS
// allocate memory for received messages
MemoryTracker::GetInstance().allocate(output_size);
#endif
std::vector<char> buffer(output_size);
returnToPool(std::move(buffer));
continue;
} else {
CHECK(!output_range.OverlapWith(dst_vertices));
}
}
}

template <typename GRAPH_T, typename MESSAGE_T, typename AGGR_T>
void GatherMasterVertices(
const GRAPH_T& frag,
Expand Down Expand Up @@ -128,33 +163,47 @@ class GatherScatterMessageManager : public MessageManagerBase {
}
}

std::vector<std::vector<MESSAGE_T>> recv_buffers;
std::vector<MESSAGE_T*> recv_buffers;
const auto& output_range = output.GetVertexRange();
size_t output_size = output_range.size();
std::queue<std::vector<char>> from_pool;
std::queue<std::vector<char>> from_heap;
for (fid_t i = 1; i < fnum_; ++i) {
fid_t src_fid = (fid_ + fnum_ - i) % fnum_;
auto src_vertices = partitioner.get_src_vertices(src_fid);
if (output_range.IsSubsetOf(src_vertices)) {
recv_buffers.emplace_back(output_size);
#ifdef TRACKING_MEMORY_ALLOCATIONS
// allocate memory for received messages
MemoryTracker::GetInstance().allocate(output_size * sizeof(MESSAGE_T));
#endif
sync_comm::irecv_buffer(recv_buffers.back().data(), output_size,
src_fid, 0, comm_, requests);
MESSAGE_T* recv_data = nullptr;
std::vector<char> buffer;
bool is_from_pool =
takeFromPool(buffer, output_size * sizeof(MESSAGE_T));
recv_data = reinterpret_cast<MESSAGE_T*>(buffer.data());
if (is_from_pool) {
from_pool.emplace(std::move(buffer));
} else {
from_heap.emplace(std::move(buffer));
}
recv_buffers.emplace_back(recv_data);
sync_comm::irecv_buffer(recv_data, output_size, src_fid, 0, comm_,
requests);
continue;
} else {
CHECK(!output_range.OverlapWith(src_vertices));
}
auto dst_vertices = partitioner.get_dst_vertices(src_fid);
if (output_range.IsSubsetOf(dst_vertices)) {
recv_buffers.emplace_back(output_size);
#ifdef TRACKING_MEMORY_ALLOCATIONS
// allocate memory for received messages
MemoryTracker::GetInstance().allocate(output_size * sizeof(MESSAGE_T));
#endif
sync_comm::irecv_buffer(recv_buffers.back().data(), output_size,
src_fid, 0, comm_, requests);
MESSAGE_T* recv_data = nullptr;
std::vector<char> buffer;
bool is_from_pool =
takeFromPool(buffer, output_size * sizeof(MESSAGE_T));
recv_data = reinterpret_cast<MESSAGE_T*>(buffer.data());
if (is_from_pool) {
from_pool.emplace(std::move(buffer));
} else {
from_heap.emplace(std::move(buffer));
}
recv_buffers.emplace_back(recv_data);
sync_comm::irecv_buffer(recv_data, output_size, src_fid, 0, comm_,
requests);
continue;
} else {
CHECK(!output_range.OverlapWith(dst_vertices));
Expand Down Expand Up @@ -220,11 +269,15 @@ class GatherScatterMessageManager : public MessageManagerBase {

#ifdef TRACKING_MEMORY_ALLOCATIONS
// deallocate memory for received messages
for (auto& recv_buffer : recv_buffers) {
MemoryTracker::GetInstance().deallocate(recv_buffer.size() *
sizeof(MESSAGE_T));
while (!from_heap.empty()) {
MemoryTracker::GetInstance().deallocate(from_heap.front().size());
from_heap.pop();
}
#endif
while (!from_pool.empty()) {
returnToPool(std::move(from_pool.front()));
from_pool.pop();
}
#ifdef PROFILING
t1_gather_calc_ += GetCurrentTime();
#endif
Expand Down Expand Up @@ -276,7 +329,9 @@ class GatherScatterMessageManager : public MessageManagerBase {
if (master_vertices.IsSubsetOf(output_range)) {
MESSAGE_T* output_data = &output[*master_vertices.begin()];
size_t output_size = master_vertices.size();
sync_comm::recv_buffer(output_data, output_size, src_fid, 0, comm_);
// sync_comm::recv_buffer(output_data, output_size, src_fid, 0, comm_);
sync_comm::irecv_buffer(output_data, output_size, src_fid, 0, comm_,
requests);
} else {
CHECK(!master_vertices.OverlapWith(output_range));
}
Expand All @@ -296,6 +351,28 @@ class GatherScatterMessageManager : public MessageManagerBase {
}

private:
bool takeFromPool(std::vector<char>& buffer, size_t size) {
for (auto& pair : gather_pools_) {
if (pair.first >= size) {
if (!pair.second.empty()) {
buffer = std::move(pair.second.front());
pair.second.pop();
return true;
}
}
}
#ifdef TRACKING_MEMORY_ALLOCATIONS
// allocate memory for received messages
MemoryTracker::GetInstance().allocate(size);
#endif
buffer.resize(size);
return false;
}

void returnToPool(std::vector<char>&& buffer) {
gather_pools_[buffer.size()].emplace(std::move(buffer));
}

fid_t fid_;
fid_t fnum_;
CommSpec comm_spec_;
Expand All @@ -309,6 +386,8 @@ class GatherScatterMessageManager : public MessageManagerBase {
TerminateInfo terminate_info_;
bool vote_terminate_;

std::map<size_t, std::queue<std::vector<char>>> gather_pools_;

#ifdef PROFILING
double t0_gather_comm_ = 0;
double t1_gather_calc_ = 0;
Expand Down

0 comments on commit 455192d

Please sign in to comment.