diff --git a/crates/forge_app/src/infra.rs b/crates/forge_app/src/infra.rs index aab687c1c8..a75da583d2 100644 --- a/crates/forge_app/src/infra.rs +++ b/crates/forge_app/src/infra.rs @@ -315,10 +315,15 @@ pub trait OAuthHttpProvider: Send + Sync { config: &OAuthConfig, code: &str, verifier: Option<&str>, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result; /// Creates an HTTP client with provider-specific headers and behavior. - fn build_http_client(&self, config: &OAuthConfig) -> anyhow::Result; + fn build_http_client( + &self, + config: &OAuthConfig, + http_config: &forge_domain::HttpConfig, + ) -> anyhow::Result; } /// Authentication strategy trait diff --git a/crates/forge_config/src/decimal.rs b/crates/forge_config/src/decimal.rs index 50af68c5e5..512714b804 100644 --- a/crates/forge_config/src/decimal.rs +++ b/crates/forge_config/src/decimal.rs @@ -21,7 +21,7 @@ impl std::fmt::Debug for Decimal { impl Clone for Decimal { fn clone(&self) -> Self { - Self(self.0) + *self } } diff --git a/crates/forge_config/src/percentage.rs b/crates/forge_config/src/percentage.rs index 532e96ca6d..ad299e3fc8 100644 --- a/crates/forge_config/src/percentage.rs +++ b/crates/forge_config/src/percentage.rs @@ -43,7 +43,7 @@ impl std::fmt::Debug for Percentage { impl Clone for Percentage { fn clone(&self) -> Self { - Self(self.0) + *self } } diff --git a/crates/forge_domain/src/context.rs b/crates/forge_domain/src/context.rs index d3018a56ce..01a2ce9570 100644 --- a/crates/forge_domain/src/context.rs +++ b/crates/forge_domain/src/context.rs @@ -556,6 +556,7 @@ impl Context { /// are supported and uses the appropriate format. For models that don't /// support tools, use the TransformToolCalls transformer to convert the /// context afterward. + #[allow(clippy::too_many_arguments)] // Each parameter is a distinct, meaningful field from the model response; grouping them would add a wrapper struct with no semantic benefit. pub fn append_message( self, content: impl ToString, diff --git a/crates/forge_infra/src/auth/http/anthropic.rs b/crates/forge_infra/src/auth/http/anthropic.rs index 2b34d23c90..5868af0124 100644 --- a/crates/forge_infra/src/auth/http/anthropic.rs +++ b/crates/forge_infra/src/auth/http/anthropic.rs @@ -66,6 +66,7 @@ impl OAuthHttpProvider for AnthropicHttpProvider { config: &OAuthConfig, code: &str, verifier: Option<&str>, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { // Anthropic-specific token exchange let (code, state) = if code.contains('#') { @@ -87,7 +88,7 @@ impl OAuthHttpProvider for AnthropicHttpProvider { code_verifier: verifier.to_string(), }; - let client = self.build_http_client(config)?; + let client = self.build_http_client(config, http_config)?; let response = client .post(config.token_url.as_str()) .header("Content-Type", "application/json") @@ -105,8 +106,12 @@ impl OAuthHttpProvider for AnthropicHttpProvider { } /// Create HTTP client with provider-specific headers/behavior - fn build_http_client(&self, config: &OAuthConfig) -> anyhow::Result { - build_http_client(config.custom_headers.as_ref()) + fn build_http_client( + &self, + config: &OAuthConfig, + http_config: &forge_domain::HttpConfig, + ) -> anyhow::Result { + build_http_client(config.custom_headers.as_ref(), http_config) } } diff --git a/crates/forge_infra/src/auth/http/github.rs b/crates/forge_infra/src/auth/http/github.rs index 3b3209febb..725fe35b91 100644 --- a/crates/forge_infra/src/auth/http/github.rs +++ b/crates/forge_infra/src/auth/http/github.rs @@ -19,17 +19,22 @@ impl OAuthHttpProvider for GithubHttpProvider { config: &OAuthConfig, code: &str, verifier: Option<&str>, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { // Use standard exchange - quirks handled in HTTP client via // github_compliant_http_request StandardHttpProvider - .exchange_code(config, code, verifier) + .exchange_code(config, code, verifier, http_config) .await } - fn build_http_client(&self, config: &OAuthConfig) -> anyhow::Result { + fn build_http_client( + &self, + config: &OAuthConfig, + http_config: &forge_domain::HttpConfig, + ) -> anyhow::Result { // GitHub quirk: HTTP 200 responses may contain errors // This is handled by the github_compliant_http_request function - build_http_client(config.custom_headers.as_ref()) + build_http_client(config.custom_headers.as_ref(), http_config) } } diff --git a/crates/forge_infra/src/auth/http/standard.rs b/crates/forge_infra/src/auth/http/standard.rs index 8df1acd14e..eb97e348e6 100644 --- a/crates/forge_infra/src/auth/http/standard.rs +++ b/crates/forge_infra/src/auth/http/standard.rs @@ -57,6 +57,7 @@ impl OAuthHttpProvider for StandardHttpProvider { config: &OAuthConfig, code: &str, verifier: Option<&str>, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { use oauth2::{AuthUrl, ClientId, TokenUrl}; @@ -69,7 +70,7 @@ impl OAuthHttpProvider for StandardHttpProvider { client = client.set_redirect_uri(oauth2::RedirectUrl::new(redirect_uri.clone())?); } - let http_client = self.build_http_client(config)?; + let http_client = self.build_http_client(config, http_config)?; let mut request = client.exchange_code(OAuth2AuthCode::new(code.to_string())); @@ -82,8 +83,12 @@ impl OAuthHttpProvider for StandardHttpProvider { } /// Create HTTP client with provider-specific headers/behavior - fn build_http_client(&self, config: &OAuthConfig) -> anyhow::Result { - build_http_client(config.custom_headers.as_ref()) + fn build_http_client( + &self, + config: &OAuthConfig, + http_config: &forge_domain::HttpConfig, + ) -> anyhow::Result { + build_http_client(config.custom_headers.as_ref(), http_config) } } diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 1062581faf..9994dbee33 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -62,11 +62,17 @@ pub struct OAuthCodeStrategy { provider_id: ProviderId, config: OAuthConfig, adapter: T, + http_config: forge_domain::HttpConfig, } impl OAuthCodeStrategy { - pub fn new(adapter: T, provider_id: ProviderId, config: OAuthConfig) -> Self { - Self { config, provider_id, adapter } + pub fn new( + adapter: T, + provider_id: ProviderId, + config: OAuthConfig, + http_config: forge_domain::HttpConfig, + ) -> Self { + Self { config, provider_id, adapter, http_config } } } @@ -99,6 +105,7 @@ impl AuthStrategy for OAuthCodeStrategy { &self.config, ctx.response.code.as_str(), ctx.request.pkce_verifier.as_ref().map(|v| v.as_str()), + &self.http_config, ) .await .map_err(|e| { @@ -124,6 +131,7 @@ impl AuthStrategy for OAuthCodeStrategy { &self.config, chrono::Duration::hours(1), false, // No API key exchange + &self.http_config, ) .await } @@ -133,11 +141,16 @@ impl AuthStrategy for OAuthCodeStrategy { pub struct OAuthDeviceStrategy { provider_id: ProviderId, config: OAuthConfig, + http_config: forge_domain::HttpConfig, } impl OAuthDeviceStrategy { - pub fn new(provider_id: ProviderId, config: OAuthConfig) -> Self { - Self { provider_id, config } + pub fn new( + provider_id: ProviderId, + config: OAuthConfig, + http_config: forge_domain::HttpConfig, + ) -> Self { + Self { provider_id, config, http_config } } } @@ -162,9 +175,10 @@ impl AuthStrategy for OAuthDeviceStrategy { } // Build HTTP client with custom headers - let http_client = build_http_client(self.config.custom_headers.as_ref()).map_err(|e| { - AuthError::InitiationFailed(format!("Failed to build HTTP client: {e}")) - })?; + let http_client = build_http_client(self.config.custom_headers.as_ref(), &self.http_config) + .map_err(|e| { + AuthError::InitiationFailed(format!("Failed to build HTTP client: {e}")) + })?; let http_fn = |req| github_compliant_http_request(http_client.clone(), req); @@ -202,6 +216,7 @@ impl AuthStrategy for OAuthDeviceStrategy { &self.config, Duration::from_secs(600), false, + &self.http_config, ) .await?; @@ -222,6 +237,7 @@ impl AuthStrategy for OAuthDeviceStrategy { &self.config, chrono::Duration::days(30), false, // No API key exchange + &self.http_config, ) .await } @@ -232,16 +248,21 @@ pub struct OAuthWithApiKeyStrategy { provider_id: ProviderId, oauth_config: OAuthConfig, api_key_exchange_url: Url, + http_config: forge_domain::HttpConfig, } impl OAuthWithApiKeyStrategy { - pub fn new(provider_id: ProviderId, oauth_config: OAuthConfig) -> anyhow::Result { + pub fn new( + provider_id: ProviderId, + oauth_config: OAuthConfig, + http_config: forge_domain::HttpConfig, + ) -> anyhow::Result { let api_key_exchange_url = oauth_config .token_refresh_url .clone() .ok_or_else(|| AuthError::InitiationFailed("Missing token_refresh_url".to_string()))?; - Ok(Self { provider_id, oauth_config, api_key_exchange_url }) + Ok(Self { provider_id, oauth_config, api_key_exchange_url, http_config }) } } @@ -265,9 +286,10 @@ impl AuthStrategy for OAuthWithApiKeyStrategy { } let http_client = - build_http_client(self.oauth_config.custom_headers.as_ref()).map_err(|e| { - AuthError::InitiationFailed(format!("Failed to build HTTP client: {e}")) - })?; + build_http_client(self.oauth_config.custom_headers.as_ref(), &self.http_config) + .map_err(|e| { + AuthError::InitiationFailed(format!("Failed to build HTTP client: {e}")) + })?; let http_fn = |req| github_compliant_http_request(http_client.clone(), req); @@ -305,6 +327,7 @@ impl AuthStrategy for OAuthWithApiKeyStrategy { &self.oauth_config, Duration::from_secs(600), true, + &self.http_config, ) .await?; @@ -313,6 +336,7 @@ impl AuthStrategy for OAuthWithApiKeyStrategy { &token_response.access_token, &self.api_key_exchange_url, &self.oauth_config, + &self.http_config, ) .await?; @@ -339,6 +363,7 @@ impl AuthStrategy for OAuthWithApiKeyStrategy { &self.oauth_config, chrono::Duration::hours(1), // Unused for API key flow true, // WITH API key exchange + &self.http_config, ) .await } @@ -456,11 +481,16 @@ impl AuthStrategy for GoogleAdcStrategy { pub struct CodexDeviceStrategy { provider_id: ProviderId, config: OAuthConfig, + http_config: forge_domain::HttpConfig, } impl CodexDeviceStrategy { - pub fn new(provider_id: ProviderId, config: OAuthConfig) -> Self { - Self { provider_id, config } + pub fn new( + provider_id: ProviderId, + config: OAuthConfig, + http_config: forge_domain::HttpConfig, + ) -> Self { + Self { provider_id, config, http_config } } } @@ -517,9 +547,10 @@ fn extract_chatgpt_account_id(token: &str) -> Option { #[async_trait::async_trait] impl AuthStrategy for CodexDeviceStrategy { async fn init(&self) -> anyhow::Result { - let http_client = build_http_client(self.config.custom_headers.as_ref()).map_err(|e| { - AuthError::InitiationFailed(format!("Failed to build HTTP client: {e}")) - })?; + let http_client = build_http_client(self.config.custom_headers.as_ref(), &self.http_config) + .map_err(|e| { + AuthError::InitiationFailed(format!("Failed to build HTTP client: {e}")) + })?; // Step 1: Request device authorization from OpenAI's custom endpoint let response = http_client @@ -568,7 +599,8 @@ impl AuthStrategy for CodexDeviceStrategy { match context_response { AuthContextResponse::DeviceCode(ctx) => { // Poll for authorization code using the custom OpenAI endpoint - let token_response = codex_poll_for_tokens(&ctx.request, &self.config).await?; + let token_response = + codex_poll_for_tokens(&ctx.request, &self.config, &self.http_config).await?; // Extract ChatGPT account ID from the access token JWT. // This is used for the optional `ChatGPT-Account-Id` request @@ -597,7 +629,14 @@ impl AuthStrategy for CodexDeviceStrategy { } async fn refresh(&self, credential: &AuthCredential) -> anyhow::Result { - refresh_oauth_credential(credential, &self.config, chrono::Duration::hours(1), false).await + refresh_oauth_credential( + credential, + &self.config, + chrono::Duration::hours(1), + false, + &self.http_config, + ) + .await } } @@ -607,6 +646,7 @@ async fn refresh_oauth_credential( config: &OAuthConfig, expiry_duration: chrono::Duration, with_api_key_exchange: bool, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { // Extract tokens (works for OAuth and OAuthWithApiKey) let tokens = extract_oauth_tokens(credential)?; @@ -616,7 +656,8 @@ async fn refresh_oauth_credential( if let Some(refresh_token) = &tokens.refresh_token { // If we have a refresh token, refresh the OAuth access token first tracing::debug!("Refreshing OAuth access token using refresh token"); - let token_response = refresh_access_token(config, refresh_token.as_str()).await?; + let token_response = + refresh_access_token(config, refresh_token.as_str(), http_config).await?; ( token_response.access_token.clone(), token_response.refresh_token, @@ -636,7 +677,8 @@ async fn refresh_oauth_credential( let url = config.token_refresh_url.as_ref().ok_or_else(|| { AuthError::RefreshFailed("Missing token_refresh_url for API key exchange".to_string()) })?; - let (key, expiry) = exchange_oauth_for_api_key(&oauth_access_token, url, config).await?; + let (key, expiry) = + exchange_oauth_for_api_key(&oauth_access_token, url, config, http_config).await?; (Some(key), expiry) } else { let expiry = calculate_token_expiry(None, expiry_duration); @@ -667,8 +709,9 @@ async fn poll_for_tokens( config: &OAuthConfig, timeout: Duration, github_compatible: bool, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { - let http_client = build_http_client(config.custom_headers.as_ref()) + let http_client = build_http_client(config.custom_headers.as_ref(), http_config) .map_err(|e| AuthError::PollFailed(format!("Failed to build HTTP client: {e}")))?; let start_time = tokio::time::Instant::now(); @@ -792,8 +835,9 @@ async fn poll_for_tokens( async fn codex_poll_for_tokens( request: &DeviceCodeRequest, config: &OAuthConfig, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { - let http_client = build_http_client(config.custom_headers.as_ref()) + let http_client = build_http_client(config.custom_headers.as_ref(), http_config) .map_err(|e| AuthError::PollFailed(format!("Failed to build HTTP client: {e}")))?; let timeout = Duration::from_secs(request.expires_in); @@ -836,11 +880,10 @@ async fn codex_poll_for_tokens( })?; // Exchange the authorization code for OAuth tokens via standard - // endpoint. Use a clean HTTP client without custom headers since the - // standard OAuth token endpoint rejects unknown headers. - let clean_client = reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .build() + // endpoint. Use a client without custom headers (the standard OAuth + // token endpoint rejects unknown headers) but still respect the + // caller's TLS/proxy configuration. + let clean_client = build_http_client(None, http_config) .map_err(|e| AuthError::PollFailed(format!("Failed to build HTTP client: {e}")))?; let token_response = clean_client @@ -905,6 +948,7 @@ async fn exchange_oauth_for_api_key( oauth_token: &str, api_key_exchange_url: &Url, config: &OAuthConfig, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result<(ApiKey, chrono::DateTime)> { // Build request headers let mut headers = reqwest::header::HeaderMap::new(); @@ -918,7 +962,7 @@ async fn exchange_oauth_for_api_key( // Add custom headers from config inject_custom_headers(&mut headers, &config.custom_headers); - let response = build_http_client(config.custom_headers.as_ref()) + let response = build_http_client(config.custom_headers.as_ref(), http_config) .map_err(|e| AuthError::CompletionFailed(format!("Failed to build HTTP client: {e}")))? .get(api_key_exchange_url.as_str()) .headers(headers) @@ -1015,17 +1059,19 @@ impl AuthStrategy for AnyAuthStrategy { } /// Factory for creating authentication strategies -pub struct ForgeAuthStrategyFactory {} +pub struct ForgeAuthStrategyFactory { + http_config: forge_domain::HttpConfig, +} impl Default for ForgeAuthStrategyFactory { fn default() -> Self { - Self::new() + Self::new(forge_domain::HttpConfig::default()) } } impl ForgeAuthStrategyFactory { - pub fn new() -> Self { - Self {} + pub fn new(http_config: forge_domain::HttpConfig) -> Self { + Self { http_config } } } @@ -1049,6 +1095,7 @@ impl StrategyFactory for ForgeAuthStrategyFactory { AnthropicHttpProvider, provider_id, config, + self.http_config.clone(), ))); } @@ -1057,6 +1104,7 @@ impl StrategyFactory for ForgeAuthStrategyFactory { GithubHttpProvider, provider_id, config, + self.http_config.clone(), ))); } @@ -1064,18 +1112,24 @@ impl StrategyFactory for ForgeAuthStrategyFactory { StandardHttpProvider, provider_id, config, + self.http_config.clone(), ))) } forge_domain::AuthMethod::OAuthDevice(config) => { // Check if this is OAuth-with-API-Key flow (GitHub Copilot pattern) if config.token_refresh_url.is_some() { Ok(AnyAuthStrategy::OAuthWithApiKey( - OAuthWithApiKeyStrategy::new(provider_id, config)?, + OAuthWithApiKeyStrategy::new( + provider_id, + config, + self.http_config.clone(), + )?, )) } else { Ok(AnyAuthStrategy::OAuthDevice(OAuthDeviceStrategy::new( provider_id, config, + self.http_config.clone(), ))) } } @@ -1083,7 +1137,7 @@ impl StrategyFactory for ForgeAuthStrategyFactory { GoogleAdcStrategy::new(provider_id, required_params), )), forge_domain::AuthMethod::CodexDevice(config) => Ok(AnyAuthStrategy::CodexDevice( - CodexDeviceStrategy::new(provider_id, config), + CodexDeviceStrategy::new(provider_id, config, self.http_config.clone()), )), } } @@ -1100,7 +1154,7 @@ mod tests { #[test] fn test_create_auth_strategy_api_key() { - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory::default(); let strategy = factory.create_auth_strategy( ProviderId::OPENAI, forge_domain::AuthMethod::ApiKey, @@ -1123,7 +1177,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory::default(); let strategy = factory.create_auth_strategy( ProviderId::OPENAI, forge_domain::AuthMethod::OAuthCode(config), @@ -1146,7 +1200,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory::default(); let strategy = factory.create_auth_strategy( ProviderId::OPENAI, forge_domain::AuthMethod::OAuthDevice(config), @@ -1169,7 +1223,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory::default(); let strategy = factory.create_auth_strategy( ProviderId::GITHUB_COPILOT, forge_domain::AuthMethod::OAuthDevice(config), @@ -1193,7 +1247,7 @@ mod tests { custom_headers: None, }; - let factory = ForgeAuthStrategyFactory::new(); + let factory = ForgeAuthStrategyFactory::default(); let actual = factory.create_auth_strategy( ProviderId::CODEX, forge_domain::AuthMethod::CodexDevice(config), @@ -1320,6 +1374,7 @@ mod tests { &fixture_config, chrono::Duration::hours(1), false, + &forge_domain::HttpConfig::default(), ) .await .unwrap(); diff --git a/crates/forge_infra/src/auth/util.rs b/crates/forge_infra/src/auth/util.rs index de652e0f67..c7013440ff 100644 --- a/crates/forge_infra/src/auth/util.rs +++ b/crates/forge_infra/src/auth/util.rs @@ -8,6 +8,7 @@ use oauth2::basic::BasicClient; use oauth2::{ClientId, RefreshToken, TokenUrl}; use crate::auth::error::Error; +use crate::http::ClientBuilderExt; /// Calculate token expiry with fallback duration pub(crate) fn calculate_token_expiry( @@ -39,30 +40,35 @@ pub(crate) fn into_domain(token: T) -> OAuthTokenRespo } } -/// Build HTTP client with custom headers +/// Build HTTP client with custom headers, respecting proxy and TLS settings +/// from the supplied [`forge_domain::HttpConfig`]. +/// +/// **Proxy**: `reqwest` automatically reads +/// `HTTPS_PROXY`/`https_proxy`/`ALL_PROXY` for HTTPS traffic, but does **not** +/// fall back to `HTTP_PROXY` for HTTPS requests. In corporate environments +/// where only `HTTP_PROXY` is set, HTTPS requests would bypass the proxy +/// entirely and fail when direct outbound connections are blocked. +/// This function detects that situation and explicitly routes HTTPS traffic +/// through `HTTP_PROXY` as well. +/// +/// **TLS**: Corporate proxies commonly perform TLS inspection using a private +/// root CA installed in the system certificate store. `rustls` ships its own +/// Mozilla CA bundle and does **not** read the OS store, so the TLS handshake +/// fails even when the proxy is correctly configured. The `http_config` +/// parameter carries the same `accept_invalid_certs` and `root_cert_paths` +/// settings that `ForgeHttpInfra` uses, so a custom corporate CA is trusted by +/// auth requests too. pub(crate) fn build_http_client( custom_headers: Option<&HashMap>, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { - let mut builder = reqwest::Client::builder() + Ok(reqwest::Client::builder() // Disable redirects to prevent SSRF vulnerabilities - .redirect(reqwest::redirect::Policy::none()); - - if let Some(headers) = custom_headers { - let mut header_map = reqwest::header::HeaderMap::new(); - - for (key, value) in headers { - let header_name = reqwest::header::HeaderName::try_from(key.as_str()) - .map_err(|e| anyhow::anyhow!("Invalid header name '{key}': {e}"))?; - let header_value = value - .parse() - .map_err(|e| anyhow::anyhow!("Invalid header value for '{key}': {e}"))?; - header_map.insert(header_name, header_value); - } - - builder = builder.default_headers(header_map); - } - - Ok(builder.build()?) + .redirect(reqwest::redirect::Policy::none()) + .with_proxy_fallback()? + .with_tls_config(http_config) + .with_custom_headers(custom_headers.into_iter().flat_map(|m| m.iter()))? + .build()?) } /// Build OAuth credential with consistent expiry handling @@ -116,13 +122,14 @@ pub(crate) fn extract_oauth_tokens(credential: &AuthCredential) -> anyhow::Resul pub(crate) async fn refresh_access_token( config: &OAuthConfig, refresh_token: &str, + http_config: &forge_domain::HttpConfig, ) -> anyhow::Result { // Build minimal oauth2 client (just need token endpoint) let client = BasicClient::new(ClientId::new(config.client_id.to_string())) .set_token_uri(TokenUrl::new(config.token_url.to_string())?); - // Build HTTP client with custom headers - let http_client = build_http_client(config.custom_headers.as_ref())?; + // Build HTTP client with custom headers and caller-supplied TLS/proxy config + let http_client = build_http_client(config.custom_headers.as_ref(), http_config)?; let refresh_token = RefreshToken::new(refresh_token.to_string()); @@ -242,6 +249,128 @@ mod tests { use super::*; + // Serialise proxy-related tests: env vars are process-global so concurrent + // mutation causes flaky failures. + static PROXY_TEST_MUTEX: std::sync::Mutex<()> = std::sync::Mutex::new(()); + + /// Prove that a plain `reqwest::Client` (the old implementation) silently + /// ignores `HTTP_PROXY` when making HTTPS requests. The fake proxy TCP + /// listener receives **no** connection — the request bypasses it entirely. + #[tokio::test] + async fn test_old_client_ignores_http_proxy_for_https() { + let _guard = PROXY_TEST_MUTEX.lock().unwrap(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let proxy_url = format!("http://{}", listener.local_addr().unwrap()); + + // Only HTTP_PROXY is set — no HTTPS_PROXY / ALL_PROXY + unsafe { + std::env::set_var("HTTP_PROXY", &proxy_url); + std::env::remove_var("HTTPS_PROXY"); + std::env::remove_var("https_proxy"); + std::env::remove_var("ALL_PROXY"); + std::env::remove_var("all_proxy"); + } + + // OLD: bare reqwest builder — identical to the pre-fix build_http_client + let old_client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + let accept_task = tokio::spawn(async move { + tokio::time::timeout(std::time::Duration::from_millis(500), listener.accept()) + .await + .is_ok() + }); + + let _ = old_client + .post("https://github.com/login/device/code") + .send() + .await; + + let proxy_was_contacted = accept_task.await.unwrap(); + unsafe { std::env::remove_var("HTTP_PROXY") }; + + assert!( + !proxy_was_contacted, + "OLD client should bypass HTTP_PROXY for HTTPS — proxy never contacted" + ); + } + + /// Prove that `build_http_client` (the new implementation) routes HTTPS + /// requests through `HTTP_PROXY` when no `HTTPS_PROXY` / `ALL_PROXY` is + /// set. The fake proxy TCP listener **does** receive the connection. + #[tokio::test] + async fn test_new_client_routes_https_through_http_proxy() { + let _guard = PROXY_TEST_MUTEX.lock().unwrap(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let proxy_url = format!("http://{}", listener.local_addr().unwrap()); + + // Only HTTP_PROXY is set — no HTTPS_PROXY / ALL_PROXY + unsafe { + std::env::set_var("HTTP_PROXY", &proxy_url); + std::env::remove_var("HTTPS_PROXY"); + std::env::remove_var("https_proxy"); + std::env::remove_var("ALL_PROXY"); + std::env::remove_var("all_proxy"); + } + + // NEW: build_http_client with the proxy fallback logic + let new_client = build_http_client(None, &forge_domain::HttpConfig::default()).unwrap(); + + let accept_task = tokio::spawn(async move { + tokio::time::timeout(std::time::Duration::from_millis(500), listener.accept()) + .await + .is_ok() + }); + + let _ = new_client + .post("https://github.com/login/device/code") + .send() + .await; + + let proxy_was_contacted = accept_task.await.unwrap(); + unsafe { std::env::remove_var("HTTP_PROXY") }; + + assert!( + proxy_was_contacted, + "NEW client should route HTTPS traffic through HTTP_PROXY" + ); + } + + /// Prove that `build_http_client` applies `root_cert_paths` from a + /// caller-supplied [`forge_domain::HttpConfig`]. A temp file containing + /// clearly invalid cert data is passed directly in the config. The client + /// must still build successfully, confirming that parse failures are + /// silently skipped rather than propagated. + #[test] + fn test_build_http_client_loads_root_cert_from_config() { + let _guard = PROXY_TEST_MUTEX.lock().unwrap(); + + // Write content that is definitely not a valid PEM or DER certificate. + // Certificate::from_pem and from_der both return Err, which is silently + // skipped. The important thing is that the code path is exercised. + let cert_file = tempfile::NamedTempFile::new().unwrap(); + std::fs::write(cert_file.path(), b"not a certificate").unwrap(); + + let http_config = forge_domain::HttpConfig { + root_cert_paths: Some(vec![cert_file.path().to_str().unwrap().to_string()]), + ..Default::default() + }; + + let result = build_http_client(None, &http_config); + + // Client must build successfully — the invalid cert is gracefully + // skipped rather than propagating an error. + assert!( + result.is_ok(), + "build_http_client should succeed even with unparseable cert: {:?}", + result.as_ref().err() + ); + } + #[test] fn test_calculate_token_expiry_with_expires_in() { let before = Utc::now(); diff --git a/crates/forge_infra/src/forge_infra.rs b/crates/forge_infra/src/forge_infra.rs index b899eee341..db05365786 100644 --- a/crates/forge_infra/src/forge_infra.rs +++ b/crates/forge_infra/src/forge_infra.rs @@ -27,8 +27,8 @@ use crate::fs_read_dir::ForgeDirectoryReaderService; use crate::fs_remove::ForgeFileRemoveService; use crate::fs_write::ForgeFileWriteService; use crate::grpc::ForgeGrpcClient; -use crate::http::ForgeHttpInfra; use crate::inquire::ForgeInquire; +use crate::llm_client::ForgeHttpInfra; use crate::mcp_client::ForgeMcpClient; use crate::mcp_server::ForgeMcpServer; use crate::walker::ForgeWalkerService; @@ -81,9 +81,9 @@ impl ForgeInfra { output_printer.clone(), )), inquire_service: Arc::new(ForgeInquire::new()), - mcp_server: ForgeMcpServer, + mcp_server: ForgeMcpServer::new(env.http.clone()), walker_service: Arc::new(ForgeWalkerService::new()), - strategy_factory: Arc::new(ForgeAuthStrategyFactory::new()), + strategy_factory: Arc::new(ForgeAuthStrategyFactory::new(env.http.clone())), http_service, grpc_client, output_printer, diff --git a/crates/forge_infra/src/http.rs b/crates/forge_infra/src/http.rs index 3baf911e30..bce826d827 100644 --- a/crates/forge_infra/src/http.rs +++ b/crates/forge_infra/src/http.rs @@ -1,27 +1,9 @@ -use std::fs; -use std::sync::Arc; use std::time::Duration; -use anyhow::Context; -use bytes::Bytes; -use forge_app::HttpInfra; -use forge_domain::{Environment, TlsBackend, TlsVersion}; -use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; +use forge_domain::{HttpConfig, TlsBackend, TlsVersion}; +use reqwest::Certificate; use reqwest::redirect::Policy; -use reqwest::{Certificate, Client, Response, StatusCode, Url}; -use reqwest_eventsource::{EventSource, RequestBuilderExt}; -use tracing::{debug, warn}; - -const VERSION: &str = match option_env!("APP_VERSION") { - None => env!("CARGO_PKG_VERSION"), - Some(v) => v, -}; - -pub struct ForgeHttpInfra { - client: Client, - env: Environment, - file: Arc, -} +use tracing::warn; fn to_reqwest_tls(tls: TlsVersion) -> reqwest::tls::Version { use reqwest::tls::Version; @@ -33,434 +15,143 @@ fn to_reqwest_tls(tls: TlsVersion) -> reqwest::tls::Version { } } -impl ForgeHttpInfra { - pub fn new(env: Environment, file_writer: Arc) -> Self { - let env = env.clone(); - let env_http = env.clone(); - let mut client = reqwest::Client::builder() - .connect_timeout(std::time::Duration::from_secs( - env_http.http.connect_timeout, - )) - .read_timeout(std::time::Duration::from_secs(env_http.http.read_timeout)) - .pool_idle_timeout(std::time::Duration::from_secs( - env_http.http.pool_idle_timeout, - )) - .pool_max_idle_per_host(env_http.http.pool_max_idle_per_host) - .redirect(Policy::limited(env_http.http.max_redirects)) - .hickory_dns(env_http.http.hickory) - // HTTP/2 configuration from config - .http2_adaptive_window(env_http.http.adaptive_window) - .http2_keep_alive_interval(env_http.http.keep_alive_interval.map(Duration::from_secs)) - .http2_keep_alive_timeout(Duration::from_secs(env_http.http.keep_alive_timeout)) - .http2_keep_alive_while_idle(env_http.http.keep_alive_while_idle); +/// Extension methods on [`reqwest::ClientBuilder`] that act as composable +/// transformers. Each method applies one configuration concern; callers mix +/// and match them to build any HTTP client without duplicating logic. +pub(crate) trait ClientBuilderExt: Sized { + /// Applies the full [`HttpConfig`]: timeouts, connection-pooling, redirect + /// policy, Hickory DNS, HTTP/2 keep-alive, and all TLS settings via + /// [`ClientBuilderExt::with_tls_config`]. + fn with_http_config(self, config: &HttpConfig) -> Self; + + /// Applies only the TLS subset of [`HttpConfig`]: root certificate paths, + /// `accept_invalid_certs`, min/max TLS version, and TLS backend selection. + /// + /// Certificate file read or parse failures are emitted as warnings and + /// skipped rather than propagated as errors. + fn with_tls_config(self, config: &HttpConfig) -> Self; + + /// Routes HTTPS traffic through `HTTP_PROXY` when no `HTTPS_PROXY` or + /// `ALL_PROXY` is set in the environment. This compensates for reqwest's + /// intentional behaviour of applying `HTTP_PROXY` only to plaintext HTTP + /// requests, which causes HTTPS traffic to bypass the proxy in corporate + /// environments where only `HTTP_PROXY` is configured. + fn with_proxy_fallback(self) -> anyhow::Result; + + /// Adds every `(key, value)` pair from `headers` to the client's default + /// header map. The call is a no-op when the iterator is empty. Returns an + /// error if any header name or value is invalid. + fn with_custom_headers( + self, + headers: impl IntoIterator, + ) -> anyhow::Result + where + K: AsRef, + V: AsRef; +} - // Add root certificates from config - if let Some(ref cert_paths) = env_http.http.root_cert_paths { +impl ClientBuilderExt for reqwest::ClientBuilder { + fn with_http_config(self, config: &HttpConfig) -> Self { + self.connect_timeout(Duration::from_secs(config.connect_timeout)) + .read_timeout(Duration::from_secs(config.read_timeout)) + .pool_idle_timeout(Duration::from_secs(config.pool_idle_timeout)) + .pool_max_idle_per_host(config.pool_max_idle_per_host) + .redirect(Policy::limited(config.max_redirects)) + .hickory_dns(config.hickory) + .http2_adaptive_window(config.adaptive_window) + .http2_keep_alive_interval(config.keep_alive_interval.map(Duration::from_secs)) + .http2_keep_alive_timeout(Duration::from_secs(config.keep_alive_timeout)) + .http2_keep_alive_while_idle(config.keep_alive_while_idle) + .with_tls_config(config) + } + + fn with_tls_config(self, config: &HttpConfig) -> Self { + let mut builder = self; + + if let Some(ref cert_paths) = config.root_cert_paths { for cert_path in cert_paths { - match fs::read(cert_path) { + match std::fs::read(cert_path) { Ok(buf) => { if let Ok(cert) = Certificate::from_pem(&buf) { - client = client.add_root_certificate(cert); + builder = builder.add_root_certificate(cert); } else if let Ok(cert) = Certificate::from_der(&buf) { - client = client.add_root_certificate(cert); + builder = builder.add_root_certificate(cert); } else { warn!( - "Failed to parse certificate as PEM or DER format, cert = {}", - cert_path + cert = %cert_path, + "Failed to parse certificate as PEM or DER format" ); } } Err(error) => { - warn!( - "Failed to read certificate file, path = {}, error = {}", - cert_path, error - ); + warn!(cert = %cert_path, %error, "Failed to read certificate file"); } } } } - if env_http.http.accept_invalid_certs { - client = client.danger_accept_invalid_certs(true); + if config.accept_invalid_certs { + builder = builder.danger_accept_invalid_certs(true); } - if let Some(version) = env_http.http.min_tls_version { - client = client.min_tls_version(to_reqwest_tls(version)); + if let Some(version) = config.min_tls_version.clone() { + builder = builder.min_tls_version(to_reqwest_tls(version)); } - if let Some(version) = env_http.http.max_tls_version { - client = client.max_tls_version(to_reqwest_tls(version)); + if let Some(version) = config.max_tls_version.clone() { + builder = builder.max_tls_version(to_reqwest_tls(version)); } - match env_http.http.tls_backend { - TlsBackend::Rustls => { - client = client.use_rustls_tls(); - } - TlsBackend::Default => {} + match &config.tls_backend { + TlsBackend::Rustls => builder.use_rustls_tls(), + TlsBackend::Default => builder, } - - Self { env, client: client.build().unwrap(), file: file_writer } - } - - async fn get(&self, url: &Url, headers: Option) -> anyhow::Result { - self.execute_request("GET", url, |client| { - client.get(url.clone()).headers(self.headers(headers)) - }) - .await } - async fn post( - &self, - url: &Url, - headers: Option, - body: Bytes, - ) -> anyhow::Result { - let mut request_headers = self.headers(headers); - request_headers.insert("Content-Type", HeaderValue::from_static("application/json")); - - self.write_debug_request(&body); - - self.execute_request("POST", url, |client| { - client.post(url.clone()).headers(request_headers).body(body) - }) - .await - } + fn with_proxy_fallback(self) -> anyhow::Result { + let has_https_proxy = std::env::var("HTTPS_PROXY") + .or_else(|_| std::env::var("https_proxy")) + .or_else(|_| std::env::var("ALL_PROXY")) + .or_else(|_| std::env::var("all_proxy")) + .is_ok(); + + if !has_https_proxy + && let Ok(proxy_url) = + std::env::var("HTTP_PROXY").or_else(|_| std::env::var("http_proxy")) + { + return Ok(self.proxy( + reqwest::Proxy::all(&proxy_url) + .map_err(|e| anyhow::anyhow!("Invalid HTTP_PROXY URL '{proxy_url}': {e}"))?, + )); + } - async fn delete(&self, url: &Url) -> anyhow::Result { - self.execute_request("DELETE", url, |client| { - client.delete(url.clone()).headers(self.headers(None)) - }) - .await + Ok(self) } - /// Generic helper method to execute HTTP requests with consistent error - /// handling - async fn execute_request( - &self, - method: &str, - url: &Url, - request_builder: B, - ) -> anyhow::Result + fn with_custom_headers( + self, + headers: impl IntoIterator, + ) -> anyhow::Result where - B: FnOnce(&Client) -> reqwest::RequestBuilder, + K: AsRef, + V: AsRef, { - let response = request_builder(&self.client) - .send() - .await - .with_context(|| format_http_context(None, method, url))?; - - let status = response.status(); - if !status.is_success() { - let error_body = response - .text() - .await - .unwrap_or_else(|_| "Unable to read response body".to_string()); - return Err(anyhow::anyhow!(error_body)) - .with_context(|| format_http_context(Some(status), method, url)); + let mut header_map = reqwest::header::HeaderMap::new(); + + for (key, value) in headers { + let k = key.as_ref(); + let v = value.as_ref(); + let header_name = reqwest::header::HeaderName::try_from(k) + .map_err(|e| anyhow::anyhow!("Invalid header name '{k}': {e}"))?; + let header_value = reqwest::header::HeaderValue::try_from(v) + .map_err(|e| anyhow::anyhow!("Invalid header value for '{k}': {e}"))?; + header_map.insert(header_name, header_value); } - Ok(response) - } - - // OpenRouter optional headers ref: https://openrouter.ai/docs/api-reference/overview#headers - // - `HTTP-Referer`: Identifies your app on openrouter.ai - // - `X-Title`: Sets/modifies your app's title - fn headers(&self, headers: Option) -> HeaderMap { - let mut headers = headers.unwrap_or_default(); - // Only set User-Agent if the provider hasn't already set one - if !headers.contains_key("User-Agent") { - headers.insert("User-Agent", HeaderValue::from_static("Forge")); + if header_map.is_empty() { + return Ok(self); } - headers.insert("X-Title", HeaderValue::from_static("forge")); - headers.insert( - "x-app-version", - HeaderValue::from_str(format!("v{VERSION}").as_str()) - .unwrap_or(HeaderValue::from_static("v0.1.0-dev")), - ); - headers.insert( - "HTTP-Referer", - HeaderValue::from_static("https://forgecode.dev"), - ); - headers.insert( - reqwest::header::CONNECTION, - HeaderValue::from_static("keep-alive"), - ); - debug!(headers = ?Self::sanitize_headers(&headers), "Request Headers"); - headers - } - - fn sanitize_headers(headers: &HeaderMap) -> HeaderMap { - let sensitive_headers = [AUTHORIZATION.as_str()]; - headers - .iter() - .map(|(name, value)| { - let name_str = name.as_str().to_lowercase(); - let value_str = if sensitive_headers.contains(&name_str.as_str()) { - HeaderValue::from_static("[REDACTED]") - } else { - value.clone() - }; - (name.clone(), value_str) - }) - .collect() - } -} - -impl ForgeHttpInfra { - fn write_debug_request(&self, body: &Bytes) { - if let Some(debug_path) = &self.env.debug_requests { - let file_writer = self.file.clone(); - let body_clone = body.clone(); - let debug_path = debug_path.clone(); - tokio::spawn(async move { - let _ = file_writer.write(&debug_path, body_clone).await; - }); - } - } - - async fn eventsource( - &self, - url: &Url, - headers: Option, - body: Bytes, - ) -> anyhow::Result { - let mut request_headers = self.headers(headers); - request_headers.insert("Content-Type", HeaderValue::from_static("application/json")); - - self.write_debug_request(&body); - - self.client - .post(url.clone()) - .headers(request_headers) - .body(body) - .eventsource() - .with_context(|| format_http_context(None, "POST (EventSource)", url)) - } -} - -/// Helper function to format HTTP request/response context for logging and -/// error reporting -fn format_http_context>(status: Option, method: &str, url: U) -> String { - if let Some(status) = status { - format!("{} {} {}", status.as_u16(), method, url.as_ref()) - } else { - format!("{} {}", method, url.as_ref()) - } -} - -#[async_trait::async_trait] -impl HttpInfra for ForgeHttpInfra { - async fn http_get(&self, url: &Url, headers: Option) -> anyhow::Result { - self.get(url, headers).await - } - - async fn http_post( - &self, - url: &Url, - headers: Option, - body: Bytes, - ) -> anyhow::Result { - self.post(url, headers, body).await - } - - async fn http_delete(&self, url: &Url) -> anyhow::Result { - self.delete(url).await - } - - async fn http_eventsource( - &self, - url: &Url, - headers: Option, - body: Bytes, - ) -> anyhow::Result { - self.eventsource(url, headers, body).await - } -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - use std::sync::Arc; - - use fake::{Fake, Faker}; - use forge_app::FileWriterInfra; - use forge_domain::{Environment, HttpConfig}; - use tokio::sync::Mutex; - - use super::*; - - #[derive(Clone)] - struct MockFileWriter { - writes: Arc>>, - } - - impl MockFileWriter { - fn new() -> Self { - Self { writes: Arc::new(Mutex::new(Vec::new())) } - } - - async fn get_writes(&self) -> Vec<(PathBuf, Bytes)> { - self.writes.lock().await.clone() - } - } - - #[async_trait::async_trait] - impl FileWriterInfra for MockFileWriter { - async fn write(&self, path: &std::path::Path, contents: Bytes) -> anyhow::Result<()> { - self.writes - .lock() - .await - .push((path.to_path_buf(), contents)); - Ok(()) - } - - async fn write_temp( - &self, - _prefix: &str, - _extension: &str, - _content: &str, - ) -> anyhow::Result { - Ok(Faker.fake()) - } - } - - fn create_test_env(debug_requests: Option) -> Environment { - Environment { debug_requests, http: HttpConfig::default(), ..Faker.fake() } - } - - #[tokio::test] - async fn test_debug_requests_none_does_not_write() { - let file_writer = MockFileWriter::new(); - let env = create_test_env(None); - let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); - - let body = Bytes::from("test request body"); - let url = Url::parse("https://api.test.com/messages").unwrap(); - - // Attempt to create eventsource (which triggers debug write if enabled) - let _ = http.eventsource(&url, None, body).await; - - // Give async task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let writes = file_writer.get_writes().await; - assert_eq!( - writes.len(), - 0, - "No files should be written when debug_requests is None" - ); - } - - #[tokio::test] - async fn test_debug_requests_with_valid_path() { - let file_writer = MockFileWriter::new(); - let debug_path = PathBuf::from("/tmp/forge-test/debug.json"); - let env = create_test_env(Some(debug_path.clone())); - let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); - - let body = Bytes::from("test request body"); - let url = Url::parse("https://api.test.com/messages").unwrap(); - - let _ = http.eventsource(&url, None, body.clone()).await; - - // Give async task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let writes = file_writer.get_writes().await; - assert_eq!(writes.len(), 1, "Should write one file"); - assert_eq!(writes[0].0, debug_path); - assert_eq!(writes[0].1, body); - } - - #[tokio::test] - async fn test_debug_requests_with_relative_path() { - let file_writer = MockFileWriter::new(); - let debug_path = PathBuf::from("./debug/requests.json"); - let env = create_test_env(Some(debug_path.clone())); - let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); - - let body = Bytes::from("test request body"); - let url = Url::parse("https://api.test.com/messages").unwrap(); - - let _ = http.eventsource(&url, None, body.clone()).await; - - // Give async task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let writes = file_writer.get_writes().await; - assert_eq!(writes.len(), 1, "Should write one file"); - assert_eq!(writes[0].0, debug_path); - assert_eq!(writes[0].1, body); - } - - #[tokio::test] - async fn test_debug_requests_post_none_does_not_write() { - let file_writer = MockFileWriter::new(); - let env = create_test_env(None); - let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); - - let body = Bytes::from("test request body"); - let url = Url::parse("http://127.0.0.1:9/responses").unwrap(); - - let _ = http.post(&url, None, body).await; - - // Give async task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let writes = file_writer.get_writes().await; - assert_eq!( - writes.len(), - 0, - "No files should be written for POST when debug_requests is None" - ); - } - - #[tokio::test] - async fn test_debug_requests_post_writes_body() { - let file_writer = MockFileWriter::new(); - let debug_path = PathBuf::from("/tmp/forge-test/debug-post.json"); - let env = create_test_env(Some(debug_path.clone())); - let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); - - let body = Bytes::from("test request body"); - let url = Url::parse("http://127.0.0.1:9/responses").unwrap(); - - let _ = http.post(&url, None, body.clone()).await; - - // Give async task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - let writes = file_writer.get_writes().await; - assert_eq!( - writes.len(), - 1, - "Should write one file for POST when debug_requests is set" - ); - assert_eq!(writes[0].0, debug_path); - assert_eq!(writes[0].1, body); - } - - #[tokio::test] - async fn test_debug_requests_fallback_on_dir_creation_failure() { - let file_writer = MockFileWriter::new(); - // Use a path with a parent that doesn't exist and can't be created - // (in practice, this would be a permission issue) - let debug_path = PathBuf::from("test_debug.json"); - let env = create_test_env(Some(debug_path.clone())); - let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); - - let body = Bytes::from("test request body"); - let url = Url::parse("https://api.test.com/messages").unwrap(); - - let _ = http.eventsource(&url, None, body.clone()).await; - - // Give async task time to complete - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - let writes = file_writer.get_writes().await; - // Should write to debug_path (no parent dir needed) - assert_eq!(writes.len(), 1, "Should write one file"); - assert_eq!(writes[0].0, debug_path); - assert_eq!(writes[0].1, body); + Ok(self.default_headers(header_map)) } } diff --git a/crates/forge_infra/src/lib.rs b/crates/forge_infra/src/lib.rs index 542c0f0891..89b9f576c3 100644 --- a/crates/forge_infra/src/lib.rs +++ b/crates/forge_infra/src/lib.rs @@ -15,6 +15,7 @@ mod grpc; mod http; mod inquire; mod kv_storage; +mod llm_client; mod mcp_client; mod mcp_server; mod walker; diff --git a/crates/forge_infra/src/llm_client.rs b/crates/forge_infra/src/llm_client.rs new file mode 100644 index 0000000000..d331d5be6d --- /dev/null +++ b/crates/forge_infra/src/llm_client.rs @@ -0,0 +1,395 @@ +use std::sync::{Arc, OnceLock}; + +use anyhow::Context; +use bytes::Bytes; +use forge_app::HttpInfra; +use forge_domain::Environment; +use reqwest::header::{AUTHORIZATION, HeaderMap, HeaderValue}; +use reqwest::{Client, Response, StatusCode, Url}; +use reqwest_eventsource::{EventSource, RequestBuilderExt}; +use tracing::debug; + +use crate::http::ClientBuilderExt; + +const VERSION: &str = match option_env!("APP_VERSION") { + None => env!("CARGO_PKG_VERSION"), + Some(v) => v, +}; + +pub struct ForgeHttpInfra { + client: OnceLock, + env: Environment, + file: Arc, +} + +impl ForgeHttpInfra { + pub fn new(env: Environment, file_writer: Arc) -> Self { + Self { env, client: OnceLock::new(), file: file_writer } + } + + fn client(&self) -> anyhow::Result<&Client> { + // Fast path: already initialized. + if let Some(client) = self.client.get() { + return Ok(client); + } + + // Build the client. On failure the error propagates and nothing is + // stored, so the next call will retry. + let new_client = reqwest::Client::builder() + .with_http_config(&self.env.http) + .with_proxy_fallback() + .and_then(|b| b.build().map_err(Into::into))?; + + // Store on success. If another thread raced us here and already stored + // a client, `get_or_init` returns theirs and drops ours — that's fine. + Ok(self.client.get_or_init(|| new_client)) + } + + async fn get(&self, url: &Url, headers: Option) -> anyhow::Result { + self.execute_request("GET", url, |client| { + client.get(url.clone()).headers(self.headers(headers)) + }) + .await + } + + async fn post( + &self, + url: &Url, + headers: Option, + body: Bytes, + ) -> anyhow::Result { + let mut request_headers = self.headers(headers); + request_headers.insert("Content-Type", HeaderValue::from_static("application/json")); + + self.write_debug_request(&body); + + self.execute_request("POST", url, |client| { + client.post(url.clone()).headers(request_headers).body(body) + }) + .await + } + + async fn delete(&self, url: &Url) -> anyhow::Result { + self.execute_request("DELETE", url, |client| { + client.delete(url.clone()).headers(self.headers(None)) + }) + .await + } + + /// Generic helper method to execute HTTP requests with consistent error + /// handling + async fn execute_request( + &self, + method: &str, + url: &Url, + request_builder: B, + ) -> anyhow::Result + where + B: FnOnce(&Client) -> reqwest::RequestBuilder, + { + let response = request_builder(self.client()?) + .send() + .await + .with_context(|| format_http_context(None, method, url))?; + + let status = response.status(); + if !status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| "Unable to read response body".to_string()); + return Err(anyhow::anyhow!(error_body)) + .with_context(|| format_http_context(Some(status), method, url)); + } + + Ok(response) + } + + // OpenRouter optional headers ref: https://openrouter.ai/docs/api-reference/overview#headers + // - `HTTP-Referer`: Identifies your app on openrouter.ai + // - `X-Title`: Sets/modifies your app's title + fn headers(&self, headers: Option) -> HeaderMap { + let mut headers = headers.unwrap_or_default(); + // Only set User-Agent if the provider hasn't already set one + if !headers.contains_key("User-Agent") { + headers.insert("User-Agent", HeaderValue::from_static("Forge")); + } + headers.insert("X-Title", HeaderValue::from_static("forge")); + headers.insert( + "x-app-version", + HeaderValue::from_str(format!("v{VERSION}").as_str()) + .unwrap_or(HeaderValue::from_static("v0.1.0-dev")), + ); + headers.insert( + "HTTP-Referer", + HeaderValue::from_static("https://forgecode.dev"), + ); + headers.insert( + reqwest::header::CONNECTION, + HeaderValue::from_static("keep-alive"), + ); + debug!(headers = ?Self::sanitize_headers(&headers), "Request Headers"); + headers + } + + fn sanitize_headers(headers: &HeaderMap) -> HeaderMap { + let sensitive_headers = [AUTHORIZATION.as_str()]; + headers + .iter() + .map(|(name, value)| { + let name_str = name.as_str().to_lowercase(); + let value_str = if sensitive_headers.contains(&name_str.as_str()) { + HeaderValue::from_static("[REDACTED]") + } else { + value.clone() + }; + (name.clone(), value_str) + }) + .collect() + } + + fn write_debug_request(&self, body: &Bytes) { + if let Some(debug_path) = &self.env.debug_requests { + let file_writer = self.file.clone(); + let body_clone = body.clone(); + let debug_path = debug_path.clone(); + tokio::spawn(async move { + let _ = file_writer.write(&debug_path, body_clone).await; + }); + } + } + + async fn eventsource( + &self, + url: &Url, + headers: Option, + body: Bytes, + ) -> anyhow::Result { + let mut request_headers = self.headers(headers); + request_headers.insert("Content-Type", HeaderValue::from_static("application/json")); + + self.write_debug_request(&body); + + self.client()? + .post(url.clone()) + .headers(request_headers) + .body(body) + .eventsource() + .with_context(|| format_http_context(None, "POST (EventSource)", url)) + } +} + +fn format_http_context>(status: Option, method: &str, url: U) -> String { + if let Some(status) = status { + format!("{} {} {}", status.as_u16(), method, url.as_ref()) + } else { + format!("{} {}", method, url.as_ref()) + } +} + +#[async_trait::async_trait] +impl HttpInfra for ForgeHttpInfra { + async fn http_get(&self, url: &Url, headers: Option) -> anyhow::Result { + self.get(url, headers).await + } + + async fn http_post( + &self, + url: &Url, + headers: Option, + body: Bytes, + ) -> anyhow::Result { + self.post(url, headers, body).await + } + + async fn http_delete(&self, url: &Url) -> anyhow::Result { + self.delete(url).await + } + + async fn http_eventsource( + &self, + url: &Url, + headers: Option, + body: Bytes, + ) -> anyhow::Result { + self.eventsource(url, headers, body).await + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + use std::sync::Arc; + + use fake::{Fake, Faker}; + use forge_app::FileWriterInfra; + use forge_domain::{Environment, HttpConfig}; + use tokio::sync::Mutex; + + use super::*; + + #[derive(Clone)] + struct MockFileWriter { + writes: Arc>>, + } + + impl MockFileWriter { + fn new() -> Self { + Self { writes: Arc::new(Mutex::new(Vec::new())) } + } + + async fn get_writes(&self) -> Vec<(PathBuf, Bytes)> { + self.writes.lock().await.clone() + } + } + + #[async_trait::async_trait] + impl FileWriterInfra for MockFileWriter { + async fn write(&self, path: &std::path::Path, contents: Bytes) -> anyhow::Result<()> { + self.writes + .lock() + .await + .push((path.to_path_buf(), contents)); + Ok(()) + } + + async fn write_temp( + &self, + _prefix: &str, + _extension: &str, + _content: &str, + ) -> anyhow::Result { + Ok(Faker.fake()) + } + } + + fn create_test_env(debug_requests: Option) -> Environment { + Environment { debug_requests, http: HttpConfig::default(), ..Faker.fake() } + } + + #[tokio::test] + async fn test_debug_requests_none_does_not_write() { + let file_writer = MockFileWriter::new(); + let env = create_test_env(None); + let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); + + let body = Bytes::from("test request body"); + let url = Url::parse("https://api.test.com/messages").unwrap(); + + let _ = http.eventsource(&url, None, body).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let writes = file_writer.get_writes().await; + assert_eq!( + writes.len(), + 0, + "No files should be written when debug_requests is None" + ); + } + + #[tokio::test] + async fn test_debug_requests_with_valid_path() { + let file_writer = MockFileWriter::new(); + let debug_path = PathBuf::from("/tmp/forge-test/debug.json"); + let env = create_test_env(Some(debug_path.clone())); + let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); + + let body = Bytes::from("test request body"); + let url = Url::parse("https://api.test.com/messages").unwrap(); + + let _ = http.eventsource(&url, None, body.clone()).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let writes = file_writer.get_writes().await; + assert_eq!(writes.len(), 1, "Should write one file"); + assert_eq!(writes[0].0, debug_path); + assert_eq!(writes[0].1, body); + } + + #[tokio::test] + async fn test_debug_requests_with_relative_path() { + let file_writer = MockFileWriter::new(); + let debug_path = PathBuf::from("./debug/requests.json"); + let env = create_test_env(Some(debug_path.clone())); + let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); + + let body = Bytes::from("test request body"); + let url = Url::parse("https://api.test.com/messages").unwrap(); + + let _ = http.eventsource(&url, None, body.clone()).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let writes = file_writer.get_writes().await; + assert_eq!(writes.len(), 1, "Should write one file"); + assert_eq!(writes[0].0, debug_path); + assert_eq!(writes[0].1, body); + } + + #[tokio::test] + async fn test_debug_requests_post_none_does_not_write() { + let file_writer = MockFileWriter::new(); + let env = create_test_env(None); + let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); + + let body = Bytes::from("test request body"); + let url = Url::parse("https://127.0.0.1:9/responses").unwrap(); + + let _ = http.post(&url, None, body).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let writes = file_writer.get_writes().await; + assert_eq!( + writes.len(), + 0, + "No files should be written for POST when debug_requests is None" + ); + } + + #[tokio::test] + async fn test_debug_requests_post_writes_body() { + let file_writer = MockFileWriter::new(); + let debug_path = PathBuf::from("/tmp/forge-test/debug-post.json"); + let env = create_test_env(Some(debug_path.clone())); + let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); + + let body = Bytes::from("test request body"); + let url = Url::parse("https://127.0.0.1:9/responses").unwrap(); + + let _ = http.post(&url, None, body.clone()).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let writes = file_writer.get_writes().await; + assert_eq!( + writes.len(), + 1, + "Should write one file for POST when debug_requests is set" + ); + assert_eq!(writes[0].0, debug_path); + assert_eq!(writes[0].1, body); + } + + #[tokio::test] + async fn test_debug_requests_fallback_on_dir_creation_failure() { + let file_writer = MockFileWriter::new(); + let debug_path = PathBuf::from("test_debug.json"); + let env = create_test_env(Some(debug_path.clone())); + let http = ForgeHttpInfra::new(env, Arc::new(file_writer.clone())); + + let body = Bytes::from("test request body"); + let url = Url::parse("https://api.test.com/messages").unwrap(); + + let _ = http.eventsource(&url, None, body.clone()).await; + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + let writes = file_writer.get_writes().await; + assert_eq!(writes.len(), 1, "Should write one file"); + assert_eq!(writes[0].0, debug_path); + assert_eq!(writes[0].1, body); + } +} diff --git a/crates/forge_infra/src/mcp_client.rs b/crates/forge_infra/src/mcp_client.rs index 662ca0ce85..dc7a608edc 100644 --- a/crates/forge_infra/src/mcp_client.rs +++ b/crates/forge_infra/src/mcp_client.rs @@ -1,13 +1,11 @@ use std::borrow::Cow; use std::collections::BTreeMap; use std::future::Future; -use std::str::FromStr; use std::sync::{Arc, OnceLock, RwLock}; use backon::{ExponentialBuilder, Retryable}; use forge_app::McpClientInfra; use forge_domain::{Image, McpHttpServer, McpServerConfig, ToolDefinition, ToolName, ToolOutput}; -use http::{HeaderName, HeaderValue, header}; use rmcp::model::{CallToolRequestParam, ClientInfo, Implementation, InitializeRequestParam}; use rmcp::service::RunningService; use rmcp::transport::sse_client::SseClientConfig; @@ -20,6 +18,7 @@ use tokio::io::{AsyncBufReadExt, BufReader}; use tokio::process::Command; use crate::error::Error; +use crate::http::ClientBuilderExt; const VERSION: &str = match option_env!("APP_VERSION") { Some(val) => val, @@ -34,15 +33,21 @@ pub struct ForgeMcpClient { config: McpServerConfig, env_vars: BTreeMap, resolved_config: Arc>>, + http_config: forge_domain::HttpConfig, } impl ForgeMcpClient { - pub fn new(config: McpServerConfig, env_vars: &BTreeMap) -> Self { + pub fn new( + config: McpServerConfig, + env_vars: &BTreeMap, + http_config: forge_domain::HttpConfig, + ) -> Self { Self { client: Default::default(), config, env_vars: env_vars.clone(), resolved_config: Arc::new(OnceLock::new()), + http_config, } } @@ -152,13 +157,11 @@ impl ForgeMcpClient { } fn reqwest_client(&self, config: &McpHttpServer) -> anyhow::Result { - let mut headers = header::HeaderMap::new(); - for (key, value) in config.headers.iter() { - headers.insert(HeaderName::from_str(key)?, HeaderValue::from_str(value)?); - } - - let client = reqwest::Client::builder().default_headers(headers); - Ok(client.build()?) + Ok(reqwest::Client::builder() + .with_proxy_fallback()? + .with_tls_config(&self.http_config) + .with_custom_headers(config.headers.iter())? + .build()?) } async fn list(&self) -> anyhow::Result> { diff --git a/crates/forge_infra/src/mcp_server.rs b/crates/forge_infra/src/mcp_server.rs index 1725313767..9e325a1447 100644 --- a/crates/forge_infra/src/mcp_server.rs +++ b/crates/forge_infra/src/mcp_server.rs @@ -1,12 +1,20 @@ use std::collections::BTreeMap; use forge_app::McpServerInfra; -use forge_domain::McpServerConfig; +use forge_domain::{HttpConfig, McpServerConfig}; use crate::mcp_client::ForgeMcpClient; #[derive(Clone)] -pub struct ForgeMcpServer; +pub struct ForgeMcpServer { + http_config: HttpConfig, +} + +impl ForgeMcpServer { + pub fn new(http_config: HttpConfig) -> Self { + Self { http_config } + } +} #[async_trait::async_trait] impl McpServerInfra for ForgeMcpServer { @@ -17,6 +25,10 @@ impl McpServerInfra for ForgeMcpServer { config: McpServerConfig, env_vars: &BTreeMap, ) -> anyhow::Result { - Ok(ForgeMcpClient::new(config, env_vars)) + Ok(ForgeMcpClient::new( + config, + env_vars, + self.http_config.clone(), + )) } }