diff --git a/crates/fetchkit/src/fetchers/default.rs b/crates/fetchkit/src/fetchers/default.rs index 137436a..23e07a9 100644 --- a/crates/fetchkit/src/fetchers/default.rs +++ b/crates/fetchkit/src/fetchers/default.rs @@ -81,7 +81,11 @@ impl Default for DefaultFetcher { } /// Build headers for HTTP requests -fn build_headers(options: &FetchOptions, accept: &str, request: &FetchRequest) -> HeaderMap { +pub(crate) fn build_headers( + options: &FetchOptions, + accept: &str, + request: &FetchRequest, +) -> HeaderMap { let mut headers = HeaderMap::new(); let user_agent = options.user_agent.as_deref().unwrap_or(DEFAULT_USER_AGENT); headers.insert( @@ -111,7 +115,7 @@ fn build_headers(options: &FetchOptions, accept: &str, request: &FetchRequest) - /// Apply bot-auth signature headers when the feature is enabled and configured. #[cfg(feature = "bot-auth")] -fn apply_bot_auth_if_enabled( +pub(crate) fn apply_bot_auth_if_enabled( mut headers: HeaderMap, options: &FetchOptions, url: &Url, @@ -142,7 +146,11 @@ fn apply_bot_auth_if_enabled( } #[cfg(not(feature = "bot-auth"))] -fn apply_bot_auth_if_enabled(headers: HeaderMap, _options: &FetchOptions, _url: &Url) -> HeaderMap { +pub(crate) fn apply_bot_auth_if_enabled( + headers: HeaderMap, + _options: &FetchOptions, + _url: &Url, +) -> HeaderMap { headers } @@ -220,8 +228,14 @@ impl Fetcher for DefaultFetcher { }; // THREAT[TM-SSRF-010]: Follow redirects manually so every hop is re-validated. - let (response, redirect_chain) = - send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?; + let (response, redirect_chain) = send_request_following_redirects( + parsed_url, + reqwest_method, + headers, + options, + FIRST_BYTE_TIMEOUT, + ) + .await?; let status_code = response.status().as_u16(); let final_url = response.url().to_string(); @@ -394,8 +408,14 @@ impl Fetcher for DefaultFetcher { }; // THREAT[TM-SSRF-010]: Follow redirects manually with IP validation at each hop - let (response, redirect_chain) = - send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?; + let (response, redirect_chain) = send_request_following_redirects( + parsed_url, + reqwest_method, + headers, + options, + FIRST_BYTE_TIMEOUT, + ) + .await?; let status_code = response.status().as_u16(); let final_url = response.url().to_string(); @@ -446,17 +466,18 @@ impl Fetcher for DefaultFetcher { } /// Returns `(response, redirect_chain)` where redirect_chain lists intermediate URLs. -async fn send_request_following_redirects( +pub(crate) async fn send_request_following_redirects( initial_url: Url, method: reqwest::Method, headers: HeaderMap, options: &FetchOptions, + timeout: Duration, ) -> Result<(reqwest::Response, Vec), FetchError> { let mut current_url = initial_url; let mut redirect_chain = Vec::new(); for redirect_count in 0..=MAX_REDIRECTS { - let client = build_client_for_url(¤t_url, headers.clone(), options)?; + let client = build_client_for_url(¤t_url, headers.clone(), options, timeout)?; let response = client .request(method.clone(), current_url.clone()) .send() @@ -489,12 +510,13 @@ fn build_client_for_url( url: &Url, headers: HeaderMap, options: &FetchOptions, + timeout: Duration, ) -> Result { // THREAT[TM-NET-003]: New client per request prevents connection-pool state leakage let mut client_builder = reqwest::Client::builder() .default_headers(headers) - .connect_timeout(FIRST_BYTE_TIMEOUT) - .timeout(FIRST_BYTE_TIMEOUT) + .connect_timeout(timeout) + .timeout(timeout) .redirect(reqwest::redirect::Policy::none()); if !options.respect_proxy_env { diff --git a/crates/fetchkit/src/fetchers/rss_feed.rs b/crates/fetchkit/src/fetchers/rss_feed.rs index 96afa8c..f05f108 100644 --- a/crates/fetchkit/src/fetchers/rss_feed.rs +++ b/crates/fetchkit/src/fetchers/rss_feed.rs @@ -5,11 +5,12 @@ use crate::client::FetchOptions; use crate::error::FetchError; +use crate::fetchers::default::{apply_bot_auth_if_enabled, send_request_following_redirects}; use crate::fetchers::Fetcher; use crate::types::{FetchRequest, FetchResponse}; use crate::DEFAULT_USER_AGENT; use async_trait::async_trait; -use reqwest::header::{HeaderValue, ACCEPT, USER_AGENT}; +use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, USER_AGENT}; use std::time::Duration; use url::Url; @@ -86,40 +87,35 @@ impl Fetcher for RSSFeedFetcher { options: &FetchOptions, ) -> Result { let user_agent = options.user_agent.as_deref().unwrap_or(DEFAULT_USER_AGENT); - let mut client_builder = reqwest::Client::builder() - .connect_timeout(API_TIMEOUT) - .timeout(API_TIMEOUT) - .redirect(reqwest::redirect::Policy::limited(5)); - - if !options.respect_proxy_env { - client_builder = client_builder.no_proxy(); - } - - let client = client_builder - .build() - .map_err(FetchError::ClientBuildError)?; - + let mut headers = HeaderMap::new(); let ua_header = HeaderValue::from_str(user_agent) .unwrap_or_else(|_| HeaderValue::from_static(DEFAULT_USER_AGENT)); - - let response = client - .get(&request.url) - .header(USER_AGENT, ua_header) - .header( - ACCEPT, - HeaderValue::from_static( - "application/rss+xml, application/atom+xml, application/xml, text/xml, */*", - ), - ) - .send() - .await - .map_err(FetchError::from_reqwest)?; + headers.insert(USER_AGENT, ua_header); + headers.insert( + ACCEPT, + HeaderValue::from_static( + "application/rss+xml, application/atom+xml, application/xml, text/xml, */*", + ), + ); + + let parsed_url = Url::parse(&request.url).map_err(|_| FetchError::InvalidUrlScheme)?; + let headers = apply_bot_auth_if_enabled(headers, options, &parsed_url); + let (response, redirect_chain) = send_request_following_redirects( + parsed_url, + reqwest::Method::GET, + headers, + options, + API_TIMEOUT, + ) + .await?; let status_code = response.status().as_u16(); + let final_url = response.url().to_string(); if !response.status().is_success() { return Ok(FetchResponse { - url: request.url.clone(), + url: final_url, status_code, + redirect_chain, error: Some(format!("HTTP {}", status_code)), ..Default::default() }); @@ -150,29 +146,32 @@ impl Fetcher for RSSFeedFetcher { } else if is_feed_by_ct { // Content-type indicates a feed but structure wasn't recognized — return as raw XML return Ok(FetchResponse { - url: request.url.clone(), + url: final_url, status_code: 200, content: Some(body), format: Some("raw".to_string()), + redirect_chain, ..Default::default() }); } else { // Not a recognized feed format return Ok(FetchResponse { - url: request.url.clone(), + url: final_url, status_code: 200, content: Some(body), format: Some("raw".to_string()), + redirect_chain, ..Default::default() }); }; Ok(FetchResponse { - url: request.url.clone(), + url: final_url, status_code: 200, content_type: Some("text/markdown".to_string()), format: Some("rss_feed".to_string()), content: Some(content), + redirect_chain, ..Default::default() }) } diff --git a/crates/fetchkit/tests/ssrf_security.rs b/crates/fetchkit/tests/ssrf_security.rs index fefb4fb..792c8ea 100644 --- a/crates/fetchkit/tests/ssrf_security.rs +++ b/crates/fetchkit/tests/ssrf_security.rs @@ -465,6 +465,70 @@ async fn test_ssrf_010_same_host_redirect_policy_blocks_cross_host_redirect() { assert!(matches!(result, Err(FetchError::BlockedUrl))); } +#[tokio::test] +async fn test_ssrf_010_rss_fetcher_blocks_loopback_feed_by_default() { + let mock_server = MockServer::start().await; + let rss = r#" + + + Loopback Feed + + Entry + Hello + + +"#; + + Mock::given(method("GET")) + .and(path("/feed")) + .respond_with(ResponseTemplate::new(200).set_body_raw(rss, "application/rss+xml")) + .mount(&mock_server) + .await; + + let req = FetchRequest::new(format!("{}/feed", mock_server.uri())); + let result = Tool::default().execute(req).await; + + assert!(matches!(result, Err(FetchError::BlockedUrl))); +} + +#[tokio::test] +async fn test_ssrf_010_rss_fetcher_enforces_same_host_redirect_policy() { + let mock_server = MockServer::start().await; + let server_addr = mock_server.address(); + let final_feed_url = format!("http://127.0.0.1:{}/final-feed", server_addr.port()); + let rss = r#" + + + Redirected Feed + + Entry + Hello + + +"#; + + Mock::given(method("GET")) + .and(path("/feed")) + .respond_with(ResponseTemplate::new(302).insert_header("Location", &final_feed_url)) + .mount(&mock_server) + .await; + + Mock::given(method("GET")) + .and(path("/final-feed")) + .respond_with(ResponseTemplate::new(200).set_body_raw(rss, "application/rss+xml")) + .mount(&mock_server) + .await; + + let tool = Tool::builder() + .block_private_ips(false) + .same_host_redirects_only(true) + .build(); + let req = FetchRequest::new(format!("http://localhost:{}/feed", server_addr.port())); + let result = tool.execute(req).await; + + assert!(matches!(result, Err(FetchError::BlockedUrl))); +} + // ============================================================================ // TM-NET-004: Ambient proxy environment variables // ============================================================================