Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remote IO: S3 support #479

Merged
merged 27 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -18,6 +19,7 @@ dependencies:
- doxygen=1.9.1
- gcc_linux-aarch64=11.*
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-python>=11.7.1,<12.0a0
Expand All @@ -20,6 +21,7 @@ dependencies:
- libcufile-dev=1.4.0.31
- libcufile=1.4.0.31
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-aarch64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-aarch64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
2 changes: 2 additions & 0 deletions conda/environments/all_cuda-125_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
- conda-forge
- nvidia
dependencies:
- boto3>=1.21.21
- c-compiler
- cmake>=3.26.4,!=3.30.0
- cuda-nvcc
Expand All @@ -19,6 +20,7 @@ dependencies:
- gcc_linux-64=11.*
- libcufile-dev
- libcurl>=7.87.0
- moto>=4.0.8
- ninja
- numcodecs !=0.12.0
- numpy>=1.23,<3.0a0
Expand Down
208 changes: 204 additions & 4 deletions cpp/include/kvikio/remote_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <cstddef>
#include <cstring>
#include <memory>
#include <optional>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -89,7 +91,7 @@ inline std::size_t callback_device_memory(char* data,
void* context)
{
auto ctx = reinterpret_cast<CallbackContext*>(context);
const std::size_t nbytes = size * nmemb;
std::size_t const nbytes = size * nmemb;
if (ctx->size < ctx->offset + nbytes) {
ctx->overflow_error = true;
return CURL_WRITEFUNC_ERROR;
Expand Down Expand Up @@ -132,7 +134,7 @@ class RemoteEndpoint {
*
* @returns A string description.
*/
virtual std::string str() = 0;
virtual std::string str() const = 0;

virtual ~RemoteEndpoint() = default;
};
Expand All @@ -145,12 +147,203 @@ class HttpEndpoint : public RemoteEndpoint {
std::string _url;

public:
/**
* @brief Create an http endpoint from a url.
*
* @param url The full http url to the remote file.
*/
HttpEndpoint(std::string url) : _url{std::move(url)} {}
void setopt(CurlHandle& curl) override { curl.setopt(CURLOPT_URL, _url.c_str()); }
std::string str() override { return _url; }
std::string str() const override { return _url; }
~HttpEndpoint() override = default;
};

/**
* @brief A remote endpoint using AWS's S3 protocol.
*/
class S3Endpoint : public RemoteEndpoint {
private:
std::string _url;
std::string _aws_sigv4;
std::string _aws_userpwd;

/**
* @brief Unwrap an optional parameter, obtaining a default from the environment.
*
* If not nullopt, the optional's value is returned otherwise the environment
madsbk marked this conversation as resolved.
Show resolved Hide resolved
* variable `env_var` is used. If that also doesn't have a value:
* - if `err_msg` is empty, the empty string is returned.
* - if `err_msg` is not empty, `std::invalid_argument(`err_msg`)` is thrown.
*
* @param value The value to unwrap.
* @param env_var The name of the environment variable to check if `value` isn't set.
* @param err_msg The error message to throw on error or the empty string.
vyasr marked this conversation as resolved.
Show resolved Hide resolved
* @return The parsed AWS argument or the empty string.
*/
static std::string unwrap_or_default(std::optional<std::string> aws_arg,
std::string const& env_var,
std::string const& err_msg = "")
{
if (aws_arg.has_value()) { return std::move(*aws_arg); }

char const* env = std::getenv(env_var.c_str());
if (env == nullptr) {
if (err_msg.empty()) { return std::string(); }
throw std::invalid_argument(err_msg);
}
return std::string(env);
}

public:
/**
* @brief Get url from a AWS S3 bucket and object name.
*
* @throws std::invalid_argument if no region is specified and no default region is
* specified in the environment.
*
* @param bucket_name The name of the S3 bucket.
madsbk marked this conversation as resolved.
Show resolved Hide resolved
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
static std::string url_from_bucket_and_object(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> const& aws_region,
std::optional<std::string> aws_endpoint_url)
{
auto const endpoint_url = unwrap_or_default(std::move(aws_endpoint_url), "AWS_ENDPOINT_URL");
std::stringstream ss;
if (endpoint_url.empty()) {
auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");
// We default to the official AWS url scheme.
ss << "https://" << bucket_name << ".s3." << region << ".amazonaws.com/" << object_name;
} else {
ss << endpoint_url << "/" << bucket_name << "/" << object_name;
}
return ss.str();
}

/**
* @brief Given an url like "s3://<bucket>/<object>", return the name of the bucket and object.
*
* @throws std::invalid_argument if url is ill-formed or is missing the bucket or object name.
*
* @param s3_url S3 url.
* @return Pair of strings: [bucket-name, object-name].
*/
[[nodiscard]] static std::pair<std::string, std::string> parse_s3_url(std::string const& s3_url)
{
// Regular expression to match s3://<bucket>/<object>
std::regex pattern{R"(^s3://([^/]+)/(.+))", std::regex_constants::icase};
madsbk marked this conversation as resolved.
Show resolved Hide resolved
std::smatch matches;
if (std::regex_match(s3_url, matches, pattern)) { return {matches[1].str(), matches[2].str()}; }
throw std::invalid_argument("Input string does not match the expected S3 URL format.");
}

/**
* @brief Create a S3 endpoint from a url.
*
* @param url The full http url to the S3 file. NB: this should be an url starting with
* "http://" or "https://". If you have an S3 url of the form "s3://<bucket>/<object>", please
* use `S3Endpoint::parse_s3_url()` and `S3Endpoint::url_from_bucket_and_object() to convert it.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
*/
S3Endpoint(std::string url,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt)
: _url{std::move(url)}
{
// Regular expression to match http[s]://
std::regex pattern{R"(^https?://.*)", std::regex_constants::icase};
if (!std::regex_search(_url, pattern)) {
throw std::invalid_argument("url must start with http:// or https://");
}

auto const region =
unwrap_or_default(std::move(aws_region),
"AWS_DEFAULT_REGION",
"S3: must provide `aws_region` if AWS_DEFAULT_REGION isn't set.");

auto const access_key =
unwrap_or_default(std::move(aws_access_key),
"AWS_ACCESS_KEY_ID",
"S3: must provide `aws_access_key` if AWS_ACCESS_KEY_ID isn't set.");

auto const secret_access_key = unwrap_or_default(
std::move(aws_secret_access_key),
"AWS_SECRET_ACCESS_KEY",
"S3: must provide `aws_secret_access_key` if AWS_SECRET_ACCESS_KEY isn't set.");

// Create the CURLOPT_AWS_SIGV4 option
{
std::stringstream ss;
ss << "aws:amz:" << region << ":s3";
_aws_sigv4 = ss.str();
}
// Create the CURLOPT_USERPWD option
// Notice, curl uses `secret_access_key` to generate a AWS V4 signature. It is NOT set
// over the wire. See
// <https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_sigv-create-signed-request.html>
{
std::stringstream ss;
ss << access_key << ":" << secret_access_key;
_aws_userpwd = ss.str();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have std::format in C++20...

}
}

/**
* @brief Create a S3 endpoint from a bucket and object name.
*
* @param bucket_name The name of the S3 bucket.
* @param object_name The name of the S3 object.
* @param aws_region The AWS region, such as "us-east-1", to use. If nullopt, the value of the
* `AWS_DEFAULT_REGION` environment variable is used.
* @param aws_access_key The AWS access key to use. If nullopt, the value of the
* `AWS_ACCESS_KEY_ID` environment variable is used.
* @param aws_secret_access_key The AWS secret access key to use. If nullopt, the value of the
* `AWS_SECRET_ACCESS_KEY` environment variable is used.
* @param aws_endpoint_url Overwrite the endpoint url (including the protocol part) by using
* the scheme: "<aws_endpoint_url>/<bucket_name>/<object_name>". If nullopt, the value of the
* `AWS_ENDPOINT_URL` environment variable is used. If this is also not set, the regular AWS
* url scheme is used: "https://<bucket_name>.s3.<region>.amazonaws.com/<object_name>".
*/
S3Endpoint(std::string const& bucket_name,
std::string const& object_name,
std::optional<std::string> aws_region = std::nullopt,
std::optional<std::string> aws_access_key = std::nullopt,
std::optional<std::string> aws_secret_access_key = std::nullopt,
std::optional<std::string> aws_endpoint_url = std::nullopt)
: S3Endpoint(url_from_bucket_and_object(
bucket_name, object_name, aws_region, std::move(aws_endpoint_url)),
std::move(aws_region),
std::move(aws_access_key),
std::move(aws_secret_access_key))
{
}

void setopt(CurlHandle& curl) override
{
curl.setopt(CURLOPT_URL, _url.c_str());
curl.setopt(CURLOPT_AWS_SIGV4, _aws_sigv4.c_str());
curl.setopt(CURLOPT_USERPWD, _aws_userpwd.c_str());
}
std::string str() const override { return _url; }
~S3Endpoint() override = default;
};

/**
* @brief Handle of remote file.
*/
Expand Down Expand Up @@ -211,6 +404,13 @@ class RemoteHandle {
*/
[[nodiscard]] std::size_t nbytes() const noexcept { return _nbytes; }

/**
* @brief Get a const reference to the underlying remote endpoint.
*
* @return The remote endpoint.
*/
[[nodiscard]] RemoteEndpoint const& endpoint() const noexcept { return *_endpoint; }

/**
* @brief Read from remote source into buffer (host or device memory).
*
Expand All @@ -229,7 +429,7 @@ class RemoteHandle {
<< " bytes file (" << _endpoint->str() << ")";
throw std::invalid_argument(ss.str());
}
const bool is_host_mem = is_host_memory(buf);
bool const is_host_mem = is_host_memory(buf);
auto curl = create_curl_handle();
_endpoint->setopt(curl);

Expand Down
7 changes: 7 additions & 0 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ dependencies:
- pytest
- pytest-cov
- rangehttpserver
- boto3>=1.21.21
- output_types: [requirements, pyproject]
packages:
- moto[server]>=4.0.8
- output_types: conda
packages:
- moto>=4.0.8
specific:
- output_types: [conda, requirements, pyproject]
matrices:
Expand Down
Loading