From a9529362707f19c3552ed395ed0d47771407bd0c Mon Sep 17 00:00:00 2001 From: Julian Meyers Date: Fri, 14 Nov 2025 12:04:03 -0600 Subject: [PATCH] Add credentials refresh in more locations --- src/include/s3fs.hpp | 29 ++++ src/s3fs.cpp | 379 ++++++++++++++++++++++++------------------- 2 files changed, 242 insertions(+), 166 deletions(-) diff --git a/src/include/s3fs.hpp b/src/include/s3fs.hpp index a7e933ea..ff076bd2 100644 --- a/src/include/s3fs.hpp +++ b/src/include/s3fs.hpp @@ -33,7 +33,15 @@ struct S3AuthParams { bool requester_pays = false; string oauth2_bearer_token; // OAuth2 bearer token for GCS + // Store FileOpener and path for credential refresh + optional_ptr opener; + string path; + static S3AuthParams ReadFrom(optional_ptr opener, FileOpenerInfo &info); + + //! Try to refresh credentials if they've expired + //! Returns true if refresh succeeded and credentials were updated + bool TryRefreshCredentials(); }; struct AWSEnvironmentCredentialsProvider { @@ -236,6 +244,9 @@ class S3FileSystem : public HTTPFileSystem { static string GetGCSAuthError(S3AuthParams &s3_auth_params); static HTTPException GetS3Error(S3AuthParams &s3_auth_params, const HTTPResponse &response, const string &url); + //! Helper method to attempt secret refresh for a given path + static bool TryRefreshSecret(const string &path, optional_ptr opener); + protected: static void NotifyUploadsInProgress(S3FileHandle &file_handle); duckdb::unique_ptr CreateHandle(const OpenFileInfo &file, FileOpenFlags flags, @@ -245,6 +256,24 @@ class S3FileSystem : public HTTPFileSystem { string GetPayloadHash(char *buffer, idx_t buffer_len); HTTPException GetHTTPError(FileHandle &, const HTTPResponse &response, const string &url) override; + +private: + template + auto ExecuteWithRefresh(S3AuthParams &auth_params, RequestFunc request_func) + -> decltype(request_func(auth_params)) { + try { + return request_func(auth_params); + } catch (std::exception &ex) { + ErrorData error(ex); + // Only attempt refresh for HTTP or IO exceptions (same logic as S3FileHandle::Initialize) + if (error.Type() == ExceptionType::IO || error.Type() == ExceptionType::HTTP) { + if (auth_params.TryRefreshCredentials()) { + return request_func(auth_params); // Retry + } + } + throw; + } + } }; // Helper class to do s3 ListObjectV2 api call https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html diff --git a/src/s3fs.cpp b/src/s3fs.cpp index b4d6c741..5eb0e79a 100644 --- a/src/s3fs.cpp +++ b/src/s3fs.cpp @@ -228,9 +228,38 @@ S3AuthParams S3AuthParams::ReadFrom(optional_ptr opener, FileOpenerI result.endpoint = "s3.amazonaws.com"; } + // Store opener and path for potential credential refresh + result.opener = opener; + result.path = info.file_path; + return result; } +bool S3AuthParams::TryRefreshCredentials() { + if (!opener) { + return false; + } + + // Try to refresh the secret using S3FileSystem helper + if (S3FileSystem::TryRefreshSecret(path, opener)) { + // Refresh succeeded, reload credentials + FileOpenerInfo info = {path}; + auto refreshed_params = S3AuthParams::ReadFrom(opener, info); + + // Update this object's credentials + this->access_key_id = refreshed_params.access_key_id; + this->secret_access_key = refreshed_params.secret_access_key; + this->session_token = refreshed_params.session_token; + this->region = refreshed_params.region; + this->endpoint = refreshed_params.endpoint; + this->oauth2_bearer_token = refreshed_params.oauth2_bearer_token; + + return true; + } + + return false; +} + unique_ptr CreateSecret(vector &prefix_paths_p, string &type, string &provider, string &name, S3AuthParams ¶ms) { auto return_value = make_uniq(prefix_paths_p, type, provider, name); @@ -729,124 +758,136 @@ string ParsedS3Url::GetHTTPUrl(S3AuthParams &auth_params, const string &http_que unique_ptr S3FileSystem::PostRequest(FileHandle &handle, string url, HTTPHeaders header_map, string &result, char *buffer_in, idx_t buffer_in_len, string http_params) { - auto auth_params = handle.Cast().auth_params; - auto parsed_s3_url = S3UrlParse(url, auth_params); - string http_url = parsed_s3_url.GetHTTPUrl(auth_params, http_params); - - HTTPHeaders headers; - if (IsGCSRequest(url) && !auth_params.oauth2_bearer_token.empty()) { - // Use bearer token for GCS - headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; - headers["Host"] = parsed_s3_url.host; - headers["Content-Type"] = "application/octet-stream"; - } else { - // Use existing S3 authentication - auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); - headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "POST", auth_params, "", - "", payload_hash, "application/octet-stream"); - } + auto &s3_handle = handle.Cast(); + return ExecuteWithRefresh(s3_handle.auth_params, [&](S3AuthParams &auth_params) { + auto parsed_s3_url = S3UrlParse(url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params, http_params); + + HTTPHeaders headers; + if (IsGCSRequest(url) && !auth_params.oauth2_bearer_token.empty()) { + // Use bearer token for GCS + headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; + headers["Host"] = parsed_s3_url.host; + headers["Content-Type"] = "application/octet-stream"; + } else { + // Use existing S3 authentication + auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); + headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "POST", auth_params, + "", "", payload_hash, "application/octet-stream"); + } - return HTTPFileSystem::PostRequest(handle, http_url, headers, result, buffer_in, buffer_in_len); + return HTTPFileSystem::PostRequest(handle, http_url, headers, result, buffer_in, buffer_in_len); + }); } unique_ptr S3FileSystem::PutRequest(FileHandle &handle, string url, HTTPHeaders header_map, char *buffer_in, idx_t buffer_in_len, string http_params) { - auto auth_params = handle.Cast().auth_params; - auto parsed_s3_url = S3UrlParse(url, auth_params); - string http_url = parsed_s3_url.GetHTTPUrl(auth_params, http_params); - auto content_type = "application/octet-stream"; - - HTTPHeaders headers; - if (IsGCSRequest(url) && !auth_params.oauth2_bearer_token.empty()) { - // Use bearer token for GCS - headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; - headers["Host"] = parsed_s3_url.host; - headers["Content-Type"] = content_type; - } else { - // Use existing S3 authentication - auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); - headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "PUT", auth_params, "", - "", payload_hash, content_type); - } + auto &s3_handle = handle.Cast(); + return ExecuteWithRefresh(s3_handle.auth_params, [&](S3AuthParams &auth_params) { + auto parsed_s3_url = S3UrlParse(url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params, http_params); + auto content_type = "application/octet-stream"; + + HTTPHeaders headers; + if (IsGCSRequest(url) && !auth_params.oauth2_bearer_token.empty()) { + // Use bearer token for GCS + headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; + headers["Host"] = parsed_s3_url.host; + headers["Content-Type"] = content_type; + } else { + // Use existing S3 authentication + auto payload_hash = GetPayloadHash(buffer_in, buffer_in_len); + headers = create_s3_header(parsed_s3_url.path, http_params, parsed_s3_url.host, "s3", "PUT", auth_params, + "", "", payload_hash, content_type); + } - return HTTPFileSystem::PutRequest(handle, http_url, headers, buffer_in, buffer_in_len); + return HTTPFileSystem::PutRequest(handle, http_url, headers, buffer_in, buffer_in_len); + }); } unique_ptr S3FileSystem::HeadRequest(FileHandle &handle, string s3_url, HTTPHeaders header_map) { - auto auth_params = handle.Cast().auth_params; - auto parsed_s3_url = S3UrlParse(s3_url, auth_params); - string http_url = parsed_s3_url.GetHTTPUrl(auth_params); - - HTTPHeaders headers; - if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { - // Use bearer token for GCS - headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; - headers["Host"] = parsed_s3_url.host; - } else { - // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "HEAD", auth_params, "", "", "", ""); - } + auto &s3_handle = handle.Cast(); + return ExecuteWithRefresh(s3_handle.auth_params, [&](S3AuthParams &auth_params) { + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + + HTTPHeaders headers; + if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { + // Use bearer token for GCS + headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; + headers["Host"] = parsed_s3_url.host; + } else { + // Use existing S3 authentication + headers = + create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "HEAD", auth_params, "", "", "", ""); + } - return HTTPFileSystem::HeadRequest(handle, http_url, headers); + return HTTPFileSystem::HeadRequest(handle, http_url, headers); + }); } unique_ptr S3FileSystem::GetRequest(FileHandle &handle, string s3_url, HTTPHeaders header_map) { - auto auth_params = handle.Cast().auth_params; - auto parsed_s3_url = S3UrlParse(s3_url, auth_params); - string http_url = parsed_s3_url.GetHTTPUrl(auth_params); - - HTTPHeaders headers; - if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { - // Use bearer token for GCS - headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; - headers["Host"] = parsed_s3_url.host; - } else { - // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); - } + auto &s3_handle = handle.Cast(); + return ExecuteWithRefresh(s3_handle.auth_params, [&](S3AuthParams &auth_params) { + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + + HTTPHeaders headers; + if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { + // Use bearer token for GCS + headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; + headers["Host"] = parsed_s3_url.host; + } else { + // Use existing S3 authentication + headers = + create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); + } - return HTTPFileSystem::GetRequest(handle, http_url, headers); + return HTTPFileSystem::GetRequest(handle, http_url, headers); + }); } unique_ptr S3FileSystem::GetRangeRequest(FileHandle &handle, string s3_url, HTTPHeaders header_map, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) { - auto auth_params = handle.Cast().auth_params; - auto parsed_s3_url = S3UrlParse(s3_url, auth_params); - string http_url = parsed_s3_url.GetHTTPUrl(auth_params); - - HTTPHeaders headers; - if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { - // Use bearer token for GCS - headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; - headers["Host"] = parsed_s3_url.host; - } else { - // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); - } + auto &s3_handle = handle.Cast(); + return ExecuteWithRefresh(s3_handle.auth_params, [&](S3AuthParams &auth_params) { + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + + HTTPHeaders headers; + if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { + // Use bearer token for GCS + headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; + headers["Host"] = parsed_s3_url.host; + } else { + // Use existing S3 authentication + headers = + create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "GET", auth_params, "", "", "", ""); + } - return HTTPFileSystem::GetRangeRequest(handle, http_url, headers, file_offset, buffer_out, buffer_out_len); + return HTTPFileSystem::GetRangeRequest(handle, http_url, headers, file_offset, buffer_out, buffer_out_len); + }); } unique_ptr S3FileSystem::DeleteRequest(FileHandle &handle, string s3_url, HTTPHeaders header_map) { - auto auth_params = handle.Cast().auth_params; - auto parsed_s3_url = S3UrlParse(s3_url, auth_params); - string http_url = parsed_s3_url.GetHTTPUrl(auth_params); - - HTTPHeaders headers; - if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { - // Use bearer token for GCS - headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; - headers["Host"] = parsed_s3_url.host; - } else { - // Use existing S3 authentication - headers = - create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "DELETE", auth_params, "", "", "", ""); - } + auto &s3_handle = handle.Cast(); + return ExecuteWithRefresh(s3_handle.auth_params, [&](S3AuthParams &auth_params) { + auto parsed_s3_url = S3UrlParse(s3_url, auth_params); + string http_url = parsed_s3_url.GetHTTPUrl(auth_params); + + HTTPHeaders headers; + if (IsGCSRequest(s3_url) && !auth_params.oauth2_bearer_token.empty()) { + // Use bearer token for GCS + headers["Authorization"] = "Bearer " + auth_params.oauth2_bearer_token; + headers["Host"] = parsed_s3_url.host; + } else { + // Use existing S3 authentication + headers = create_s3_header(parsed_s3_url.path, "", parsed_s3_url.host, "s3", "DELETE", auth_params, "", "", + "", ""); + } - return HTTPFileSystem::DeleteRequest(handle, http_url, headers); + return HTTPFileSystem::DeleteRequest(handle, http_url, headers); + }); } unique_ptr S3FileSystem::CreateHandle(const OpenFileInfo &file, FileOpenFlags flags, @@ -865,55 +906,56 @@ unique_ptr S3FileSystem::CreateHandle(const OpenFileInfo &file, S3ConfigParams::ReadFrom(opener)); } +bool S3FileSystem::TryRefreshSecret(const string &path, optional_ptr opener) { + if (!opener) { + return false; + } + + auto context = opener->TryGetClientContext(); + if (!context) { + return false; + } + + bool refreshed_secret = false; + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(*context); + for (const string type : {"s3", "r2", "gcs", "aws"}) { + auto res = context->db->GetSecretManager().LookupSecret(transaction, path, type); + if (res.HasMatch()) { + refreshed_secret |= CreateS3SecretFunctions::TryRefreshS3Secret(*context, *res.secret_entry); + } + } + return refreshed_secret; +} + void S3FileHandle::Initialize(optional_ptr opener) { try { HTTPFileHandle::Initialize(opener); } catch (std::exception &ex) { ErrorData error(ex); - bool refreshed_secret = false; - if (error.Type() == ExceptionType::IO || error.Type() == ExceptionType::HTTP) { - // legacy endpoint (no region) returns 400 - auto context = opener->TryGetClientContext(); - if (context) { - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(*context); - for (const string type : {"s3", "r2", "gcs", "aws"}) { - auto res = context->db->GetSecretManager().LookupSecret(transaction, path, type); - if (res.HasMatch()) { - refreshed_secret |= CreateS3SecretFunctions::TryRefreshS3Secret(*context, *res.secret_entry); - } + auto &extra_info = error.ExtraInfo(); + auto entry = extra_info.find("status_code"); + if (entry != extra_info.end()) { + if (entry->second == "301" || entry->second == "400") { + auto new_region = extra_info.find("header_x-amz-bucket-region"); + string correct_region = ""; + if (new_region != extra_info.end()) { + correct_region = new_region->second; } + auto extra_text = S3FileSystem::GetS3BadRequestError(auth_params, correct_region); + throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); } - } - if (!refreshed_secret) { - auto &extra_info = error.ExtraInfo(); - auto entry = extra_info.find("status_code"); - if (entry != extra_info.end()) { - if (entry->second == "301" || entry->second == "400") { - auto new_region = extra_info.find("header_x-amz-bucket-region"); - string correct_region = ""; - if (new_region != extra_info.end()) { - correct_region = new_region->second; - } - auto extra_text = S3FileSystem::GetS3BadRequestError(auth_params, correct_region); - throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); - } - if (entry->second == "403") { - // 403: FORBIDDEN - string extra_text; - if (IsGCSRequest(path)) { - extra_text = S3FileSystem::GetGCSAuthError(auth_params); - } else { - extra_text = S3FileSystem::GetS3AuthError(auth_params); - } - throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); + if (entry->second == "403") { + // 403: FORBIDDEN + string extra_text; + if (IsGCSRequest(path)) { + extra_text = S3FileSystem::GetGCSAuthError(auth_params); + } else { + extra_text = S3FileSystem::GetS3AuthError(auth_params); } + throw Exception(error.Type(), error.RawMessage() + extra_text, extra_info); } - throw; } - // We have succesfully refreshed a secret: retry initializing with new credentials - FileOpenerInfo info = {path}; - auth_params = S3AuthParams::ReadFrom(opener, info); - HTTPFileHandle::Initialize(opener); + throw; } auto &s3fs = file_system.Cast(); @@ -1069,38 +1111,43 @@ vector S3FileSystem::Glob(const string &glob_pattern, FileOpener * ReadQueryParams(parsed_s3_url.query_param, s3_auth_params); - // Do main listobjectsv2 request - vector s3_keys; - string main_continuation_token; - - // Main paging loop - do { - // main listobject call, may - string response_str = AWSListObjectV2::Request(shared_path, *http_params, s3_auth_params, - main_continuation_token, HTTPState::TryGetState(opener).get()); - main_continuation_token = AWSListObjectV2::ParseContinuationToken(response_str); - AWSListObjectV2::ParseFileList(response_str, s3_keys); - - // Repeat requests until the keys of all common prefixes are parsed. - auto common_prefixes = AWSListObjectV2::ParseCommonPrefix(response_str); - while (!common_prefixes.empty()) { - auto prefix_path = parsed_s3_url.prefix + parsed_s3_url.bucket + '/' + common_prefixes.back(); - common_prefixes.pop_back(); - - // TODO we could optimize here by doing a match on the prefix, if it doesn't match we can skip this prefix - // Paging loop for common prefix requests - string common_prefix_continuation_token; - do { - auto prefix_res = - AWSListObjectV2::Request(prefix_path, *http_params, s3_auth_params, - common_prefix_continuation_token, HTTPState::TryGetState(opener).get()); - AWSListObjectV2::ParseFileList(prefix_res, s3_keys); - auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res); - common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end()); - common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res); - } while (!common_prefix_continuation_token.empty()); - } - } while (!main_continuation_token.empty()); + vector s3_keys = + ExecuteWithRefresh(s3_auth_params, [&](S3AuthParams &auth_params) -> vector { + vector s3_keys; + string main_continuation_token; + + // Main paging loop + do { + // main listobject call, may + string response_str = + AWSListObjectV2::Request(shared_path, *http_params, auth_params, main_continuation_token, + HTTPState::TryGetState(opener).get()); + main_continuation_token = AWSListObjectV2::ParseContinuationToken(response_str); + AWSListObjectV2::ParseFileList(response_str, s3_keys); + + // Repeat requests until the keys of all common prefixes are parsed. + auto common_prefixes = AWSListObjectV2::ParseCommonPrefix(response_str); + while (!common_prefixes.empty()) { + auto prefix_path = parsed_s3_url.prefix + parsed_s3_url.bucket + '/' + common_prefixes.back(); + common_prefixes.pop_back(); + + // TODO we could optimize here by doing a match on the prefix, if it doesn't match we can skip this + // prefix Paging loop for common prefix requests + string common_prefix_continuation_token; + do { + auto prefix_res = AWSListObjectV2::Request(prefix_path, *http_params, auth_params, + common_prefix_continuation_token, + HTTPState::TryGetState(opener).get()); + AWSListObjectV2::ParseFileList(prefix_res, s3_keys); + auto more_prefixes = AWSListObjectV2::ParseCommonPrefix(prefix_res); + common_prefixes.insert(common_prefixes.end(), more_prefixes.begin(), more_prefixes.end()); + common_prefix_continuation_token = AWSListObjectV2::ParseContinuationToken(prefix_res); + } while (!common_prefix_continuation_token.empty()); + } + } while (!main_continuation_token.empty()); + + return s3_keys; + }); vector pattern_splits = StringUtil::Split(parsed_s3_url.key, "/"); vector result;