diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index cee9bc495a..743bbdd4f2 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-python>=11.7.1,<12.0a0 @@ -18,6 +19,7 @@ dependencies: - doxygen=1.9.1 - gcc_linux-aarch64=11.* - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 6aa0a40289..10a26eff42 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-python>=11.7.1,<12.0a0 @@ -20,6 +21,7 @@ dependencies: - libcufile-dev=1.4.0.31 - libcufile=1.4.0.31 - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index 321431dbed..ca7c36e8cd 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-nvcc @@ -19,6 +20,7 @@ dependencies: - gcc_linux-aarch64=11.* - libcufile-dev - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 5123bc012d..cfbb9b3154 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-nvcc @@ -19,6 +20,7 @@ dependencies: - gcc_linux-64=11.* - libcufile-dev - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index 436d07c91d..77f08a30de 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -145,6 +146,141 @@ class HttpEndpoint : public RemoteEndpoint { ~HttpEndpoint() override = default; }; +/** + * @brief + */ +class S3Endpoint : public RemoteEndpoint { + private: + std::string _url; + std::string _aws_sigv4; + std::string _aws_userpwd; + + static std::string parse_aws_argument(std::optional aws_arg, + const std::string& env_var, + const std::string& err_msg, + bool allow_empty = false) + { + if (aws_arg.has_value()) { return std::move(*aws_arg); } + + char const* env = std::getenv(env_var.c_str()); + if (env == nullptr) { + if (allow_empty) { return std::string(); } + throw std::invalid_argument(err_msg); + } + return std::string(env); + } + + static std::string url_from_bucket_and_object(const std::string& bucket_name, + const std::string& object_name, + const std::optional& aws_region, + std::optional aws_endpoint_url) + { + std::string endpoint_url = + parse_aws_argument(std::move(aws_endpoint_url), + "AWS_ENDPOINT_URL", + "S3: must provide `aws_endpoint_url` if AWS_ENDPOINT_URL isn't set.", + true); + std::stringstream ss; + if (endpoint_url.empty()) { + std::string region = + parse_aws_argument(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + // We default to the official AWS url scheme. + ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name; + } else { + ss << endpoint_url << "/" << bucket_name << "/" << object_name; + } + return ss.str(); + } + + public: + /** + * @brief Given an url like "s3:///", return the name of the bucket and object. + * + * @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name. + * + * @param s3_url S3 url. + * @return Pair of strings: [bucket-name, object-name]. + */ + static std::pair parse_s3_url(std::string const& s3_url) + { + if (s3_url.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } + if (s3_url.size() < 5 || s3_url.substr(0, 5) != "s3://") { + throw std::invalid_argument("The S3 url must start with the S3 scheme (\"s3://\")."); + } + std::string p = s3_url.substr(5); + if (p.empty()) { throw std::invalid_argument("The S3 url cannot be an empty string."); } + size_t pos = p.find_first_of('/'); + std::string bucket_name = p.substr(0, pos); + if (bucket_name.empty()) { + throw std::invalid_argument("The S3 url does not contain a bucket name."); + } + std::string object_name = (pos == std::string::npos) ? "" : p.substr(pos + 1); + if (object_name.empty()) { + throw std::invalid_argument("The S3 url does not contain an object name."); + } + return std::make_pair(std::move(bucket_name), std::move(object_name)); + } + + S3Endpoint(std::string url, + std::optional aws_region = std::nullopt, + std::optional aws_access_key = std::nullopt, + std::optional aws_secret_access_key = std::nullopt) + : _url{std::move(url)} + { + std::string region = + parse_aws_argument(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + + std::string access_key = + parse_aws_argument(std::move(aws_access_key), + "AWS_ACCESS_KEY_ID", + "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set."); + + std::string secret_access_key = parse_aws_argument( + std::move(aws_secret_access_key), + "AWS_SECRET_ACCESS_KEY", + "S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set."); + + // Create the CURLOPT_AWS_SIGV4 option + { + std::stringstream ss; + ss << "aws:amz:" << region << ":s3"; + _aws_sigv4 = ss.str(); + } + // Create the CURLOPT_USERPWD option + { + std::stringstream ss; + ss << access_key << ":" << secret_access_key; + _aws_userpwd = ss.str(); + } + } + S3Endpoint(const std::string& bucket_name, + const std::string& object_name, + std::optional aws_region = std::nullopt, + std::optional aws_access_key = std::nullopt, + std::optional aws_secret_access_key = std::nullopt, + std::optional aws_endpoint_url = std::nullopt) + : S3Endpoint(url_from_bucket_and_object( + bucket_name, object_name, aws_region, std::move(aws_endpoint_url)), + std::move(aws_region), + std::move(aws_access_key), + std::move(aws_secret_access_key)) + { + } + + void setopt(CurlHandle& curl) override + { + curl.setopt(CURLOPT_URL, _url.c_str()); + curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str()); + curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str()); + } + std::string str() override { return _url; } + ~S3Endpoint() override = default; +}; + /** * @brief Handle of remote file. */ diff --git a/dependencies.yaml b/dependencies.yaml index dbad8de059..44ef5cc200 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -327,6 +327,8 @@ dependencies: - &dask dask>=2022.05.2 - pytest - pytest-cov + - moto>=4.0.8 + - boto3>=1.21.21 - output_types: [requirements, pyproject] packages: - rangehttpserver diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 5e58da32f0..357a965595 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -23,6 +23,15 @@ cdef extern from "" nogil: cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint": cpp_HttpEndpoint(string url) except + + cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": + cpp_S3Endpoint(string url) except + + + cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint": + cpp_S3Endpoint(string bucket_name, string object_name) except + + + pair[string, string] cpp_parse_s3_url \ + "kvikio::S3Endpoint::parse_s3_url"(string url) except + + cdef cppclass cpp_RemoteHandle "kvikio::RemoteHandle": cpp_RemoteHandle( unique_ptr[cpp_RemoteEndpoint] endpoint, size_t nbytes @@ -67,6 +76,59 @@ cdef class RemoteFile: ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) return ret + @classmethod + def open_s3_from_http_url( + cls, + url: str, + nbytes: Optional[int], + ): + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( + _to_string(url) + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + + @classmethod + def open_s3( + cls, + bucket_name: str, + object_name: str, + nbytes: Optional[int], + ): + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( + _to_string(bucket_name), _to_string(object_name) + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + + @classmethod + def open_s3_from_s3_url( + cls, + url: str, + nbytes: Optional[int], + ): + cdef pair[string, string] bucket_and_object = cpp_parse_s3_url(_to_string(url)) + cdef RemoteFile ret = RemoteFile() + cdef unique_ptr[cpp_S3Endpoint] ep = make_unique[cpp_S3Endpoint]( + bucket_and_object.first, bucket_and_object.second + ) + if nbytes is None: + ret._handle = make_unique[cpp_RemoteHandle](move(ep)) + return ret + cdef size_t n = nbytes + ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) + return ret + def nbytes(self) -> int: return deref(self._handle).nbytes() diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py new file mode 100644 index 0000000000..41585cf83b --- /dev/null +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -0,0 +1,245 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import argparse +import contextlib +import multiprocessing +import os +import socket +import statistics +import sys +import time +from functools import partial +from typing import ContextManager +from urllib.parse import urlparse + +import boto3 +import cupy +import numpy +from dask.utils import format_bytes + +import kvikio +import kvikio.defaults + + +def get_local_port() -> int: + """Return an available port""" + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_s3_server(lifetime: int): + """Start a server and run it for `lifetime` minutes. + NB: to stop before `lifetime`, kill the process/thread running this function. + """ + from moto.server import ThreadedMotoServer + + # Silence the activity info from ThreadedMotoServer + sys.stderr = open(os.devnull, "w") + url = urlparse(os.environ["AWS_ENDPOINT_URL"]) + server = ThreadedMotoServer(ip_address=url.hostname, port=url.port) + server.start() + time.sleep(lifetime) + + +@contextlib.contextmanager +def local_s3_server(lifetime: int): + """Start a server and run it for `lifetime` minutes or kill it on context exit""" + # Use fake aws credentials + os.environ["AWS_ACCESS_KEY_ID"] = "foobar_key" + os.environ["AWS_SECRET_ACCESS_KEY"] = "foobar_secret" + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + p = multiprocessing.Process(target=start_s3_server, args=(lifetime,)) + p.start() + yield + p.kill() + + +def create_client_and_bucket(): + client = boto3.client("s3", endpoint_url=os.getenv("AWS_ENDPOINT_URL", None)) + try: + client.create_bucket(Bucket=args.bucket, ACL="public-read-write") + except ( + client.exceptions.BucketAlreadyOwnedByYou, + client.exceptions.BucketAlreadyExists, + ): + pass + except Exception: + print( + "Problem accessing the S3 server? using wrong credentials? Try setting " + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and/or AWS_ENDPOINT_URL. " + "Alternatively, use the bundled server `--use-bundled-server`\n", + file=sys.stderr, + flush=True, + ) + raise + return client + + +def run_numpy_like(args, xp): + # Upload data to S3 server + data = numpy.arange(args.nelem, dtype=args.dtype) + recv = xp.empty_like(data) + + client = create_client_and_bucket() + client.put_object(Bucket=args.bucket, Key="data", Body=bytes(data)) + server_address = os.environ["AWS_ENDPOINT_URL"] + url = f"{server_address}/{args.bucket}/data" + + def run() -> float: + t0 = time.perf_counter() + with kvikio.RemoteFile.open_s3_from_http_url(url) as f: + res = f.read(recv) + t1 = time.perf_counter() + assert res == args.nbytes, f"IO mismatch, expected {args.nbytes} got {res}" + xp.testing.assert_array_equal(data, recv) + return t1 - t0 + + for _ in range(args.nruns): + yield run() + + +def run_cudf(args, libcudf_s3_io: bool): + import cudf + + cudf.set_option("libcudf_s3_io", libcudf_s3_io) + + # Upload data to S3 server + create_client_and_bucket() + data = cupy.random.rand(args.nelem).astype(args.dtype) + df = cudf.DataFrame({"a": data}) + df.to_parquet(f"s3://{args.bucket}/data1") + + def run() -> float: + t0 = time.perf_counter() + cudf.read_parquet(f"s3://{args.bucket}/data1") + t1 = time.perf_counter() + return t1 - t0 + + for _ in range(args.nruns): + yield run() + + +API = { + "cupy-kvikio": partial(run_numpy_like, xp=cupy), + "numpy-kvikio": partial(run_numpy_like, xp=numpy), + "cudf-kvikio": partial(run_cudf, libcudf_s3_io=True), + "cudf-fsspec": partial(run_cudf, libcudf_s3_io=False), +} + + +def main(args): + cupy.cuda.set_allocator(None) # Disable CuPy's default memory pool + cupy.arange(10) # Make sure CUDA is initialized + + kvikio.defaults.num_threads_reset(args.nthreads) + print("Roundtrip benchmark") + print("--------------------------------------") + print(f"nelem | {args.nelem} ({format_bytes(args.nbytes)})") + print(f"dtype | {args.dtype}") + print(f"nthreads | {args.nthreads}") + print(f"nruns | {args.nruns}") + print(f"server | {os.getenv('AWS_ENDPOINT_URL', 'http://*.amazonaws.com')}") + if args.use_bundled_server: + print("--------------------------------------") + print("Using the bundled local server is slow") + print("and can be misleading. Consider using") + print("a local MinIO or official S3 server.") + print("======================================") + + # Run each benchmark using the requested APIs + for api in args.api: + res = [] + for elapsed in API[api](args): + res.append(elapsed) + + def pprint_api_res(name, samples): + samples = [args.nbytes / s for s in samples] # Convert to throughput + mean = statistics.harmonic_mean(samples) if len(samples) > 1 else samples[0] + ret = f"{api}-{name}".ljust(18) + ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14) + if len(samples) > 1: + stdev = statistics.stdev(samples) / mean * 100 + ret += " ± %5.2f %%" % stdev + ret += " (" + for sample in samples: + ret += f"{format_bytes(sample)}/s, " + ret = ret[:-2] + ")" # Replace trailing comma + return ret + + print(pprint_api_res("read", res)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Roundtrip benchmark") + parser.add_argument( + "-n", + "--nelem", + metavar="NELEM", + default="1024", + type=int, + help="Number of elements (default: %(default)s).", + ) + parser.add_argument( + "--dtype", + metavar="DATATYPE", + default="float32", + type=numpy.dtype, + help="The data type of each element (default: %(default)s).", + ) + parser.add_argument( + "--nruns", + metavar="RUNS", + default=1, + type=int, + help="Number of runs per API (default: %(default)s).", + ) + parser.add_argument( + "-t", + "--nthreads", + metavar="THREADS", + default=1, + type=int, + help="Number of threads to use (default: %(default)s).", + ) + parser.add_argument( + "--use-bundled-server", + action="store_true", + help="Launch and use a local slow S3 server (ThreadedMotoServer).", + ) + parser.add_argument( + "--bundled-server-lifetime", + metavar="SECONDS", + default=3600, + type=int, + help="Maximum lifetime of the bundled server (default: %(default)s).", + ) + parser.add_argument( + "--bucket", + metavar="NAME", + default="kvikio-s3-benchmark", + type=str, + help="Name of the AWS S3 bucket to use (default: %(default)s).", + ) + parser.add_argument( + "--api", + metavar="API", + default=list(API.keys())[0], # defaults to the first API + nargs="+", + choices=tuple(API.keys()) + ("all",), + help="List of APIs to use {%(choices)s} (default: %(default)s).", + ) + args = parser.parse_args() + args.nbytes = args.nelem * args.dtype.itemsize + if "all" in args.api: + args.api = tuple(API.keys()) + + ctx: ContextManager = contextlib.nullcontext() + if args.use_bundled_server: + os.environ["AWS_ENDPOINT_URL"] = f"http://127.0.0.1:{get_local_port()}" + ctx = local_s3_server(args.bundled_server_lifetime) + with ctx: + main(args) diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index 0b2e886f0b..39666b6642 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -68,6 +68,50 @@ def open_http( """ return RemoteFile(_get_remote_module().RemoteFile.open_http(url, nbytes)) + @classmethod + def open_s3( + cls, + bucket_name: str, + object_name: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + return RemoteFile( + _get_remote_module().RemoteFile.open_s3(bucket_name, object_name, nbytes) + ) + + @classmethod + def open_s3_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + url = url.lower() + if url.startswith("http://") or url.startswith("https://"): + return cls.open_s3_from_http_url(url, nbytes) + if url.startswith("s://"): + return cls.open_s3_from_s3_url(url, nbytes) + raise ValueError(f"Unsupported protocol in url: {url}") + + @classmethod + def open_s3_from_http_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_http_url(url, nbytes) + ) + + @classmethod + def open_s3_from_s3_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_s3_url(url, nbytes) + ) + def __enter__(self) -> RemoteFile: return self diff --git a/python/kvikio/pyproject.toml b/python/kvikio/pyproject.toml index e59a19cd30..bc547c3f35 100644 --- a/python/kvikio/pyproject.toml +++ b/python/kvikio/pyproject.toml @@ -38,8 +38,10 @@ classifiers = [ [project.optional-dependencies] test = [ + "boto3>=1.21.21", "cuda-python>=11.7.1,<12.0a0", "dask>=2022.05.2", + "moto>=4.0.8", "pytest", "pytest-cov", "rangehttpserver", @@ -139,4 +141,5 @@ regex = "(?P.*)" filterwarnings = [ "error", "ignore:Jitify is performing a one-time only warm-up to populate the persistent cache", + "ignore::DeprecationWarning:botocore.*", ] diff --git a/python/kvikio/tests/test_s3_io.py b/python/kvikio/tests/test_s3_io.py new file mode 100644 index 0000000000..2daab28700 --- /dev/null +++ b/python/kvikio/tests/test_s3_io.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import multiprocessing as mp +import socket +import time +from contextlib import contextmanager + +import pytest + +import kvikio +import kvikio.defaults + +pytestmark = pytest.mark.skipif( + not kvikio.is_remote_file_available(), + reason=( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ), +) + +# Notice, we import boto and moto after the `is_remote_file_available` check. +import boto3 # noqa: E402 +import moto # noqa: E402 +import moto.server # noqa: E402 + + +@pytest.fixture(scope="session") +def endpoint_ip(): + return "127.0.0.1" + + +@pytest.fixture(scope="session") +def endpoint_port(): + # Return a free port per worker session. + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_s3_server(ip_address, port): + server = moto.server.ThreadedMotoServer(ip_address=ip_address, port=port) + server.start() + time.sleep(600) + print("ThreadedMotoServer shutting down because of timeout (10min)") + + +@pytest.fixture(scope="session") +def s3_base(endpoint_ip, endpoint_port): + """Fixture to set up moto server in separate process""" + with pytest.MonkeyPatch.context() as monkeypatch: + # Use fake aws credentials + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "foobar_key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "foobar_secret") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + monkeypatch.setenv("AWS_ENDPOINT_URL", f"http://{endpoint_ip}:{endpoint_port}") + + p = mp.Process(target=start_s3_server, args=(endpoint_ip, endpoint_port)) + p.start() + yield f"http://{endpoint_ip}:{endpoint_port}" + p.kill() + + +@contextmanager +def s3_context(s3_base, bucket, files=None): + if files is None: + files = {} + client = boto3.client("s3", endpoint_url=s3_base) + client.create_bucket(Bucket=bucket, ACL="public-read-write") + for f, data in files.items(): + client.put_object(Bucket=bucket, Key=f, Body=data) + yield s3_base + for f, data in files.items(): + try: + client.delete_object(Bucket=bucket, Key=f) + except Exception: + pass + + +@pytest.mark.parametrize("size", [10, 100, 1000]) +@pytest.mark.parametrize("nthreads", [1, 3]) +@pytest.mark.parametrize("tasksize", [99, 999]) +@pytest.mark.parametrize("buffer_size", [101, 1001]) +def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): + bucket_name = "test_read" + object_name = "a1" + a = xp.arange(size) + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} + ) as server_address: + with kvikio.defaults.set_num_threads(nthreads): + with kvikio.defaults.set_task_size(tasksize): + with kvikio.defaults.set_bounce_buffer_size(buffer_size): + with kvikio.RemoteFile.open_s3_url( + f"{server_address}/{bucket_name}/{object_name}" + ) as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(buf=b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + with kvikio.RemoteFile.open_s3(bucket_name, object_name) as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(buf=b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + +@pytest.mark.parametrize( + "start,end", + [ + (0, 10 * 4096), + (1, int(1.3 * 4096)), + (int(2.1 * 4096), int(5.6 * 4096)), + (42, int(2**20)), + ], +) +def test_read_with_file_offset(s3_base, xp, start, end): + bucket_name = "test_read_with_file_offset" + object_name = "a1" + a = xp.arange(end, dtype=xp.int64) + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} + ) as server_address: + url = f"{server_address}/{bucket_name}/{object_name}" + with kvikio.RemoteFile.open_s3_from_http_url(url) as f: + b = xp.zeros(shape=(end - start,), dtype=xp.int64) + assert f.read(b, file_offset=start * a.itemsize) == b.nbytes + xp.testing.assert_array_equal(a[start:end], b)