Skip to content

Commit

Permalink
make remote IO optional
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Aug 12, 2024
1 parent ed1b7a5 commit f76ff32
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 39 deletions.
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ rapids_find_package(
)

rapids_find_package(
AWSSDK REQUIRED COMPONENTS s3
AWSSDK COMPONENTS s3
BUILD_EXPORT_SET kvikio-exports
INSTALL_EXPORT_SET kvikio-exports
)
Expand Down
2 changes: 1 addition & 1 deletion python/kvikio/kvikio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kvikio._lib import buffer, driver_properties # type: ignore
from kvikio._version import __git_commit__, __version__ # noqa: F401
from kvikio.cufile import CuFile # noqa: F401
from kvikio.remote_file import RemoteFile # noqa: F401
from kvikio.remote_file import RemoteFile, is_remote_file_available # noqa: F401


def memory_register(buf) -> None:
Expand Down
9 changes: 8 additions & 1 deletion python/kvikio/kvikio/_lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@

# Set the list of Cython files to build, one .so per file
set(cython_modules arr.pyx libnvcomp.pyx libnvcomp_ll.pyx file_handle.pyx driver_properties.pyx
future.pyx buffer.pyx defaults.pyx remote_handle.pyx
future.pyx buffer.pyx defaults.pyx
)

if(AWSSDK_FOUND)
message(STATUS "Building remote_handle.pyx (aws-cpp-sdk-s3 found)")
list(APPEND cython_modules remote_handle.pyx)
else()
message(WARNING "Skipping remote_handle.pyx (aws-cpp-sdk-s3 not found)")
endif()

rapids_cython_create_modules(
CXX
SOURCE_FILES "${cython_modules}"
Expand Down
34 changes: 1 addition & 33 deletions python/kvikio/kvikio/cufile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import pathlib
from typing import Optional, Union

from typing_extensions import Self

from kvikio._lib import file_handle, remote_handle # type: ignore
from kvikio._lib import file_handle # type: ignore


class IOFutureStream:
Expand Down Expand Up @@ -432,33 +430,3 @@ def raw_write(
to be a multiple of 4096 bytes. When GDS isn't used, this is less critical.
"""
return self._handle.write(buf, size, file_offset, dev_offset)


class RemoteFile:
"""File handle for Remote files"""

def __init__(self, bucket_name: str, object_name: str):
self._handle = remote_handle.RemoteFile.from_bucket_and_object(
bucket_name, object_name
)

@classmethod
def from_url(cls, url: str) -> Self:
ret = object.__new__(cls)
ret._handle = remote_handle.RemoteFile.from_url(url)
return ret

def __enter__(self) -> "RemoteFile":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
pass

def nbytes(self) -> int:
return self._handle.nbytes()

def pread(self, buf, size: Optional[int] = None, file_offset: int = 0) -> IOFuture:
return IOFuture(self._handle.pread(buf, size, file_offset))

def read(self, buf, size: Optional[int] = None, file_offset: int = 0) -> int:
return self.pread(buf, size, file_offset).get()
24 changes: 21 additions & 3 deletions python/kvikio/kvikio/remote_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,40 @@

from typing_extensions import Self

from kvikio._lib import remote_handle # type: ignore
from kvikio.cufile import IOFuture


def is_remote_file_available() -> bool:
try:
import kvikio._lib.remote_handle # noqa: F401
except ImportError:
return False
else:
return True


def _get_remote_remote_file_class():
if not is_remote_file_available():
raise RuntimeError(
"RemoteFile not available, please build KvikIO with AWS S3 support"
)
import kvikio._lib.remote_handle

return kvikio._lib.remote_handle.RemoteFile


class RemoteFile:
"""File handle of a remote file"""

def __init__(self, bucket_name: str, object_name: str):
self._handle = remote_handle.RemoteFile.from_bucket_and_object(
self._handle = _get_remote_remote_file_class().from_bucket_and_object(
bucket_name, object_name
)

@classmethod
def from_url(cls, url: str) -> Self:
ret = object.__new__(cls)
ret._handle = remote_handle.RemoteFile.from_url(url)
ret._handle = _get_remote_remote_file_class().from_url(url)
return ret

def __enter__(self) -> RemoteFile:
Expand Down
6 changes: 6 additions & 0 deletions python/kvikio/tests/test_aws_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
moto = pytest.importorskip("moto", minversion="3.1.6")
boto3 = pytest.importorskip("boto3")

if not kvikio.is_remote_file_available():
pytest.skip(
"cannot test remote IO, please build KvikIO with with AWS S3 support",
allow_module_level=True,
)

ThreadedMotoServer = pytest.importorskip("moto.server").ThreadedMotoServer


Expand Down
7 changes: 7 additions & 0 deletions python/kvikio/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import pytest

import kvikio

benchmarks_path = (
Path(os.path.realpath(__file__)).parent.parent / "kvikio" / "benchmarks"
)
Expand Down Expand Up @@ -92,6 +94,11 @@ def test_zarr_io(run_cmd, tmp_path, api):
def test_aws_s3_io(run_cmd, api):
"""Test benchmarks/aws_s3_io.py"""

if not kvikio.is_remote_file_available():
pytest.skip(
"cannot test remote IO, please build KvikIO with with AWS S3 support",
allow_module_level=True,
)
pytest.importorskip("boto3")
pytest.importorskip("moto")
if "cudf" in api:
Expand Down

0 comments on commit f76ff32

Please sign in to comment.