diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 7d2e09fd2..ea0c94204 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -167,4 +167,4 @@ jobs: meson compile -C builddir - name: Test Iceberg run: | - meson test -C builddir --timeout-multiplier 0 + meson test -C builddir --timeout-multiplier 0 --print-errorlogs diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 18cf70bdb..ef9b23fa0 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -101,6 +101,7 @@ set(ICEBERG_SOURCES update/update_snapshot_reference.cc update/update_sort_order.cc update/update_statistics.cc + util/base64.cc util/bucket_util.cc util/content_file_util.cc util/conversions.cc diff --git a/src/iceberg/catalog/rest/CMakeLists.txt b/src/iceberg/catalog/rest/CMakeLists.txt index b862bc869..8fb2e93c0 100644 --- a/src/iceberg/catalog/rest/CMakeLists.txt +++ b/src/iceberg/catalog/rest/CMakeLists.txt @@ -23,6 +23,7 @@ set(ICEBERG_REST_SOURCES auth/auth_properties.cc auth/auth_session.cc auth/oauth2_util.cc + auth/token_refresh_scheduler.cc catalog_properties.cc endpoint.cc error_handlers.cc diff --git a/src/iceberg/catalog/rest/auth/auth_manager.cc b/src/iceberg/catalog/rest/auth/auth_manager.cc index 47370bd3b..247c1d538 100644 --- a/src/iceberg/catalog/rest/auth/auth_manager.cc +++ b/src/iceberg/catalog/rest/auth/auth_manager.cc @@ -25,8 +25,8 @@ #include "iceberg/catalog/rest/auth/auth_properties.h" #include "iceberg/catalog/rest/auth/auth_session.h" #include "iceberg/catalog/rest/auth/oauth2_util.h" +#include "iceberg/util/base64.h" #include "iceberg/util/macros.h" -#include "iceberg/util/transform_util.h" namespace iceberg::rest::auth { @@ -83,8 +83,7 @@ class BasicAuthManager : public AuthManager { "Missing required property '{}'", AuthProperties::kBasicPassword); std::string credential = username_it->second + ":" + password_it->second; return AuthSession::MakeDefault( - {{std::string(kAuthorizationHeader), - "Basic " + TransformUtil::Base64Encode(credential)}}); + {{std::string(kAuthorizationHeader), "Basic " + Base64::Encode(credential)}}); } }; @@ -130,7 +129,8 @@ class OAuth2Manager : public AuthManager { init_token_response_.reset(); return AuthSession::MakeOAuth2(token_response, config.oauth2_server_uri(), config.client_id(), config.client_secret(), - config.scope(), client); + config.scope(), config.keep_refreshed(), + config.optional_oauth_params(), client); } // If token is provided, use it directly. @@ -143,10 +143,10 @@ class OAuth2Manager : public AuthManager { auto base_session = AuthSession::MakeDefault(AuthHeaders(config.token())); OAuthTokenResponse token_response; ICEBERG_ASSIGN_OR_RAISE(token_response, FetchToken(client, *base_session, config)); - // TODO(lishuxu): should we directly pass config to the MakeOAuth2 call? return AuthSession::MakeOAuth2(token_response, config.oauth2_server_uri(), config.client_id(), config.client_secret(), - config.scope(), client); + config.scope(), config.keep_refreshed(), + config.optional_oauth_params(), client); } return AuthSession::MakeDefault({}); diff --git a/src/iceberg/catalog/rest/auth/auth_session.cc b/src/iceberg/catalog/rest/auth/auth_session.cc index 7251dc4a9..fcb649972 100644 --- a/src/iceberg/catalog/rest/auth/auth_session.cc +++ b/src/iceberg/catalog/rest/auth/auth_session.cc @@ -19,9 +19,16 @@ #include "iceberg/catalog/rest/auth/auth_session.h" +#include +#include +#include +#include #include +#include "iceberg/catalog/rest/auth/auth_properties.h" #include "iceberg/catalog/rest/auth/oauth2_util.h" +#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h" +#include "iceberg/catalog/rest/http_client.h" namespace iceberg::rest::auth { @@ -44,6 +51,189 @@ class DefaultAuthSession : public AuthSession { std::unordered_map headers_; }; +/// \brief OAuth2 session with automatic token refresh. +class OAuth2AuthSession : public AuthSession, + public std::enable_shared_from_this { + public: + struct Config { + std::string token_endpoint; + std::string client_id; + std::string client_secret; + std::string scope; + std::unordered_map optional_oauth_params; + bool keep_refreshed; + }; + + /// \brief Create an OAuth2 session and optionally schedule refresh. + static std::shared_ptr Make(const OAuthTokenResponse& initial_token, + Config config, HttpClient& client) { + auto session = std::shared_ptr( + new OAuth2AuthSession(std::move(config), client)); + session->SetInitialToken(initial_token); + return session; + } + + Status Authenticate(std::unordered_map& headers) override { + std::shared_lock lock(mutex_); + for (const auto& [key, value] : headers_) { + headers.try_emplace(key, value); + } + return {}; + } + + Status Close() override { + bool expected = false; + if (!closed_.compare_exchange_strong(expected, true)) { + return {}; // Already closed + } + TokenRefreshScheduler::Instance().Cancel(scheduled_task_id_.load()); + return {}; + } + + private: + OAuth2AuthSession(Config config, HttpClient& client) + : config_(std::move(config)), client_(client) {} + + void SetInitialToken(const OAuthTokenResponse& token_response) { + token_ = token_response.access_token; + headers_ = {{std::string(kAuthorizationHeader), std::string(kBearerPrefix) + token_}}; + + // Determine expiration time + if (token_response.expires_in_secs.has_value()) { + expires_at_ = std::chrono::steady_clock::now() + + std::chrono::seconds(*token_response.expires_in_secs); + } else if (auto exp_ms = ExpiresAtMillis(token_); exp_ms.has_value()) { + // Convert absolute epoch millis to steady_clock time_point + auto now_sys = std::chrono::system_clock::now(); + auto now_steady = std::chrono::steady_clock::now(); + auto exp_sys = + std::chrono::system_clock::time_point(std::chrono::milliseconds(*exp_ms)); + expires_at_ = now_steady + (exp_sys - now_sys); + } + + if (config_.keep_refreshed && + expires_at_ != std::chrono::steady_clock::time_point{}) { + ScheduleRefresh(); + } + } + + void DoRefresh() { DoRefreshAttempt(0, std::chrono::milliseconds(200)); } + + /// \brief Single refresh attempt. On failure, schedules a retry via the + /// scheduler (non-blocking) instead of sleeping on the worker thread. + void DoRefreshAttempt(int attempt, std::chrono::milliseconds backoff) { + static constexpr int kMaxRetries = 5; + static constexpr auto kMaxBackoff = std::chrono::milliseconds(10'000); + + if (closed_.load()) return; + + // Build credential and properties once (invariant across retries) + std::string credential = config_.client_id.empty() + ? config_.client_secret + : config_.client_id + ":" + config_.client_secret; + + // Use an empty session for the refresh request (no auth headers — + // avoids circular dependency of using an expired token to refresh itself) + auto empty_session = AuthSession::MakeDefault({}); + + AuthProperties props; + props.Set(AuthProperties::kCredential, credential); + props.Set(AuthProperties::kScope, config_.scope); + props.Set(AuthProperties::kOAuth2ServerUri, config_.token_endpoint); + for (const auto& [key, value] : config_.optional_oauth_params) { + props.mutable_configs().insert_or_assign(key, value); + } + + auto result = FetchToken(client_, *empty_session, props); + if (result.has_value()) { + auto& response = result.value(); + { + std::unique_lock lock(mutex_); + token_ = response.access_token; + headers_ = { + {std::string(kAuthorizationHeader), std::string(kBearerPrefix) + token_}}; + + // Reset before deriving new expiry + expires_at_ = std::chrono::steady_clock::time_point{}; + + if (response.expires_in_secs.has_value()) { + expires_at_ = std::chrono::steady_clock::now() + + std::chrono::seconds(*response.expires_in_secs); + } else if (auto exp_ms = ExpiresAtMillis(token_); exp_ms.has_value()) { + auto now_sys = std::chrono::system_clock::now(); + auto now_steady = std::chrono::steady_clock::now(); + auto exp_sys = + std::chrono::system_clock::time_point(std::chrono::milliseconds(*exp_ms)); + expires_at_ = now_steady + (exp_sys - now_sys); + } + } + // Note: ScheduleRefresh must be called outside the lock. + ScheduleRefresh(); + return; // Success + } + + // Schedule retry with exponential backoff (non-blocking) + if (attempt + 1 < kMaxRetries) { + auto next_backoff = + std::min(std::chrono::duration_cast(backoff * 2), + kMaxBackoff); + std::weak_ptr weak_self = shared_from_this(); + TokenRefreshScheduler::Instance().Schedule( + backoff, + [weak_self = std::move(weak_self), next_attempt = attempt + 1, next_backoff] { + if (auto self = weak_self.lock()) { + self->DoRefreshAttempt(next_attempt, next_backoff); + } + }); + } + // All retries exhausted — stop refreshing silently. + // Next request will use the expired token; server returns 401. + } + + /// \brief Schedule the next token refresh based on expiration time. + /// + /// Must be called outside any lock on mutex_ (CalculateRefreshDelay + /// acquires shared_lock internally). + void ScheduleRefresh() { + if (!config_.keep_refreshed || closed_.load()) return; + + auto delay = CalculateRefreshDelay(); + if (delay <= std::chrono::milliseconds::zero()) return; + + std::weak_ptr weak_self = shared_from_this(); + auto new_id = TokenRefreshScheduler::Instance().Schedule( + delay, [weak_self = std::move(weak_self)] { + if (auto self = weak_self.lock()) { + self->DoRefresh(); + } + }); + scheduled_task_id_.store(new_id); + } + + std::chrono::milliseconds CalculateRefreshDelay() const { + std::shared_lock lock(mutex_); + auto now = std::chrono::steady_clock::now(); + if (expires_at_ <= now) return std::chrono::milliseconds::zero(); + + auto expires_in = + std::chrono::duration_cast(expires_at_ - now); + // Refresh window: 10% of remaining time, capped at 5 minutes + auto refresh_window = std::min(expires_in / 10, std::chrono::milliseconds(300'000)); + auto wait_time = expires_in - refresh_window; + return std::max(wait_time, std::chrono::milliseconds(10)); + } + + mutable std::shared_mutex mutex_; // protects token_, headers_, expires_at_ + std::string token_; + std::unordered_map headers_; + std::chrono::steady_clock::time_point expires_at_{}; + + Config config_; + HttpClient& client_; + std::atomic scheduled_task_id_{0}; + std::atomic closed_{false}; +}; + } // namespace std::shared_ptr AuthSession::MakeDefault( @@ -52,12 +242,20 @@ std::shared_ptr AuthSession::MakeDefault( } std::shared_ptr AuthSession::MakeOAuth2( - const OAuthTokenResponse& initial_token, const std::string& /*token_endpoint*/, - const std::string& /*client_id*/, const std::string& /*client_secret*/, - const std::string& /*scope*/, HttpClient& /*client*/) { - // TODO(lishuxu): Create OAuth2AuthSession with auto-refresh support. - return MakeDefault({{std::string(kAuthorizationHeader), - std::string(kBearerPrefix) + initial_token.access_token}}); + const OAuthTokenResponse& initial_token, const std::string& token_endpoint, + const std::string& client_id, const std::string& client_secret, + const std::string& scope, bool keep_refreshed, + const std::unordered_map& optional_oauth_params, + HttpClient& client) { + OAuth2AuthSession::Config config{ + .token_endpoint = token_endpoint, + .client_id = client_id, + .client_secret = client_secret, + .scope = scope, + .optional_oauth_params = optional_oauth_params, + .keep_refreshed = keep_refreshed, + }; + return OAuth2AuthSession::Make(initial_token, std::move(config), client); } } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/auth_session.h b/src/iceberg/catalog/rest/auth/auth_session.h index 26b93877b..16c04157d 100644 --- a/src/iceberg/catalog/rest/auth/auth_session.h +++ b/src/iceberg/catalog/rest/auth/auth_session.h @@ -84,14 +84,16 @@ class ICEBERG_REST_EXPORT AuthSession { /// \param client_id OAuth2 client ID for refresh requests. /// \param client_secret OAuth2 client secret for re-fetch if refresh fails. /// \param scope OAuth2 scope for refresh requests. + /// \param keep_refreshed Whether to schedule automatic token refresh. + /// \param optional_oauth_params Optional OAuth params (audience, resource) for refresh. /// \param client HTTP client for making refresh requests. /// \return A new session that manages token lifecycle automatically. - static std::shared_ptr MakeOAuth2(const OAuthTokenResponse& initial_token, - const std::string& token_endpoint, - const std::string& client_id, - const std::string& client_secret, - const std::string& scope, - HttpClient& client); + static std::shared_ptr MakeOAuth2( + const OAuthTokenResponse& initial_token, const std::string& token_endpoint, + const std::string& client_id, const std::string& client_secret, + const std::string& scope, bool keep_refreshed, + const std::unordered_map& optional_oauth_params, + HttpClient& client); }; } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/oauth2_util.cc b/src/iceberg/catalog/rest/auth/oauth2_util.cc index 3d209d2bd..b1e2402b3 100644 --- a/src/iceberg/catalog/rest/auth/oauth2_util.cc +++ b/src/iceberg/catalog/rest/auth/oauth2_util.cc @@ -28,6 +28,7 @@ #include "iceberg/catalog/rest/http_client.h" #include "iceberg/catalog/rest/json_serde_internal.h" #include "iceberg/json_serde_internal.h" +#include "iceberg/util/base64.h" #include "iceberg/util/macros.h" namespace iceberg::rest::auth { @@ -74,4 +75,49 @@ Result FetchToken(HttpClient& client, AuthSession& session, return token_response; } +std::optional ExpiresAtMillis(std::string_view token) { + if (token.empty()) { + return std::nullopt; + } + + // A JWT has exactly 3 dot-separated parts: header.payload.signature + auto first_dot = token.find('.'); + if (first_dot == std::string_view::npos) { + return std::nullopt; + } + auto second_dot = token.find('.', first_dot + 1); + if (second_dot == std::string_view::npos) { + return std::nullopt; + } + // Ensure there are exactly 3 parts (no additional dots after the signature). + // Note: JWE tokens have 5 segments — they are intentionally not supported here + // and will return nullopt (graceful degradation to not scheduling refresh). + if (token.find('.', second_dot + 1) != std::string_view::npos) { + return std::nullopt; + } + + // Extract and decode the payload (second part). + // Note: Base64::UrlDecode returns an error on invalid input, and Ok("") on empty input. + // A valid JWT payload is never empty (at minimum "{}"), so empty result reliably + // indicates the token is not a JWT we can parse. + std::string_view payload_b64 = token.substr(first_dot + 1, second_dot - first_dot - 1); + auto payload_result = Base64::UrlDecode(payload_b64); + if (!payload_result.has_value() || payload_result->empty()) { + return std::nullopt; + } + const std::string& payload = *payload_result; + + // Parse JSON and extract "exp" claim + auto json = nlohmann::json::parse(payload, nullptr, false); + if (json.is_discarded() || !json.is_object()) { + return std::nullopt; + } + auto it = json.find("exp"); + if (it == json.end() || !it->is_number()) { + return std::nullopt; + } + auto exp_seconds = static_cast(it->get()); + return exp_seconds * 1000; // Convert seconds to milliseconds +} + } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/oauth2_util.h b/src/iceberg/catalog/rest/auth/oauth2_util.h index 39dd12964..428ebc385 100644 --- a/src/iceberg/catalog/rest/auth/oauth2_util.h +++ b/src/iceberg/catalog/rest/auth/oauth2_util.h @@ -19,6 +19,8 @@ #pragma once +#include +#include #include #include #include @@ -53,4 +55,14 @@ ICEBERG_REST_EXPORT Result FetchToken( ICEBERG_REST_EXPORT std::unordered_map AuthHeaders( const std::string& token); +/// \brief Extract expiration time from a JWT token. +/// +/// Decodes the JWT payload (base64url) and reads the "exp" claim. +/// Returns std::nullopt if the token is not a valid JWT or has no "exp" claim. +/// +/// \param token A token string. If it is a JWT (three dot-separated base64url +/// segments), the "exp" claim is extracted from the payload. +/// \return Expiration time as milliseconds since epoch, or std::nullopt. +ICEBERG_REST_EXPORT std::optional ExpiresAtMillis(std::string_view token); + } // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/token_refresh_scheduler.cc b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.cc new file mode 100644 index 000000000..d66ed7d90 --- /dev/null +++ b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.cc @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h" + +#include +#include + +namespace iceberg::rest::auth { + +TokenRefreshScheduler& TokenRefreshScheduler::Instance() { + // Intentionally leaked to avoid destruction-order races at process exit. + static auto* instance = new TokenRefreshScheduler(); + return *instance; +} + +TokenRefreshScheduler::TokenRefreshScheduler() : worker_([this] { Run(); }) {} + +TokenRefreshScheduler::~TokenRefreshScheduler() { Shutdown(); } + +uint64_t TokenRefreshScheduler::Schedule(std::chrono::milliseconds delay, + std::function callback) { + std::lock_guard lock(mutex_); + if (shutdown_) { + return 0; + } + uint64_t id = next_id_++; + tasks_.push_back(Task{.id = id, + .fire_at = std::chrono::steady_clock::now() + delay, + .callback = std::move(callback)}); + cv_.notify_one(); + return id; +} + +void TokenRefreshScheduler::Cancel(uint64_t handle) { + if (handle == 0) return; + std::lock_guard lock(mutex_); + std::erase_if(tasks_, [handle](const Task& t) { return t.id == handle; }); +} + +void TokenRefreshScheduler::Shutdown() { + { + std::lock_guard lock(mutex_); + if (shutdown_) return; + shutdown_ = true; + tasks_.clear(); + } + cv_.notify_one(); + if (worker_.joinable()) { + worker_.join(); + } +} + +void TokenRefreshScheduler::Run() { + while (true) { + std::function callback; + + { + std::unique_lock lock(mutex_); + + if (tasks_.empty() && !shutdown_) { + // Wait until a task is added or shutdown is requested + cv_.wait(lock, [this] { return !tasks_.empty() || shutdown_; }); + } + + if (shutdown_) break; + if (tasks_.empty()) continue; + + // Find the task with the earliest fire_at + auto earliest_it = std::ranges::min_element( + tasks_, [](const Task& a, const Task& b) { return a.fire_at < b.fire_at; }); + + auto fire_at = earliest_it->fire_at; + auto target_id = earliest_it->id; + + // Wait until fire_at or until woken (new task, cancel, or shutdown). + // Note: The predicate does O(n) scan on each spurious wakeup. This is + // acceptable for the expected task count (< 10). If task count grows + // significantly, consider replacing vector with a priority queue. + cv_.wait_until(lock, fire_at, [&] { + // Wake up if: shutdown, task list changed, or time is up + if (shutdown_) return true; + if (tasks_.empty()) return true; + // Check if the earliest task has changed (new task added or cancelled) + auto new_earliest = std::ranges::min_element( + tasks_, [](const Task& a, const Task& b) { return a.fire_at < b.fire_at; }); + return new_earliest->id != target_id; + }); + + if (shutdown_) break; + + // If we were woken because the earliest task changed, loop again + auto now = std::chrono::steady_clock::now(); + auto due_it = + std::ranges::find_if(tasks_, [now](const Task& t) { return t.fire_at <= now; }); + if (due_it == tasks_.end()) continue; + + callback = std::move(due_it->callback); + tasks_.erase(due_it); + } + + // Execute callback outside the lock + if (callback) { + callback(); + } + } +} + +} // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/auth/token_refresh_scheduler.h b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.h new file mode 100644 index 000000000..02dc0e14f --- /dev/null +++ b/src/iceberg/catalog/rest/auth/token_refresh_scheduler.h @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "iceberg/catalog/rest/iceberg_rest_export.h" + +/// \file iceberg/catalog/rest/auth/token_refresh_scheduler.h +/// \brief Global scheduler for OAuth2 token refresh tasks. + +namespace iceberg::rest::auth { + +/// \brief A process-global scheduler for delayed token refresh tasks. +/// +/// Uses a single background thread that sleeps until the next task is due. +/// All OAuth2AuthSession instances share this scheduler. Tasks are lightweight +/// (a single HTTP POST to refresh a token), so one thread is sufficient. +/// +/// Thread safety: All public methods are thread-safe. +/// +/// TODO(lishuxu): Migrate to the shared thread pool abstraction once available +/// (see https://github.com/apache/iceberg-cpp/pull/646#discussion_r3304315308). +class ICEBERG_REST_EXPORT TokenRefreshScheduler { + public: + /// \brief Get the global singleton instance. + /// + /// Lazily created on first access and intentionally leaked: the worker + /// thread is reclaimed by the OS at process exit. Tests needing + /// deterministic shutdown should use a local instance. + static TokenRefreshScheduler& Instance(); + + /// \brief Schedule a callback to run after a delay. + /// + /// \param delay Time to wait before executing the callback. + /// \param callback Function to execute when the delay expires. + /// \return A unique handle that can be used to cancel the task. + uint64_t Schedule(std::chrono::milliseconds delay, std::function callback); + + /// \brief Cancel a previously scheduled task. + /// + /// If the task has already fired or does not exist, this is a no-op. + /// + /// \param handle The handle returned by Schedule(). + void Cancel(uint64_t handle); + + /// \brief Shutdown the scheduler, cancelling all pending tasks. + /// + /// After shutdown, Schedule() calls are no-ops (return 0). + /// This is called automatically on destruction. + /// + /// WARNING: Do not call this on the global Instance() unless you intend to + /// permanently stop all token refresh for the entire process. This is mainly + /// useful for testing with locally-constructed scheduler instances. + void Shutdown(); + + ~TokenRefreshScheduler(); + + // Non-copyable, non-movable + TokenRefreshScheduler(const TokenRefreshScheduler&) = delete; + TokenRefreshScheduler& operator=(const TokenRefreshScheduler&) = delete; + TokenRefreshScheduler(TokenRefreshScheduler&&) = delete; + TokenRefreshScheduler& operator=(TokenRefreshScheduler&&) = delete; + + /// \brief Construct a scheduler (prefer Instance() for production use). + /// + /// This constructor is public to allow testing with isolated instances. + /// In production code, use Instance() to get the global singleton. + TokenRefreshScheduler(); + + private: + /// \brief Worker loop that processes tasks. + void Run(); + + struct Task { + uint64_t id; + std::chrono::steady_clock::time_point fire_at; + std::function callback; + }; + + std::mutex mutex_; + std::condition_variable cv_; + std::vector tasks_; + uint64_t next_id_ = 1; // 0 is reserved as "invalid handle" + bool shutdown_ = false; + std::thread worker_; +}; + +} // namespace iceberg::rest::auth diff --git a/src/iceberg/catalog/rest/meson.build b/src/iceberg/catalog/rest/meson.build index cc05576a2..f3eae6d45 100644 --- a/src/iceberg/catalog/rest/meson.build +++ b/src/iceberg/catalog/rest/meson.build @@ -21,6 +21,7 @@ iceberg_rest_sources = files( 'auth/auth_properties.cc', 'auth/auth_session.cc', 'auth/oauth2_util.cc', + 'auth/token_refresh_scheduler.cc', 'catalog_properties.cc', 'endpoint.cc', 'error_handlers.cc', diff --git a/src/iceberg/meson.build b/src/iceberg/meson.build index a5a60b605..02ce046b3 100644 --- a/src/iceberg/meson.build +++ b/src/iceberg/meson.build @@ -123,6 +123,7 @@ iceberg_sources = files( 'update/update_snapshot_reference.cc', 'update/update_sort_order.cc', 'update/update_statistics.cc', + 'util/base64.cc', 'util/bucket_util.cc', 'util/content_file_util.cc', 'util/conversions.cc', diff --git a/src/iceberg/test/CMakeLists.txt b/src/iceberg/test/CMakeLists.txt index 997d18354..3f55c38d6 100644 --- a/src/iceberg/test/CMakeLists.txt +++ b/src/iceberg/test/CMakeLists.txt @@ -119,6 +119,7 @@ add_iceberg_test(json_serde_test add_iceberg_test(util_test USE_DATA SOURCES + base64_test.cc bucket_util_test.cc config_test.cc data_file_set_test.cc diff --git a/src/iceberg/test/auth_manager_test.cc b/src/iceberg/test/auth_manager_test.cc index bd06fee3f..b5ad56008 100644 --- a/src/iceberg/test/auth_manager_test.cc +++ b/src/iceberg/test/auth_manager_test.cc @@ -34,6 +34,7 @@ #include "iceberg/catalog/rest/json_serde_internal.h" #include "iceberg/json_serde_internal.h" #include "iceberg/test/matchers.h" +#include "iceberg/util/base64.h" namespace iceberg::rest::auth { @@ -358,4 +359,311 @@ TEST_F(AuthManagerTest, OAuthTokenResponseNATokenType) { EXPECT_EQ(result->token_type, "N_A"); } +// ---- ExpiresAtMillis tests ---- + +// Helper: build a minimal JWT with a given payload JSON string. +// JWT = base64url(header) + "." + base64url(payload) + "." + base64url(signature) +namespace { + +std::string MakeJwt(const std::string& payload_json) { + std::string header = R"({"alg":"HS256","typ":"JWT"})"; + std::string signature = "test-signature"; + return Base64::UrlEncode(header) + "." + Base64::UrlEncode(payload_json) + "." + + Base64::UrlEncode(signature); +} + +} // namespace + +// Verifies ExpiresAtMillis extracts exp claim from a valid JWT +TEST_F(AuthManagerTest, ExpiresAtMillisValidJwt) { + // exp = 1700000000 (seconds since epoch) + std::string token = MakeJwt(R"({"sub":"user","exp":1700000000})"); + auto result = ExpiresAtMillis(token); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 1700000000LL * 1000); // milliseconds +} + +// Verifies ExpiresAtMillis handles large exp values correctly +TEST_F(AuthManagerTest, ExpiresAtMillisLargeExp) { + std::string token = MakeJwt(R"({"exp":2000000000})"); + auto result = ExpiresAtMillis(token); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 2000000000LL * 1000); +} + +// Verifies ExpiresAtMillis truncates floating-point exp to integer +TEST_F(AuthManagerTest, ExpiresAtMillisFloatExp) { + std::string token = MakeJwt(R"({"exp":1700000000.5})"); + auto result = ExpiresAtMillis(token); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 1700000000LL * 1000); // truncated to int +} + +// Verifies ExpiresAtMillis returns nullopt for non-JWT token without dots +TEST_F(AuthManagerTest, ExpiresAtMillisNonJwtNoDots) { + auto result = ExpiresAtMillis("just-a-plain-token"); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for token with only one dot +TEST_F(AuthManagerTest, ExpiresAtMillisOneDot) { + auto result = ExpiresAtMillis("part1.part2"); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for token with too many segments +TEST_F(AuthManagerTest, ExpiresAtMillisFourParts) { + auto result = ExpiresAtMillis("a.b.c.d"); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt when JWT has no exp claim +TEST_F(AuthManagerTest, ExpiresAtMillisNoExpClaim) { + std::string token = MakeJwt(R"({"sub":"user","iat":1700000000})"); + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt when exp is not a number +TEST_F(AuthManagerTest, ExpiresAtMillisExpNotInteger) { + std::string token = MakeJwt(R"({"exp":"not-a-number"})"); + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for malformed base64 payload +TEST_F(AuthManagerTest, ExpiresAtMillisMalformedBase64) { + // Use invalid base64url characters in the payload part + std::string token = "eyJhbGciOiJIUzI1NiJ9.!!!invalid!!!.signature"; + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt for empty token +TEST_F(AuthManagerTest, ExpiresAtMillisEmptyToken) { + auto result = ExpiresAtMillis(""); + EXPECT_FALSE(result.has_value()); +} + +// Verifies ExpiresAtMillis returns nullopt when payload is not valid JSON +TEST_F(AuthManagerTest, ExpiresAtMillisInvalidJson) { + std::string header = R"({"alg":"HS256"})"; + std::string invalid_json = "this is not json"; + std::string token = + Base64::UrlEncode(header) + "." + Base64::UrlEncode(invalid_json) + "." + "sig"; + auto result = ExpiresAtMillis(token); + EXPECT_FALSE(result.has_value()); +} + +// ---- TokenRefreshScheduler tests ---- + +} // namespace iceberg::rest::auth + +#include +#include +#include +#include + +#include "iceberg/catalog/rest/auth/token_refresh_scheduler.h" +#include "iceberg/catalog/rest/types.h" + +namespace iceberg::rest::auth { + +// Verifies that a scheduled task fires after the specified delay +TEST(TokenRefreshSchedulerTest, ScheduleFiresAfterDelay) { + TokenRefreshScheduler scheduler; + std::atomic fired{false}; + + scheduler.Schedule(std::chrono::milliseconds(50), [&] { fired.store(true); }); + + // Should not have fired immediately + EXPECT_FALSE(fired.load()); + + // Wait enough time for it to fire + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + EXPECT_TRUE(fired.load()); + + scheduler.Shutdown(); +} + +// Verifies that cancelling a task prevents it from executing +TEST(TokenRefreshSchedulerTest, CancelPreventsExecution) { + TokenRefreshScheduler scheduler; + std::atomic fired{false}; + + auto handle = + scheduler.Schedule(std::chrono::milliseconds(100), [&] { fired.store(true); }); + + // Cancel before it fires + scheduler.Cancel(handle); + + // Wait past the scheduled time + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + EXPECT_FALSE(fired.load()); + + scheduler.Shutdown(); +} + +// Verifies that shutdown with pending tasks does not crash +TEST(TokenRefreshSchedulerTest, ShutdownWithPendingTasks) { + TokenRefreshScheduler scheduler; + std::atomic fired{false}; + + scheduler.Schedule(std::chrono::milliseconds(5000), [&] { fired.store(true); }); + + // Shutdown immediately — should not crash and task should not fire + scheduler.Shutdown(); + + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + EXPECT_FALSE(fired.load()); +} + +// Verifies that Schedule after shutdown returns invalid handle (0) +TEST(TokenRefreshSchedulerTest, ScheduleAfterShutdownIsNoop) { + TokenRefreshScheduler scheduler; + scheduler.Shutdown(); + + auto handle = scheduler.Schedule(std::chrono::milliseconds(10), [] {}); + EXPECT_EQ(0u, handle); +} + +// Verifies that cancelling an invalid handle does not crash +TEST(TokenRefreshSchedulerTest, CancelInvalidHandleIsNoop) { + TokenRefreshScheduler scheduler; + // Should not crash + scheduler.Cancel(0); + scheduler.Cancel(999); + scheduler.Shutdown(); +} + +// ---- OAuth2AuthSession integration tests ---- + +// Verifies that MakeOAuth2 creates a session with correct initial Bearer header +TEST(OAuth2AuthSessionTest, InitialTokenIsUsed) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "initial-token-123"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 3600; + + // Create session (refresh will fail since there's no real server, but + // initial token should work) + auto session = + AuthSession::MakeOAuth2(token_response, "http://localhost/oauth/tokens", + "client_id", "client_secret", "catalog", true, {}, client); + + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer initial-token-123"); + + session->Close(); +} + +// Verifies that session without expiration does not schedule refresh +TEST(OAuth2AuthSessionTest, NoExpirationNoRefresh) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "static-token"; + token_response.token_type = "bearer"; + // No expires_in_secs set — token is not a JWT either + + auto session = AuthSession::MakeOAuth2(token_response, "http://localhost/oauth/tokens", + "id", "secret", "catalog", true, {}, client); + + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer static-token"); + + // Wait a bit — no crash, no refresh attempt + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + headers.clear(); + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer static-token"); + + session->Close(); +} + +// Verifies that Close prevents further refresh callbacks +TEST(OAuth2AuthSessionTest, CloseStopsRefresh) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "token-before-close"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 1; // Expires in 1 second + + auto session = AuthSession::MakeOAuth2(token_response, "http://127.0.0.1:1/tokens", + "id", "secret", "catalog", true, {}, client); + + // Close immediately — should cancel the scheduled refresh + session->Close(); + + // Wait past expiration + refresh window + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + + // Token should still be the original (no refresh happened) + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer token-before-close"); +} + +// Verifies that concurrent Authenticate calls are thread-safe +TEST(OAuth2AuthSessionTest, ConcurrentAuthenticate) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "concurrent-token"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 3600; + + auto session = AuthSession::MakeOAuth2(token_response, "http://localhost/oauth/tokens", + "id", "secret", "catalog", true, {}, client); + + // Launch multiple threads calling Authenticate concurrently + std::vector threads; + std::atomic success_count{0}; + + for (int i = 0; i < 10; ++i) { + threads.emplace_back([&] { + for (int j = 0; j < 100; ++j) { + std::unordered_map headers; + auto status = session->Authenticate(headers); + if (status.has_value()) { + success_count.fetch_add(1); + } + } + }); + } + + for (auto& t : threads) { + t.join(); + } + + EXPECT_EQ(1000, success_count.load()); + session->Close(); +} + +// Verifies that session still returns last known token after all refresh retries fail +TEST(OAuth2AuthSessionTest, RefreshFailureKeepsLastToken) { + HttpClient client({}); + OAuthTokenResponse token_response; + token_response.access_token = "original-token"; + token_response.token_type = "bearer"; + token_response.expires_in_secs = 1; // Very short — will trigger refresh soon + + auto session = + AuthSession::MakeOAuth2(token_response, "http://127.0.0.1:1/tokens", // Will fail + "id", "secret", "catalog", true, {}, client); + + // Wait for refresh to be attempted and fail (all retries) + // With non-blocking retries: 200ms + 400ms + 800ms + 1600ms ≈ 3s total + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + + // Session should still return the original token (no crash) + std::unordered_map headers; + ASSERT_THAT(session->Authenticate(headers), IsOk()); + EXPECT_EQ(headers["Authorization"], "Bearer original-token"); + + session->Close(); +} + } // namespace iceberg::rest::auth diff --git a/src/iceberg/test/base64_test.cc b/src/iceberg/test/base64_test.cc new file mode 100644 index 000000000..8dbdc71ab --- /dev/null +++ b/src/iceberg/test/base64_test.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/util/base64.h" + +#include + +#include "iceberg/test/matchers.h" + +namespace iceberg { + +TEST(Base64Test, Encode) { + // Empty string + EXPECT_EQ("", Base64::Encode("")); + + // Single character + EXPECT_EQ("YQ==", Base64::Encode("a")); + EXPECT_EQ("YWI=", Base64::Encode("ab")); + EXPECT_EQ("YWJj", Base64::Encode("abc")); + + // Multiple of 3 characters + EXPECT_EQ("YWJjZGU=", Base64::Encode("abcde")); + EXPECT_EQ("YWJjZGVm", Base64::Encode("abcdef")); + + // Common strings + EXPECT_EQ("U29tZSBkYXRhIHdpdGggY2hhcmFjdGVycw==", + Base64::Encode("Some data with characters")); + EXPECT_EQ("aGVsbG8=", Base64::Encode("hello")); + EXPECT_EQ("dGVzdCBzdHJpbmc=", Base64::Encode("test string")); + + // Unicode + EXPECT_EQ("8J+EgA==", Base64::Encode("\xF0\x9F\x84\x80")); + // Null byte + EXPECT_EQ("AA==", Base64::Encode({"\x00", 1})); +} + +TEST(Base64Test, Decode) { + // Empty string + ICEBERG_UNWRAP_OR_FAIL(auto empty, Base64::Decode("")); + EXPECT_EQ("", empty); + + // Round-trip with Base64::Encode + ICEBERG_UNWRAP_OR_FAIL(auto a, Base64::Decode("YQ==")); + EXPECT_EQ("a", a); + ICEBERG_UNWRAP_OR_FAIL(auto ab, Base64::Decode("YWI=")); + EXPECT_EQ("ab", ab); + ICEBERG_UNWRAP_OR_FAIL(auto abc, Base64::Decode("YWJj")); + EXPECT_EQ("abc", abc); + ICEBERG_UNWRAP_OR_FAIL(auto abcde, Base64::Decode("YWJjZGU=")); + EXPECT_EQ("abcde", abcde); + ICEBERG_UNWRAP_OR_FAIL(auto abcdef, Base64::Decode("YWJjZGVm")); + EXPECT_EQ("abcdef", abcdef); + ICEBERG_UNWRAP_OR_FAIL(auto hello, Base64::Decode("aGVsbG8=")); + EXPECT_EQ("hello", hello); + ICEBERG_UNWRAP_OR_FAIL(auto test_str, Base64::Decode("dGVzdCBzdHJpbmc=")); + EXPECT_EQ("test string", test_str); + + // Without padding (should still work) + ICEBERG_UNWRAP_OR_FAIL(auto a2, Base64::Decode("YQ")); + EXPECT_EQ("a", a2); + ICEBERG_UNWRAP_OR_FAIL(auto ab2, Base64::Decode("YWI")); + EXPECT_EQ("ab", ab2); + + // Invalid characters return error + EXPECT_THAT(Base64::Decode("!!!"), IsError(ErrorKind::kInvalidArgument)); +} + +TEST(Base64Test, UrlEncode) { + // Empty string + EXPECT_EQ("", Base64::UrlEncode("")); + + // No padding is emitted (unlike standard base64) + EXPECT_EQ("YQ", Base64::UrlEncode("a")); + EXPECT_EQ("YWI", Base64::UrlEncode("ab")); + EXPECT_EQ("YWJj", Base64::UrlEncode("abc")); + EXPECT_EQ("aGVsbG8", Base64::UrlEncode("hello")); + + // URL-safe characters: '-' and '_' instead of '+' and '/' + // bytes {0xFB, 0xFF, 0xFE} encode to "+//+" in standard base64, "-__-" in base64url + EXPECT_EQ("-__-", Base64::UrlEncode("\xFB\xFF\xFE")); + + // Round-trip with UrlDecode + ICEBERG_UNWRAP_OR_FAIL(auto decoded, Base64::UrlDecode(Base64::UrlEncode("hello"))); + EXPECT_EQ("hello", decoded); +} + +TEST(Base64Test, UrlDecode) { + // Empty string + ICEBERG_UNWRAP_OR_FAIL(auto empty, Base64::UrlDecode("")); + EXPECT_EQ("", empty); + + // Standard cases (same as Base64::Decode for alphanumeric) + ICEBERG_UNWRAP_OR_FAIL(auto hello, Base64::UrlDecode("aGVsbG8")); + EXPECT_EQ("hello", hello); + ICEBERG_UNWRAP_OR_FAIL(auto abc, Base64::UrlDecode("YWJj")); + EXPECT_EQ("abc", abc); + + // URL-safe characters: '-' and '_' instead of '+' and '/' + // bytes {0xFB, 0xFF, 0xFE} encode to "+//+" in standard base64, "-__-" in base64url + ICEBERG_UNWRAP_OR_FAIL(auto decoded, Base64::UrlDecode("-__-")); + EXPECT_EQ(3u, decoded.size()); + EXPECT_EQ('\xFB', decoded[0]); + EXPECT_EQ('\xFF', decoded[1]); + EXPECT_EQ('\xFE', decoded[2]); + + // Standard base64 chars '+' and '/' should be invalid in base64url + EXPECT_THAT(Base64::UrlDecode("+//+"), IsError(ErrorKind::kInvalidArgument)); + + // With padding (should handle gracefully) + ICEBERG_UNWRAP_OR_FAIL(auto hello2, Base64::UrlDecode("aGVsbG8=")); + EXPECT_EQ("hello", hello2); + + // Invalid characters return error + EXPECT_THAT(Base64::UrlDecode("!!!invalid!!!"), IsError(ErrorKind::kInvalidArgument)); +} + +} // namespace iceberg diff --git a/src/iceberg/test/meson.build b/src/iceberg/test/meson.build index b21a264b1..2f812655f 100644 --- a/src/iceberg/test/meson.build +++ b/src/iceberg/test/meson.build @@ -83,6 +83,7 @@ iceberg_tests = { }, 'util_test': { 'sources': files( + 'base64_test.cc', 'bucket_util_test.cc', 'config_test.cc', 'data_file_set_test.cc', diff --git a/src/iceberg/test/transform_util_test.cc b/src/iceberg/test/transform_util_test.cc index f5a22233a..13c96b9e1 100644 --- a/src/iceberg/test/transform_util_test.cc +++ b/src/iceberg/test/transform_util_test.cc @@ -152,30 +152,4 @@ TEST(TransformUtilTest, HumanTimestampNsWithZone) { EXPECT_EQ("1969-12-31T23:59:59.123456789+00:00", TransformUtil::HumanTimestampNsWithZone(-876543211)); } - -TEST(TransformUtilTest, Base64Encode) { - // Empty string - EXPECT_EQ("", TransformUtil::Base64Encode("")); - - // Single character - EXPECT_EQ("YQ==", TransformUtil::Base64Encode("a")); - EXPECT_EQ("YWI=", TransformUtil::Base64Encode("ab")); - EXPECT_EQ("YWJj", TransformUtil::Base64Encode("abc")); - - // Multiple of 3 characters - EXPECT_EQ("YWJjZGU=", TransformUtil::Base64Encode("abcde")); - EXPECT_EQ("YWJjZGVm", TransformUtil::Base64Encode("abcdef")); - - // Common strings - EXPECT_EQ("U29tZSBkYXRhIHdpdGggY2hhcmFjdGVycw==", - TransformUtil::Base64Encode("Some data with characters")); - EXPECT_EQ("aGVsbG8=", TransformUtil::Base64Encode("hello")); - EXPECT_EQ("dGVzdCBzdHJpbmc=", TransformUtil::Base64Encode("test string")); - - // Unicode - EXPECT_EQ("8J+EgA==", TransformUtil::Base64Encode("\xF0\x9F\x84\x80")); - // Null byte - EXPECT_EQ("AA==", TransformUtil::Base64Encode({"\x00", 1})); -} - } // namespace iceberg diff --git a/src/iceberg/transform.cc b/src/iceberg/transform.cc index c019c7ead..c915ec067 100644 --- a/src/iceberg/transform.cc +++ b/src/iceberg/transform.cc @@ -28,6 +28,7 @@ #include "iceberg/result.h" #include "iceberg/transform_function.h" #include "iceberg/type.h" +#include "iceberg/util/base64.h" #include "iceberg/util/checked_cast.h" #include "iceberg/util/macros.h" #include "iceberg/util/projection_util_internal.h" @@ -437,7 +438,7 @@ Result Transform::ToHumanString(const Literal& value) { case TypeId::kFixed: case TypeId::kBinary: { const auto& binary_data = std::get>(value.value()); - return TransformUtil::Base64Encode( + return Base64::Encode( {reinterpret_cast(binary_data.data()), binary_data.size()}); } case TypeId::kDecimal: { diff --git a/src/iceberg/util/base64.cc b/src/iceberg/util/base64.cc new file mode 100644 index 000000000..0baf1511e --- /dev/null +++ b/src/iceberg/util/base64.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/util/base64.h" + +#include +#include + +namespace iceberg { + +namespace { + +// Shared base64 encode logic. The encode alphabet maps a 6-bit value -> ASCII char. +// When `pad` is true, the output is padded with '=' to a multiple of 4 characters. +std::string Base64EncodeWithAlphabet(std::string_view input, std::string_view alphabet, + bool pad) { + std::string output; + output.reserve((input.size() + 2) / 3 * 4); + + uint32_t buffer = 0; + int bits_collected = 0; + for (unsigned char byte : input) { + buffer = (buffer << 8) | byte; + bits_collected += 8; + while (bits_collected >= 6) { + bits_collected -= 6; + output.push_back(alphabet[(buffer >> bits_collected) & 0x3F]); + } + } + if (bits_collected > 0) { + // Pad the remaining bits on the right to form the final 6-bit group. + output.push_back(alphabet[(buffer << (6 - bits_collected)) & 0x3F]); + } + if (pad) { + while (output.size() % 4 != 0) { + output.push_back('='); + } + } + return output; +} + +// Shared base64 decode logic. The decode table maps ASCII char -> 6-bit value. +// 0xFF means invalid character. +Result Base64DecodeWithTable(std::string_view input, + const std::array& table) { + // Strip trailing padding + while (!input.empty() && input.back() == '=') { + input.remove_suffix(1); + } + if (input.empty()) { + return std::string{}; + } + + std::string output; + output.reserve((input.size() * 3) / 4); + + uint32_t buffer = 0; + int bits_collected = 0; + + for (char c : input) { + uint8_t val = table[static_cast(c)]; + if (val == 0xFF) { + return InvalidArgument("Invalid base64 character: '{}'", c); + } + buffer = (buffer << 6) | val; + bits_collected += 6; + if (bits_collected >= 8) { + bits_collected -= 8; + output.push_back(static_cast((buffer >> bits_collected) & 0xFF)); + } + } + + return output; +} + +// Standard base64 alphabet: A-Z=0-25, a-z=26-51, 0-9=52-61, +=62, /=63 +constexpr std::string_view kBase64Chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + +// Base64url alphabet: same as standard but '-'=62, '_'=63 (RFC 4648 §5) +constexpr std::string_view kBase64UrlChars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + +// Standard base64 decode table: A-Z=0-25, a-z=26-51, 0-9=52-61, +=62, /=63 +constexpr std::array kBase64DecodeTable = [] { + std::array table{}; + table.fill(0xFF); + for (int i = 0; i < 26; ++i) { + table[static_cast('A' + i)] = static_cast(i); + table[static_cast('a' + i)] = static_cast(26 + i); + } + for (int i = 0; i < 10; ++i) { + table[static_cast('0' + i)] = static_cast(52 + i); + } + table[static_cast('+')] = 62; + table[static_cast('/')] = 63; + return table; +}(); + +// Base64url decode table: same as standard but '-'=62, '_'=63 (RFC 4648 §5) +constexpr std::array kBase64UrlDecodeTable = [] { + auto table = kBase64DecodeTable; + table[static_cast('+')] = 0xFF; // '+' is invalid in base64url + table[static_cast('/')] = 0xFF; // '/' is invalid in base64url + table[static_cast('-')] = 62; + table[static_cast('_')] = 63; + return table; +}(); + +} // namespace + +std::string Base64::Encode(std::string_view data) { + return Base64EncodeWithAlphabet(data, kBase64Chars, /*pad=*/true); +} + +Result Base64::Decode(std::string_view encoded) { + return Base64DecodeWithTable(encoded, kBase64DecodeTable); +} + +std::string Base64::UrlEncode(std::string_view data) { + return Base64EncodeWithAlphabet(data, kBase64UrlChars, /*pad=*/false); +} + +Result Base64::UrlDecode(std::string_view encoded) { + return Base64DecodeWithTable(encoded, kBase64UrlDecodeTable); +} + +} // namespace iceberg diff --git a/src/iceberg/util/base64.h b/src/iceberg/util/base64.h new file mode 100644 index 000000000..1a0ead0f0 --- /dev/null +++ b/src/iceberg/util/base64.h @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include +#include + +#include "iceberg/iceberg_export.h" +#include "iceberg/result.h" + +/// \file iceberg/util/base64.h +/// \brief Base64 and base64url encoding/decoding (RFC 4648). + +namespace iceberg { + +/// \brief Utilities for base64 and base64url encoding and decoding. +class ICEBERG_EXPORT Base64 { + public: + /// \brief Base64 encode a string (standard alphabet: +/). + /// + /// The output is padded with '=' to a multiple of 4 characters. + /// + /// \param data The string to encode. + /// \return The base64-encoded string. + static std::string Encode(std::string_view data); + + /// \brief Base64 decode a string (standard alphabet: +/). + /// + /// Handles optional padding ('='). + /// \param encoded The base64-encoded string. + /// \return Decoded string, or an error if the input contains invalid characters. + static Result Decode(std::string_view encoded); + + /// \brief Base64url encode a string (URL-safe alphabet: -_). + /// + /// This variant uses '-' and '_' instead of '+' and '/' per RFC 4648 §5 and + /// emits no '=' padding, matching the encoding commonly used by JWTs. + /// + /// \param data The string to encode. + /// \return The base64url-encoded string (without padding). + static std::string UrlEncode(std::string_view data); + + /// \brief Base64url decode a string (URL-safe alphabet: -_). + /// + /// Handles optional padding ('='). This variant uses '-' and '_' instead of + /// '+' and '/' per RFC 4648 §5. + /// \param encoded The base64url-encoded string. + /// \return Decoded string, or an error if the input contains invalid characters. + static Result UrlDecode(std::string_view encoded); +}; + +} // namespace iceberg diff --git a/src/iceberg/util/meson.build b/src/iceberg/util/meson.build index d70855016..91c183bb6 100644 --- a/src/iceberg/util/meson.build +++ b/src/iceberg/util/meson.build @@ -17,6 +17,7 @@ install_headers( [ + 'base64.h', 'bucket_util.h', 'checked_cast.h', 'config.h', diff --git a/src/iceberg/util/transform_util.cc b/src/iceberg/util/transform_util.cc index d12449382..0a8af63cd 100644 --- a/src/iceberg/util/transform_util.cc +++ b/src/iceberg/util/transform_util.cc @@ -19,7 +19,6 @@ #include "iceberg/util/transform_util.h" -#include #include #include @@ -139,53 +138,4 @@ std::string TransformUtil::HumanTimestampNsWithZone(int64_t timestamp_nanos) { } } -std::string TransformUtil::Base64Encode(std::string_view str_to_encode) { - static constexpr std::string_view kBase64Chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - int32_t i = 0; - int32_t j = 0; - std::array char_array_3; - std::array char_array_4; - - std::string encoded; - encoded.reserve((str_to_encode.size() + 2) * 4 / 3); - - for (unsigned char byte : str_to_encode) { - char_array_3[i++] = byte; - if (i == 3) { - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - - for (j = 0; j < 4; j++) { - encoded += kBase64Chars[char_array_4[j]]; - } - - i = 0; - } - } - - if (i) { - for (j = i; j < 3; j++) { - char_array_3[j] = '\0'; - } - - char_array_4[0] = (char_array_3[0] & 0xfc) >> 2; - char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4); - char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6); - char_array_4[3] = char_array_3[2] & 0x3f; - - for (j = 0; j < i + 1; j++) { - encoded += kBase64Chars[char_array_4[j]]; - } - - while (i++ < 3) { - encoded += '='; - } - } - - return encoded; -} - } // namespace iceberg diff --git a/src/iceberg/util/transform_util.h b/src/iceberg/util/transform_util.h index 2fbcb3e55..e0918dcf3 100644 --- a/src/iceberg/util/transform_util.h +++ b/src/iceberg/util/transform_util.h @@ -21,7 +21,6 @@ #include #include -#include #include "iceberg/iceberg_export.h" @@ -126,9 +125,6 @@ class ICEBERG_EXPORT TransformUtil { /// \param timestamp_nanos the timestamp in nanoseconds. /// \return a string representation of this timestamp. static std::string HumanTimestampNsWithZone(int64_t timestamp_nanos); - - /// \brief Base64 encode a string - static std::string Base64Encode(std::string_view str_to_encode); }; } // namespace iceberg