Skip to content

Commit

Permalink
read to device memory
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Jul 31, 2024
1 parent 33169c5 commit 9e4887d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 16 deletions.
36 changes: 30 additions & 6 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>

#include <kvikio/posix_io.hpp>
#include <kvikio/utils.hpp>

namespace kvikio {
namespace detail {

Expand Down Expand Up @@ -122,13 +125,10 @@ class RemoteHandle {
*/
[[nodiscard]] inline std::size_t nbytes() const { return _nbytes; }

std::size_t read(void* buf,
std::size_t size,
std::size_t file_offset = 0,
std::size_t task_size = defaults::task_size())
std::size_t read_to_host(void* buf, std::size_t size, std::size_t file_offset = 0)
{
std::cout << "RemoteHandle::read() - buf: " << buf << ", size: " << size
<< ", file_offset: " << file_offset << ", task_size: " << task_size << std::endl;
std::cout << "RemoteHandle::read_to_host() - buf: " << buf << ", size: " << size
<< ", file_offset: " << file_offset << std::endl;

Aws::S3::Model::GetObjectRequest req;
req.SetBucket(_bucket_name.c_str());
Expand All @@ -155,6 +155,30 @@ class RemoteHandle {
outcome.GetResult().GetBody().read(static_cast<char*>(buf), size);
return n;
}

std::size_t read(void* buf, std::size_t size, std::size_t file_offset = 0)
{
if (is_host_memory(buf)) { return read_to_host(buf, size, file_offset); }

auto alloc = detail::AllocRetain::instance().get(); // Host memory allocation
CUdeviceptr devPtr = convert_void2deviceptr(buf);
CUstream stream = detail::StreamsByThread::get();

std::size_t cur_file_offset = convert_size2off(file_offset);
std::size_t byte_remaining = convert_size2off(size);

while (byte_remaining > 0) {
const std::size_t nbytes_requested = std::min(posix_bounce_buffer_size, byte_remaining);
std::size_t nbytes_got = nbytes_requested;
nbytes_got = read_to_host(alloc.get(), nbytes_requested, cur_file_offset);
CUDA_DRIVER_TRY(cudaAPI::instance().MemcpyHtoDAsync(devPtr, alloc.get(), nbytes_got, stream));
CUDA_DRIVER_TRY(cudaAPI::instance().StreamSynchronize(stream));
cur_file_offset += nbytes_got;
devPtr += nbytes_got;
byte_remaining -= nbytes_got;
}
return size;
}
};

} // namespace kvikio
24 changes: 14 additions & 10 deletions python/kvikio/tests/test_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import time
from contextlib import contextmanager

import numpy as np
import pytest

import kvikio
Expand Down Expand Up @@ -92,28 +91,33 @@ def s3_context(s3_base, bucket, files=None):
pass


def test_read(s3_base):
def test_read(s3_base, xp):
bucket_name = "test_read"
object_name = "a1"
a = np.arange(1000)
a = xp.arange(1000)
with s3_context(s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)}):
with kvikio.RemoteFile(bucket_name, object_name) as f:
assert f.nbytes() == a.nbytes
b = np.empty_like(a)
b = xp.empty_like(a)
assert f.read(buf=b) == a.nbytes
assert all(a == b)
xp.testing.assert_array_equal(a, b)


@pytest.mark.parametrize(
"start,end",
[(0, 10 * 4096), (1, int(1.3 * 4096)), (int(2.1 * 4096), int(5.6 * 4096))],
[
(0, 10 * 4096),
(1, int(1.3 * 4096)),
(int(2.1 * 4096), int(5.6 * 4096)),
(42, int(2**23)),
],
)
def test_read_with_file_offset(s3_base, start, end):
def test_read_with_file_offset(s3_base, xp, start, end):
bucket_name = "test_read"
object_name = "a1"
a = np.arange(10 * 4096, dtype=np.int64) # 10 page-sizes
a = xp.arange(end, dtype=xp.int64)
with s3_context(s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)}):
with kvikio.RemoteFile(bucket_name, object_name) as f:
b = np.zeros(shape=(end - start,), dtype=np.int64)
b = xp.zeros(shape=(end - start,), dtype=xp.int64)
assert f.read(b, file_offset=start * a.itemsize) == b.nbytes
assert all(a[start:end] == b)
xp.testing.assert_array_equal(a[start:end], b)

0 comments on commit 9e4887d

Please sign in to comment.