Skip to content

Commit

Permalink
Remote IO: S3 support (#479)
Browse files Browse the repository at this point in the history
Implements AWS S3 read support using libcurl:
```python
import kvikio
import cupy

with kvikio.RemoteFile.from_s3_url("s://my-bucket/my-file") as f:
    ary = cupy.empty(f.nbytes, dtype="uint8")
    f.read(ary)
```

Supersedes #426

Authors:
  - Mads R. B. Kristensen (https://github.com/madsbk)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Lawrence Mitchell (https://github.com/wence-)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #479
  • Loading branch information
madsbk authored Oct 22, 2024
1 parent ed2d6d0 commit fcf4b15
Show file tree
Hide file tree
Showing 13 changed files with 791 additions and 13 deletions.
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -18,6 +19,7 @@ dependencies:
- doxygen=1.9.1
- gcc_linux-aarch64=11.*
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -20,6 +21,7 @@ dependencies:
- libcufile-dev=1.4.0.31
- libcufile=1.4.0.31
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-aarch64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
208 changes: 204 additions & 4 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <cstddef>
#include <cstring>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -89,7 +91,7 @@ inline std::size_t callback_device_memory(char* data,
void* context)
{
auto ctx = reinterpret_cast<CallbackContext*>(context);
const std::size_t nbytes = size * nmemb;
std::size_t const nbytes = size * nmemb;
if (ctx->size < ctx->offset + nbytes) {
ctx->overflow_error = true;
return CURL_WRITEFUNC_ERROR;
Expand Down Expand Up @@ -132,7 +134,7 @@ class RemoteEndpoint {
*
* @returns A string description.
*/
virtual std::string str() = 0;
virtual std::string str() const = 0;

virtual ~RemoteEndpoint() = default;
};
Expand All @@ -145,12 +147,203 @@ class HttpEndpoint : public RemoteEndpoint {
std::string _url;

public:
/**
* @brief Create an http endpoint from a url.
*
* @param url The full http url to the remote file.
*/
HttpEndpoint(std::string url) : _url{std::move(url)} {}
void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); }
std::string str() override { return _url; }
std::string str() const override { return _url; }
~HttpEndpoint() override = default;
};

/**
* @brief A remote endpoint using AWS's S3 protocol.
*/
class S3Endpoint : public RemoteEndpoint {
private:
std::string _url;
std::string _aws_sigv4;
std::string _aws_userpwd;

/**
* @brief Unwrap an optional parameter, obtaining a default from the environment.
*
* If not nullopt, the optional's value is returned. Otherwise, the environment
* variable `env_var` is used. If that also doesn't have a value:
* - if `err_msg` is empty, the empty string is returned.
* - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown.
*
* @param value The value to unwrap.
* @param env_var The name of the environment variable to check if `value` isn't set.
* @param err_msg The error message to throw on error or the empty string.
* @return The parsed AWS argument or the empty string.
*/
static std::string unwrap_or_default(std::optional<std::string> aws_arg,
std::string const& env_var,
std::string const& err_msg = "")
{
if (aws_arg.has_value()) { return std::move(*aws_arg); }

char const* env = std::getenv(env_var.c_str());
if (env == nullptr) {
if (err_msg.empty()) { return std::string(); }
throw std::invalid_argument(err_msg);
}
return std::string(env);
}

public:
/**
* @brief Get url from a AWS S3 bucket and object name.
*
* @throws std::invalid_argument if no region is specified and no default region is
* specified in the environment.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
static std::string url_from_bucket_and_object(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> const& aws_region,
std::optional<std::string> aws_endpoint_url)
{
auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL");
std::stringstream ss;
if (endpoint_url.empty()) {
auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");
// We default to the official AWS url scheme.
ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name;
} else {
ss << endpoint_url << "/" << bucket_name << "/" << object_name;
}
return ss.str();
}

/**
* @brief Given an url like "s3://<bucket>/<object>", return the name of the bucket and object.
*
* @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name.
*
* @param s3_url S3 url.
* @return Pair of strings: [bucket-name, object-name].
*/
[[nodiscard]] static std::pair<std::string, std::string> parse_s3_url(std::string const& s3_url)
{
// Regular expression to match s3://<bucket>/<object>
std::regex const pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase};
std::smatch matches;
if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; }
throw std::invalid_argument("Input string does not match the expected S3 URL format.");
}

/**
* @brief Create a S3 endpoint from a url.
*
* @param url The full http url to the S3 file. NB: this should be an url starting with
* "http://" or "https://". If you have an S3 url of the form "s3://<bucket>/<object>", please
* use `S3Endpoint::parse_s3_url()` and `S3Endpoint::url_from_bucket_and_object() to convert it.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
*/
S3Endpoint(std::string url,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt)
: _url{std::move(url)}
{
// Regular expression to match http[s]://
std::regex pattern{R"(^https?://.*)", std::regex_constants::icase};
if (!std::regex_search(_url, pattern)) {
throw std::invalid_argument("url must start with http:// or https://");
}

auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");

auto const access_key =
unwrap_or_default(std::move(aws_access_key),
"AWS_ACCESS_KEY_ID",
"S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set.");

auto const secret_access_key = unwrap_or_default(
std::move(aws_secret_access_key),
"AWS_SECRET_ACCESS_KEY",
"S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set.");

// Create the CURLOPT_AWS_SIGV4 option
{
std::stringstream ss;
ss << "aws:amz:" << region << ":s3";
_aws_sigv4 = ss.str();
}
// Create the CURLOPT_USERPWD option
// Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT included
// in the http header. See
// <https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html>
{
std::stringstream ss;
ss << access_key << ":" << secret_access_key;
_aws_userpwd = ss.str();
}
}

/**
* @brief Create a S3 endpoint from a bucket and object name.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
S3Endpoint(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt,
std::optional<std::string> aws_endpoint_url = std::nullopt)
: S3Endpoint(url_from_bucket_and_object(
bucket_name, object_name, aws_region, std::move(aws_endpoint_url)),
std::move(aws_region),
std::move(aws_access_key),
std::move(aws_secret_access_key))
{
}

void setopt(CurlHandle& curl) override
{
curl.setopt(CURLOPT_URL, _url.c_str());
curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str());
curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str());
}
std::string str() const override { return _url; }
~S3Endpoint() override = default;
};

/**
* @brief Handle of remote file.
*/
Expand Down Expand Up @@ -211,6 +404,13 @@ class RemoteHandle {
*/
[[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; }

/**
* @brief Get a const reference to the underlying remote endpoint.
*
* @return The remote endpoint.
*/
[[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; }

/**
* @brief Read from remote source into buffer (host or device memory).
*
Expand All @@ -229,7 +429,7 @@ class RemoteHandle {
<< " bytes file (" << _endpoint->str() << ")";
throw std::invalid_argument(ss.str());
}
const bool is_host_mem = is_host_memory(buf);
bool const is_host_mem = is_host_memory(buf);
auto curl = create_curl_handle();
_endpoint->setopt(curl);

Expand Down
7 changes: 7 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ dependencies:
- pytest
- pytest-cov
- rangehttpserver
- boto3>=1.21.21
- output_types: [requirements, pyproject]
packages:
- moto[server]>=4.0.8
- output_types: conda
packages:
- moto>=4.0.8
specific:
- output_types: [conda, requirements, pyproject]
matrices:
Expand Down
Loading

0 comments on commit fcf4b15

Please sign in to comment.