diff --git a/src/scitokens.cpp b/src/scitokens.cpp index b4d98ec..1794afd 100644 --- a/src/scitokens.cpp +++ b/src/scitokens.cpp @@ -43,7 +43,7 @@ void load_config_from_environment() { bool is_int; }; - const std::array known_configs = { + const std::array known_configs = { {{"keycache.update_interval_s", "KEYCACHE_UPDATE_INTERVAL_S", true}, {"keycache.expiration_interval_s", "KEYCACHE_EXPIRATION_INTERVAL_S", true}, @@ -54,7 +54,8 @@ void load_config_from_environment() { {"monitoring.file_interval_s", "MONITORING_FILE_INTERVAL_S", true}, {"keycache.refresh_interval_ms", "KEYCACHE_REFRESH_INTERVAL_MS", true}, {"keycache.refresh_threshold_ms", "KEYCACHE_REFRESH_THRESHOLD_MS", - true}}}; + true}, + {"jwks.download_timeout_s", "JWKS_DOWNLOAD_TIMEOUT_S", true}}}; const char *prefix = "SCITOKEN_CONFIG_"; @@ -1322,6 +1323,17 @@ int scitoken_config_set_int(const char *key, int value, char **err_msg) { return 0; } + else if (_key == "jwks.download_timeout_s") { + if (value <= 0) { + if (err_msg) { + *err_msg = strdup("JWKS download timeout must be positive."); + } + return -1; + } + configurer::Configuration::set_jwks_download_timeout(value); + return 0; + } + else { if (err_msg) { *err_msg = strdup("Key not recognized."); @@ -1363,6 +1375,10 @@ int scitoken_config_get_int(const char *key, char **err_msg) { return configurer::Configuration::get_refresh_threshold(); } + else if (_key == "jwks.download_timeout_s") { + return configurer::Configuration::get_jwks_download_timeout(); + } + else { if (err_msg) { *err_msg = strdup("Key not recognized."); diff --git a/src/scitokens_internal.cpp b/src/scitokens_internal.cpp index 34f3688..f1ded9f 100644 --- a/src/scitokens_internal.cpp +++ b/src/scitokens_internal.cpp @@ -951,8 +951,14 @@ std::string Validator::get_jwks(const std::string &issuer) { bool Validator::refresh_jwks(const std::string &issuer) { picojson::value keys; - std::unique_ptr status = get_public_keys_from_web( - issuer, internal::SimpleCurlGet::extended_timeout); + int configured_timeout = + configurer::Configuration::get_jwks_download_timeout(); + unsigned timeout = + std::max(configured_timeout > 0 ? static_cast(configured_timeout) + : internal::SimpleCurlGet::default_timeout, + internal::SimpleCurlGet::extended_timeout); + std::unique_ptr status = + get_public_keys_from_web(issuer, timeout); while (!status->m_done) { status = get_public_keys_from_web_continue(std::move(status)); } @@ -994,7 +1000,7 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, try { result->m_ignore_error = true; result = get_public_keys_from_web( - issuer, internal::SimpleCurlGet::default_timeout); + issuer, configurer::Configuration::get_jwks_download_timeout()); // Hold refresh mutex in the new result result->m_refresh_lock = std::move(lock); // Mark that this is a refresh attempt for a known issuer @@ -1034,7 +1040,7 @@ Validator::get_public_key_pem(const std::string &issuer, const std::string &kid, } else { // Still no keys, fetch them from the web result = get_public_keys_from_web( - issuer, internal::SimpleCurlGet::default_timeout); + issuer, configurer::Configuration::get_jwks_download_timeout()); // Transfer ownership of the lock to the async status // The lock will be held until keys are stored in diff --git a/src/scitokens_internal.h b/src/scitokens_internal.h index ae759ec..0d73b30 100644 --- a/src/scitokens_internal.h +++ b/src/scitokens_internal.h @@ -70,6 +70,14 @@ class Configuration { return m_allow_in_memory.load(std::memory_order_relaxed); } + // JWKS download timeout configuration (in seconds) + static void set_jwks_download_timeout(int timeout_s) { + get_jwks_download_timeout_ref() = timeout_s; + } + static int get_jwks_download_timeout() { + return get_jwks_download_timeout_ref(); + } + // Background refresh configuration static void set_background_refresh_enabled(bool enabled) { m_background_refresh_enabled = enabled; @@ -96,6 +104,10 @@ class Configuration { static std::atomic_int instance{4 * 24 * 3600}; return instance; } + static std::atomic_int &get_jwks_download_timeout_ref() { + static std::atomic_int instance{4}; + return instance; + } // Thread-safe accessors for string configurations static std::mutex &get_cache_home_mutex() { diff --git a/test/main.cpp b/test/main.cpp index af2bbb5..b2006aa 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -1432,6 +1432,60 @@ TEST_F(EnvConfigTest, StringConfigFromEnv) { // temp_cache destructor will clean up the directory } +TEST_F(EnvConfigTest, JwksDownloadTimeoutConfig) { + char *err_msg = nullptr; + + // Verify default value is 4 seconds + int default_val = + scitoken_config_get_int("jwks.download_timeout_s", &err_msg); + EXPECT_EQ(default_val, 4) << (err_msg ? err_msg : ""); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Set a custom value + int test_value = 10; + auto rv = scitoken_config_set_int("jwks.download_timeout_s", test_value, + &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Verify the value was set + int retrieved = + scitoken_config_get_int("jwks.download_timeout_s", &err_msg); + EXPECT_EQ(retrieved, test_value) << (err_msg ? err_msg : ""); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Setting zero should fail + rv = scitoken_config_set_int("jwks.download_timeout_s", 0, &err_msg); + EXPECT_NE(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Setting negative should fail + rv = scitoken_config_set_int("jwks.download_timeout_s", -1, &err_msg); + EXPECT_NE(rv, 0); + if (err_msg) { + free(err_msg); + err_msg = nullptr; + } + + // Restore default + rv = scitoken_config_set_int("jwks.download_timeout_s", 4, &err_msg); + ASSERT_EQ(rv, 0) << (err_msg ? err_msg : ""); + if (err_msg) + free(err_msg); +} + // Test for thundering herd prevention with per-issuer locks TEST_F(IssuerSecurityTest, ThunderingHerdPrevention) { char *err_msg = nullptr; diff --git a/test/test_env_config.cpp b/test/test_env_config.cpp index 0acef6c..e3e21ec 100644 --- a/test/test_env_config.cpp +++ b/test/test_env_config.cpp @@ -151,6 +151,35 @@ int main() { } } + // Test 6: Check if SCITOKEN_CONFIG_JWKS_DOWNLOAD_TIMEOUT_S was loaded + const char *env_timeout = + std::getenv("SCITOKEN_CONFIG_JWKS_DOWNLOAD_TIMEOUT_S"); + if (env_timeout) { + try { + int expected = std::stoi(env_timeout); + int actual = scitoken_config_get_int("jwks.download_timeout_s", + &err_msg); + if (actual != expected) { + std::cerr + << "FAIL: jwks.download_timeout_s expected " << expected + << " but got " << actual << std::endl; + if (err_msg) { + std::cerr << "Error: " << err_msg << std::endl; + free(err_msg); + err_msg = nullptr; + } + failures++; + } else { + std::cout << "PASS: jwks.download_timeout_s = " << actual + << std::endl; + } + } catch (const std::exception &e) { + std::cerr << "FAIL: Could not parse env var value: " << e.what() + << std::endl; + failures++; + } + } + if (failures == 0) { std::cout << "\nAll environment variable configuration tests passed!" << std::endl;