From fcf4b155314184e7f9ce1fa5209ca755a80a4867 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Tue, 22 Oct 2024 21:05:00 +0200 Subject: [PATCH] Remote IO: S3 support (#479) Implements AWS S3 read support using libcurl: ```python import kvikio import cupy with kvikio.RemoteFile.from_s3_url("s://my-bucket/my-file") as f: ary = cupy.empty(f.nbytes, dtype="uint8") f.read(ary) ``` Supersedes https://github.com/rapidsai/kvikio/pull/426 Authors: - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - Bradley Dice (https://github.com/bdice) - Lawrence Mitchell (https://github.com/wence-) - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/kvikio/pull/479 --- .../all_cuda-118_arch-aarch64.yaml | 2 + .../all_cuda-118_arch-x86_64.yaml | 2 + .../all_cuda-125_arch-aarch64.yaml | 2 + .../all_cuda-125_arch-x86_64.yaml | 2 + cpp/include/kvikio/remote_handle.hpp | 208 ++++++++++++++++- dependencies.yaml | 7 + python/kvikio/kvikio/_lib/remote_handle.pyx | 91 +++++++- python/kvikio/kvikio/benchmarks/s3_io.py | 221 ++++++++++++++++++ python/kvikio/kvikio/remote_file.py | 74 ++++++ python/kvikio/pyproject.toml | 3 + python/kvikio/tests/test_benchmarks.py | 29 +++ python/kvikio/tests/test_http_io.py | 4 + python/kvikio/tests/test_s3_io.py | 159 +++++++++++++ 13 files changed, 791 insertions(+), 13 deletions(-) create mode 100644 python/kvikio/kvikio/benchmarks/s3_io.py create mode 100644 python/kvikio/tests/test_s3_io.py diff --git a/conda/environments/all_cuda-118_arch-aarch64.yaml b/conda/environments/all_cuda-118_arch-aarch64.yaml index 0e7f4b3e21..ef1215d51b 100644 --- a/conda/environments/all_cuda-118_arch-aarch64.yaml +++ b/conda/environments/all_cuda-118_arch-aarch64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-python>=11.7.1,<12.0a0 @@ -18,6 +19,7 @@ dependencies: - doxygen=1.9.1 - gcc_linux-aarch64=11.* - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 293085e8f7..842b984cc6 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-python>=11.7.1,<12.0a0 @@ -20,6 +21,7 @@ dependencies: - libcufile-dev=1.4.0.31 - libcufile=1.4.0.31 - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-125_arch-aarch64.yaml b/conda/environments/all_cuda-125_arch-aarch64.yaml index 1e4a370ff6..9a4b3e94bd 100644 --- a/conda/environments/all_cuda-125_arch-aarch64.yaml +++ b/conda/environments/all_cuda-125_arch-aarch64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-nvcc @@ -19,6 +20,7 @@ dependencies: - gcc_linux-aarch64=11.* - libcufile-dev - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/conda/environments/all_cuda-125_arch-x86_64.yaml b/conda/environments/all_cuda-125_arch-x86_64.yaml index 44d8772a71..2b926acf29 100644 --- a/conda/environments/all_cuda-125_arch-x86_64.yaml +++ b/conda/environments/all_cuda-125_arch-x86_64.yaml @@ -6,6 +6,7 @@ channels: - conda-forge - nvidia dependencies: +- boto3>=1.21.21 - c-compiler - cmake>=3.26.4,!=3.30.0 - cuda-nvcc @@ -19,6 +20,7 @@ dependencies: - gcc_linux-64=11.* - libcufile-dev - libcurl>=7.87.0 +- moto>=4.0.8 - ninja - numcodecs !=0.12.0 - numpy>=1.23,<3.0a0 diff --git a/cpp/include/kvikio/remote_handle.hpp b/cpp/include/kvikio/remote_handle.hpp index e036ebcb37..809500f663 100644 --- a/cpp/include/kvikio/remote_handle.hpp +++ b/cpp/include/kvikio/remote_handle.hpp @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -89,7 +91,7 @@ inline std::size_t callback_device_memory(char* data, void* context) { auto ctx = reinterpret_cast(context); - const std::size_t nbytes = size * nmemb; + std::size_t const nbytes = size * nmemb; if (ctx->size < ctx->offset + nbytes) { ctx->overflow_error = true; return CURL_WRITEFUNC_ERROR; @@ -132,7 +134,7 @@ class RemoteEndpoint { * * @returns A string description. */ - virtual std::string str() = 0; + virtual std::string str() const = 0; virtual ~RemoteEndpoint() = default; }; @@ -145,12 +147,203 @@ class HttpEndpoint : public RemoteEndpoint { std::string _url; public: + /** + * @brief Create an http endpoint from a url. + * + * @param url The full http url to the remote file. + */ HttpEndpoint(std::string url) : _url{std::move(url)} {} void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); } - std::string str() override { return _url; } + std::string str() const override { return _url; } ~HttpEndpoint() override = default; }; +/** + * @brief A remote endpoint using AWS's S3 protocol. + */ +class S3Endpoint : public RemoteEndpoint { + private: + std::string _url; + std::string _aws_sigv4; + std::string _aws_userpwd; + + /** + * @brief Unwrap an optional parameter, obtaining a default from the environment. + * + * If not nullopt, the optional's value is returned. Otherwise, the environment + * variable `env_var` is used. If that also doesn't have a value: + * - if `err_msg` is empty, the empty string is returned. + * - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown. + * + * @param value The value to unwrap. + * @param env_var The name of the environment variable to check if `value` isn't set. + * @param err_msg The error message to throw on error or the empty string. + * @return The parsed AWS argument or the empty string. + */ + static std::string unwrap_or_default(std::optional aws_arg, + std::string const& env_var, + std::string const& err_msg = "") + { + if (aws_arg.has_value()) { return std::move(*aws_arg); } + + char const* env = std::getenv(env_var.c_str()); + if (env == nullptr) { + if (err_msg.empty()) { return std::string(); } + throw std::invalid_argument(err_msg); + } + return std::string(env); + } + + public: + /** + * @brief Get url from a AWS S3 bucket and object name. + * + * @throws std::invalid_argument if no region is specified and no default region is + * specified in the environment. + * + * @param bucket_name The name of the S3 bucket. + * @param object_name The name of the S3 object. + * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the + * `AWS_DEFAULT_REGION` environment variable is used. + * @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using + * the scheme: "//". If nullopt, the value of the + * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS + * url scheme is used: "https://.s3..amazonaws.com/". + */ + static std::string url_from_bucket_and_object(std::string const& bucket_name, + std::string const& object_name, + std::optional const& aws_region, + std::optional aws_endpoint_url) + { + auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL"); + std::stringstream ss; + if (endpoint_url.empty()) { + auto const region = + unwrap_or_default(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + // We default to the official AWS url scheme. + ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name; + } else { + ss << endpoint_url << "/" << bucket_name << "/" << object_name; + } + return ss.str(); + } + + /** + * @brief Given an url like "s3:///", return the name of the bucket and object. + * + * @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name. + * + * @param s3_url S3 url. + * @return Pair of strings: [bucket-name, object-name]. + */ + [[nodiscard]] static std::pair parse_s3_url(std::string const& s3_url) + { + // Regular expression to match s3:/// + std::regex const pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase}; + std::smatch matches; + if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; } + throw std::invalid_argument("Input string does not match the expected S3 URL format."); + } + + /** + * @brief Create a S3 endpoint from a url. + * + * @param url The full http url to the S3 file. NB: this should be an url starting with + * "http://" or "https://". If you have an S3 url of the form "s3:///", please + * use `S3Endpoint::parse_s3_url()` and `S3Endpoint::url_from_bucket_and_object() to convert it. + * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the + * `AWS_DEFAULT_REGION` environment variable is used. + * @param aws_access_key The AWS access key to use. If nullopt, the value of the + * `AWS_ACCESS_KEY_ID` environment variable is used. + * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the + * `AWS_SECRET_ACCESS_KEY` environment variable is used. + */ + S3Endpoint(std::string url, + std::optional aws_region = std::nullopt, + std::optional aws_access_key = std::nullopt, + std::optional aws_secret_access_key = std::nullopt) + : _url{std::move(url)} + { + // Regular expression to match http[s]:// + std::regex pattern{R"(^https?://.*)", std::regex_constants::icase}; + if (!std::regex_search(_url, pattern)) { + throw std::invalid_argument("url must start with http:// or https://"); + } + + auto const region = + unwrap_or_default(std::move(aws_region), + "AWS_DEFAULT_REGION", + "S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set."); + + auto const access_key = + unwrap_or_default(std::move(aws_access_key), + "AWS_ACCESS_KEY_ID", + "S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set."); + + auto const secret_access_key = unwrap_or_default( + std::move(aws_secret_access_key), + "AWS_SECRET_ACCESS_KEY", + "S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set."); + + // Create the CURLOPT_AWS_SIGV4 option + { + std::stringstream ss; + ss << "aws:amz:" << region << ":s3"; + _aws_sigv4 = ss.str(); + } + // Create the CURLOPT_USERPWD option + // Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT included + // in the http header. See + // + { + std::stringstream ss; + ss << access_key << ":" << secret_access_key; + _aws_userpwd = ss.str(); + } + } + + /** + * @brief Create a S3 endpoint from a bucket and object name. + * + * @param bucket_name The name of the S3 bucket. + * @param object_name The name of the S3 object. + * @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the + * `AWS_DEFAULT_REGION` environment variable is used. + * @param aws_access_key The AWS access key to use. If nullopt, the value of the + * `AWS_ACCESS_KEY_ID` environment variable is used. + * @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the + * `AWS_SECRET_ACCESS_KEY` environment variable is used. + * @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using + * the scheme: "//". If nullopt, the value of the + * `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS + * url scheme is used: "https://.s3..amazonaws.com/". + */ + S3Endpoint(std::string const& bucket_name, + std::string const& object_name, + std::optional aws_region = std::nullopt, + std::optional aws_access_key = std::nullopt, + std::optional aws_secret_access_key = std::nullopt, + std::optional aws_endpoint_url = std::nullopt) + : S3Endpoint(url_from_bucket_and_object( + bucket_name, object_name, aws_region, std::move(aws_endpoint_url)), + std::move(aws_region), + std::move(aws_access_key), + std::move(aws_secret_access_key)) + { + } + + void setopt(CurlHandle& curl) override + { + curl.setopt(CURLOPT_URL, _url.c_str()); + curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str()); + curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str()); + } + std::string str() const override { return _url; } + ~S3Endpoint() override = default; +}; + /** * @brief Handle of remote file. */ @@ -211,6 +404,13 @@ class RemoteHandle { */ [[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; } + /** + * @brief Get a const reference to the underlying remote endpoint. + * + * @return The remote endpoint. + */ + [[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; } + /** * @brief Read from remote source into buffer (host or device memory). * @@ -229,7 +429,7 @@ class RemoteHandle { << " bytes file (" << _endpoint->str() << ")"; throw std::invalid_argument(ss.str()); } - const bool is_host_mem = is_host_memory(buf); + bool const is_host_mem = is_host_memory(buf); auto curl = create_curl_handle(); _endpoint->setopt(curl); diff --git a/dependencies.yaml b/dependencies.yaml index 39ba3aaa17..85bf871150 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -345,6 +345,13 @@ dependencies: - pytest - pytest-cov - rangehttpserver + - boto3>=1.21.21 + - output_types: [requirements, pyproject] + packages: + - moto[server]>=4.0.8 + - output_types: conda + packages: + - moto>=4.0.8 specific: - output_types: [conda, requirements, pyproject] matrices: diff --git a/python/kvikio/kvikio/_lib/remote_handle.pyx b/python/kvikio/kvikio/_lib/remote_handle.pyx index 93c6ac398a..1e0b14acb9 100644 --- a/python/kvikio/kvikio/_lib/remote_handle.pyx +++ b/python/kvikio/kvikio/_lib/remote_handle.pyx @@ -18,17 +18,25 @@ from kvikio._lib.future cimport IOFuture, _wrap_io_future, future cdef extern from "" nogil: cdef cppclass cpp_RemoteEndpoint "kvikio::RemoteEndpoint": - pass + string str() except + - cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint": + cdef cppclass cpp_HttpEndpoint "kvikio::HttpEndpoint"(cpp_RemoteEndpoint): cpp_HttpEndpoint(string url) except + + cdef cppclass cpp_S3Endpoint "kvikio::S3Endpoint"(cpp_RemoteEndpoint): + cpp_S3Endpoint(string url) except + + cpp_S3Endpoint(string bucket_name, string object_name) except + + + pair[string, string] cpp_parse_s3_url \ + "kvikio::S3Endpoint::parse_s3_url"(string url) except + + cdef cppclass cpp_RemoteHandle "kvikio::RemoteHandle": cpp_RemoteHandle( unique_ptr[cpp_RemoteEndpoint] endpoint, size_t nbytes ) except + cpp_RemoteHandle(unique_ptr[cpp_RemoteEndpoint] endpoint) except + int nbytes() except + + const cpp_RemoteEndpoint& endpoint() except + size_t read( void* buf, size_t size, @@ -48,20 +56,27 @@ cdef string _to_string(str s): else: return string() +# Helper function to cast an endpoint to its base class `RemoteEndpoint` +cdef extern from *: + """ + template + std::unique_ptr cast_to_remote_endpoint(T endpoint) + { + return std::move(endpoint); + } + """ + cdef unique_ptr[cpp_RemoteEndpoint] cast_to_remote_endpoint[T](T handle) except + + cdef class RemoteFile: cdef unique_ptr[cpp_RemoteHandle] _handle - @classmethod - def open_http( - cls, - url: str, + @staticmethod + cdef RemoteFile _from_endpoint( + unique_ptr[cpp_RemoteEndpoint] ep, nbytes: Optional[int], ): cdef RemoteFile ret = RemoteFile() - cdef unique_ptr[cpp_HttpEndpoint] ep = make_unique[cpp_HttpEndpoint]( - _to_string(url) - ) if nbytes is None: ret._handle = make_unique[cpp_RemoteHandle](move(ep)) return ret @@ -69,6 +84,64 @@ cdef class RemoteFile: ret._handle = make_unique[cpp_RemoteHandle](move(ep), n) return ret + @staticmethod + def open_http( + url: str, + nbytes: Optional[int], + ): + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_HttpEndpoint](_to_string(url)) + ), + nbytes + ) + + @staticmethod + def open_s3( + bucket_name: str, + object_name: str, + nbytes: Optional[int], + ): + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_S3Endpoint]( + _to_string(bucket_name), _to_string(object_name) + ) + ), + nbytes + ) + + @staticmethod + def open_s3_from_http_url( + url: str, + nbytes: Optional[int], + ): + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_S3Endpoint](_to_string(url)) + ), + nbytes + ) + + @staticmethod + def open_s3_from_s3_url( + url: str, + nbytes: Optional[int], + ): + cdef pair[string, string] bucket_and_object = cpp_parse_s3_url(_to_string(url)) + return RemoteFile._from_endpoint( + cast_to_remote_endpoint( + make_unique[cpp_S3Endpoint]( + bucket_and_object.first, bucket_and_object.second + ) + ), + nbytes + ) + + def __str__(self) -> str: + cdef string ep_str = deref(self._handle).endpoint().str() + return f'<{self.__class__.__name__} "{ep_str.decode()}">' + def nbytes(self) -> int: return deref(self._handle).nbytes() diff --git a/python/kvikio/kvikio/benchmarks/s3_io.py b/python/kvikio/kvikio/benchmarks/s3_io.py new file mode 100644 index 0000000000..7941462650 --- /dev/null +++ b/python/kvikio/kvikio/benchmarks/s3_io.py @@ -0,0 +1,221 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import argparse +import contextlib +import multiprocessing +import os +import socket +import statistics +import sys +import time +from functools import partial +from typing import ContextManager +from urllib.parse import urlparse + +import boto3 +import cupy +import numpy +from dask.utils import format_bytes + +import kvikio +import kvikio.defaults + + +def get_local_port() -> int: + """Return an available port""" + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_s3_server(lifetime: int): + """Start a server and run it for `lifetime` minutes. + NB: to stop before `lifetime`, kill the process/thread running this function. + """ + from moto.server import ThreadedMotoServer + + # Silence the activity info from ThreadedMotoServer + sys.stderr = open(os.devnull, "w") + url = urlparse(os.environ["AWS_ENDPOINT_URL"]) + server = ThreadedMotoServer(ip_address=url.hostname, port=url.port) + server.start() + time.sleep(lifetime) + + +@contextlib.contextmanager +def local_s3_server(lifetime: int): + """Start a server and run it for `lifetime` minutes or kill it on context exit""" + # Use fake aws credentials + os.environ["AWS_ACCESS_KEY_ID"] = "foobar_key" + os.environ["AWS_SECRET_ACCESS_KEY"] = "foobar_secret" + os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + p = multiprocessing.Process(target=start_s3_server, args=(lifetime,)) + p.start() + yield + p.kill() + + +def create_client_and_bucket(): + client = boto3.client("s3", endpoint_url=os.getenv("AWS_ENDPOINT_URL", None)) + try: + bucket_names = {bucket["Name"] for bucket in client.list_buckets()["Buckets"]} + if args.bucket not in bucket_names: + client.create_bucket(Bucket=args.bucket, ACL="public-read-write") + except Exception: + print( + "Problem accessing the S3 server? using wrong credentials? Try setting " + "AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, and/or AWS_ENDPOINT_URL. Also, " + "if the bucket doesn't exist, make sure you have the required permission. " + "Alternatively, use the bundled server `--use-bundled-server`:\n", + file=sys.stderr, + flush=True, + ) + raise + return client + + +def run_numpy_like(args, xp): + # Upload data to S3 server + data = numpy.arange(args.nelem, dtype=args.dtype) + recv = xp.empty_like(data) + + client = create_client_and_bucket() + client.put_object(Bucket=args.bucket, Key="data", Body=bytes(data)) + url = f"s3://{args.bucket}/data" + + def run() -> float: + t0 = time.perf_counter() + with kvikio.RemoteFile.open_s3_url(url) as f: + res = f.read(recv) + t1 = time.perf_counter() + assert res == args.nbytes, f"IO mismatch, expected {args.nbytes} got {res}" + xp.testing.assert_array_equal(data, recv) + return t1 - t0 + + for _ in range(args.nruns): + yield run() + + +API = { + "cupy": partial(run_numpy_like, xp=cupy), + "numpy": partial(run_numpy_like, xp=numpy), +} + + +def main(args): + cupy.cuda.set_allocator(None) # Disable CuPy's default memory pool + cupy.arange(10) # Make sure CUDA is initialized + + os.environ["KVIKIO_NTHREADS"] = str(args.nthreads) + kvikio.defaults.num_threads_reset(args.nthreads) + + print("Remote S3 benchmark") + print("--------------------------------------") + print(f"nelem | {args.nelem} ({format_bytes(args.nbytes)})") + print(f"dtype | {args.dtype}") + print(f"nthreads | {args.nthreads}") + print(f"nruns | {args.nruns}") + print(f"file | s3://{args.bucket}/data") + if args.use_bundled_server: + print("--------------------------------------") + print("Using the bundled local server is slow") + print("and can be misleading. Consider using") + print("a local MinIO or official S3 server.") + print("======================================") + + # Run each benchmark using the requested APIs + for api in args.api: + res = [] + for elapsed in API[api](args): + res.append(elapsed) + + def pprint_api_res(name, samples): + samples = [args.nbytes / s for s in samples] # Convert to throughput + mean = statistics.harmonic_mean(samples) if len(samples) > 1 else samples[0] + ret = f"{api}-{name}".ljust(12) + ret += f"| {format_bytes(mean).rjust(10)}/s".ljust(14) + if len(samples) > 1: + stdev = statistics.stdev(samples) / mean * 100 + ret += " ± %5.2f %%" % stdev + ret += " (" + for sample in samples: + ret += f"{format_bytes(sample)}/s, " + ret = ret[:-2] + ")" # Replace trailing comma + return ret + + print(pprint_api_res("read", res)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Roundtrip benchmark") + parser.add_argument( + "-n", + "--nelem", + metavar="NELEM", + default="1024", + type=int, + help="Number of elements (default: %(default)s).", + ) + parser.add_argument( + "--dtype", + metavar="DATATYPE", + default="float32", + type=numpy.dtype, + help="The data type of each element (default: %(default)s).", + ) + parser.add_argument( + "--nruns", + metavar="RUNS", + default=1, + type=int, + help="Number of runs per API (default: %(default)s).", + ) + parser.add_argument( + "-t", + "--nthreads", + metavar="THREADS", + default=1, + type=int, + help="Number of threads to use (default: %(default)s).", + ) + parser.add_argument( + "--use-bundled-server", + action="store_true", + help="Launch and use a local slow S3 server (ThreadedMotoServer).", + ) + parser.add_argument( + "--bundled-server-lifetime", + metavar="SECONDS", + default=3600, + type=int, + help="Maximum lifetime of the bundled server (default: %(default)s).", + ) + parser.add_argument( + "--bucket", + metavar="NAME", + default="kvikio-s3-benchmark", + type=str, + help="Name of the AWS S3 bucket to use (default: %(default)s).", + ) + parser.add_argument( + "--api", + metavar="API", + default="all", + nargs="+", + choices=tuple(API.keys()) + ("all",), + help="List of APIs to use {%(choices)s} (default: %(default)s).", + ) + args = parser.parse_args() + args.nbytes = args.nelem * args.dtype.itemsize + if "all" in args.api: + args.api = tuple(API.keys()) + + ctx: ContextManager = contextlib.nullcontext() + if args.use_bundled_server: + os.environ["AWS_ENDPOINT_URL"] = f"http://127.0.0.1:{get_local_port()}" + ctx = local_s3_server(args.bundled_server_lifetime) + with ctx: + main(args) diff --git a/python/kvikio/kvikio/remote_file.py b/python/kvikio/kvikio/remote_file.py index 52bbe8010f..f10f4b49f9 100644 --- a/python/kvikio/kvikio/remote_file.py +++ b/python/kvikio/kvikio/remote_file.py @@ -68,6 +68,77 @@ def open_http( """ return RemoteFile(_get_remote_module().RemoteFile.open_http(url, nbytes)) + @classmethod + def open_s3( + cls, + bucket_name: str, + object_name: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + """Open a AWS S3 file from a bucket name and object name. + + Please make sure to set the AWS environment variables: + - `AWS_DEFAULT_REGION` + - `AWS_ACCESS_KEY_ID` + - `AWS_SECRET_ACCESS_KEY` + + Additionally, to overwrite the AWS endpoint, set `AWS_ENDPOINT_URL`. + See + + Parameters + ---------- + bucket_name + The bucket name of the file. + object_name + The object name of the file. + nbytes + The size of the file. If None, KvikIO will ask the server + for the file size. + """ + return RemoteFile( + _get_remote_module().RemoteFile.open_s3(bucket_name, object_name, nbytes) + ) + + @classmethod + def open_s3_url( + cls, + url: str, + nbytes: Optional[int] = None, + ) -> RemoteFile: + """Open a AWS S3 file from an URL. + + The `url` can take two forms: + - A full http url such as "http://127.0.0.1/my/file", or + - A S3 url such as "s3:///". + + Please make sure to set the AWS environment variables: + - `AWS_DEFAULT_REGION` + - `AWS_ACCESS_KEY_ID` + - `AWS_SECRET_ACCESS_KEY` + + Additionally, if `url` is a S3 url, it is possible to overwrite the AWS endpoint + by setting `AWS_ENDPOINT_URL`. + See + + Parameters + ---------- + url + Either a http url or a S3 url. + nbytes + The size of the file. If None, KvikIO will ask the server + for the file size. + """ + url = url.lower() + if url.startswith("http://") or url.startswith("https://"): + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_http_url(url, nbytes) + ) + if url.startswith("s3://"): + return RemoteFile( + _get_remote_module().RemoteFile.open_s3_from_s3_url(url, nbytes) + ) + raise ValueError(f"Unsupported protocol: {url}") + def close(self) -> None: """Close the file""" pass @@ -78,6 +149,9 @@ def __enter__(self) -> RemoteFile: def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + def __str__(self) -> str: + return str(self._handle) + def nbytes(self) -> int: """Get the file size. diff --git a/python/kvikio/pyproject.toml b/python/kvikio/pyproject.toml index 04f04cfa6f..25a961a858 100644 --- a/python/kvikio/pyproject.toml +++ b/python/kvikio/pyproject.toml @@ -39,8 +39,10 @@ classifiers = [ [project.optional-dependencies] test = [ + "boto3>=1.21.21", "cuda-python>=11.7.1,<12.0a0", "dask>=2022.05.2", + "moto[server]>=4.0.8", "pytest", "pytest-cov", "rangehttpserver", @@ -140,4 +142,5 @@ regex = "(?P.*)" filterwarnings = [ "error", "ignore:Jitify is performing a one-time only warm-up to populate the persistent cache", + "ignore::DeprecationWarning:botocore.*", ] diff --git a/python/kvikio/tests/test_benchmarks.py b/python/kvikio/tests/test_benchmarks.py index 5b5602e53a..307b0b258d 100644 --- a/python/kvikio/tests/test_benchmarks.py +++ b/python/kvikio/tests/test_benchmarks.py @@ -109,3 +109,32 @@ def test_http_io(run_cmd, api): cwd=benchmarks_path, ) assert retcode == 0 + + +@pytest.mark.parametrize( + "api", + [ + "cupy", + "numpy", + ], +) +def test_s3_io(run_cmd, api): + """Test benchmarks/s3_io.py""" + + if not kvikio.is_remote_file_available(): + pytest.skip( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ) + retcode = run_cmd( + cmd=[ + sys.executable, + "http_io.py", + "-n", + "1000", + "--api", + api, + ], + cwd=benchmarks_path, + ) + assert retcode == 0 diff --git a/python/kvikio/tests/test_http_io.py b/python/kvikio/tests/test_http_io.py index 70abec71b6..5c2c3888cd 100644 --- a/python/kvikio/tests/test_http_io.py +++ b/python/kvikio/tests/test_http_io.py @@ -47,6 +47,7 @@ def test_read(http_server, tmpdir, xp, size, nthreads, tasksize): with kvikio.defaults.set_task_size(tasksize): with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) b = xp.empty_like(a) assert f.read(b) == a.nbytes xp.testing.assert_array_equal(a, b) @@ -60,6 +61,7 @@ def test_large_read(http_server, tmpdir, xp, nthreads): with kvikio.defaults.set_num_threads(nthreads): with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) b = xp.empty_like(a) assert f.read(b) == a.nbytes xp.testing.assert_array_equal(a, b) @@ -71,6 +73,7 @@ def test_error_too_small_file(http_server, tmpdir, xp): a.tofile(tmpdir / "a") with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) with pytest.raises( ValueError, match=r"cannot read 0\+100 bytes into a 10 bytes file" ): @@ -88,6 +91,7 @@ def test_no_range_support(http_server, tmpdir, xp): b = xp.empty_like(a) with kvikio.RemoteFile.open_http(f"{http_server}/a") as f: assert f.nbytes() == a.nbytes + assert f"{http_server}/a" in str(f) with pytest.raises( OverflowError, match="maybe the server doesn't support file ranges?" ): diff --git a/python/kvikio/tests/test_s3_io.py b/python/kvikio/tests/test_s3_io.py new file mode 100644 index 0000000000..1f2bae95d0 --- /dev/null +++ b/python/kvikio/tests/test_s3_io.py @@ -0,0 +1,159 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# See file LICENSE for terms. + +import multiprocessing as mp +import socket +import time +from contextlib import contextmanager + +import pytest + +import kvikio +import kvikio.defaults + +pytestmark = pytest.mark.skipif( + not kvikio.is_remote_file_available(), + reason=( + "RemoteFile not available, please build KvikIO " + "with libcurl (-DKvikIO_REMOTE_SUPPORT=ON)" + ), +) + +# Notice, we import boto and moto after the `is_remote_file_available` check. +import boto3 # noqa: E402 +import moto # noqa: E402 +import moto.server # noqa: E402 + + +@pytest.fixture(scope="session") +def endpoint_ip(): + return "127.0.0.1" + + +@pytest.fixture(scope="session") +def endpoint_port(): + # Return a free port per worker session. + sock = socket.socket() + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.close() + return port + + +def start_s3_server(ip_address, port): + server = moto.server.ThreadedMotoServer(ip_address=ip_address, port=port) + server.start() + time.sleep(600) + print("ThreadedMotoServer shutting down because of timeout (10min)") + + +@pytest.fixture(scope="session") +def s3_base(endpoint_ip, endpoint_port): + """Fixture to set up moto server in separate process""" + with pytest.MonkeyPatch.context() as monkeypatch: + # Use fake aws credentials + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "foobar_key") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "foobar_secret") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + monkeypatch.setenv("AWS_ENDPOINT_URL", f"http://{endpoint_ip}:{endpoint_port}") + + p = mp.Process(target=start_s3_server, args=(endpoint_ip, endpoint_port)) + p.start() + yield f"http://{endpoint_ip}:{endpoint_port}" + p.kill() + + +@contextmanager +def s3_context(s3_base, bucket, files=None): + if files is None: + files = {} + client = boto3.client("s3", endpoint_url=s3_base) + client.create_bucket(Bucket=bucket, ACL="public-read-write") + for f, data in files.items(): + client.put_object(Bucket=bucket, Key=f, Body=data) + yield s3_base + for f, data in files.items(): + try: + client.delete_object(Bucket=bucket, Key=f) + except Exception: + pass + + +def test_read_access(s3_base): + bucket_name = "bucket" + object_name = "data" + data = b"file content" + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(data)} + ) as server_address: + with kvikio.RemoteFile.open_s3_url(f"s3://{bucket_name}/{object_name}") as f: + assert f.nbytes() == len(data) + got = bytearray(len(data)) + assert f.read(got) == len(got) + + with kvikio.RemoteFile.open_s3(bucket_name, object_name) as f: + assert f.nbytes() == len(data) + got = bytearray(len(data)) + assert f.read(got) == len(got) + + with kvikio.RemoteFile.open_s3_url( + f"{server_address}/{bucket_name}/{object_name}" + ) as f: + assert f.nbytes() == len(data) + got = bytearray(len(data)) + assert f.read(got) == len(got) + + with pytest.raises(ValueError, match="Unsupported protocol"): + kvikio.RemoteFile.open_s3_url(f"unknown://{bucket_name}/{object_name}") + + with pytest.raises(RuntimeError, match="URL returned error: 404"): + kvikio.RemoteFile.open_s3("unknown-bucket", object_name) + + with pytest.raises(RuntimeError, match="URL returned error: 404"): + kvikio.RemoteFile.open_s3(bucket_name, "unknown-file") + + +@pytest.mark.parametrize("size", [10, 100, 1000]) +@pytest.mark.parametrize("nthreads", [1, 3]) +@pytest.mark.parametrize("tasksize", [99, 999]) +@pytest.mark.parametrize("buffer_size", [101, 1001]) +def test_read(s3_base, xp, size, nthreads, tasksize, buffer_size): + bucket_name = "test_read" + object_name = "a1" + a = xp.arange(size) + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} + ) as server_address: + with kvikio.defaults.set_num_threads(nthreads): + with kvikio.defaults.set_task_size(tasksize): + with kvikio.defaults.set_bounce_buffer_size(buffer_size): + with kvikio.RemoteFile.open_s3_url( + f"{server_address}/{bucket_name}/{object_name}" + ) as f: + assert f.nbytes() == a.nbytes + b = xp.empty_like(a) + assert f.read(buf=b) == a.nbytes + xp.testing.assert_array_equal(a, b) + + +@pytest.mark.parametrize( + "start,end", + [ + (0, 10 * 4096), + (1, int(1.3 * 4096)), + (int(2.1 * 4096), int(5.6 * 4096)), + (42, int(2**20)), + ], +) +def test_read_with_file_offset(s3_base, xp, start, end): + bucket_name = "test_read_with_file_offset" + object_name = "a1" + a = xp.arange(end, dtype=xp.int64) + with s3_context( + s3_base=s3_base, bucket=bucket_name, files={object_name: bytes(a)} + ) as server_address: + url = f"{server_address}/{bucket_name}/{object_name}" + with kvikio.RemoteFile.open_s3_url(url) as f: + b = xp.zeros(shape=(end - start,), dtype=xp.int64) + assert f.read(b, file_offset=start * a.itemsize) == b.nbytes + xp.testing.assert_array_equal(a[start:end], b)