Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 33 additions & 11 deletions crates/fetchkit/src/fetchers/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ impl Default for DefaultFetcher {
}

/// Build headers for HTTP requests
fn build_headers(options: &FetchOptions, accept: &str, request: &FetchRequest) -> HeaderMap {
pub(crate) fn build_headers(
options: &FetchOptions,
accept: &str,
request: &FetchRequest,
) -> HeaderMap {
let mut headers = HeaderMap::new();
let user_agent = options.user_agent.as_deref().unwrap_or(DEFAULT_USER_AGENT);
headers.insert(
Expand Down Expand Up @@ -111,7 +115,7 @@ fn build_headers(options: &FetchOptions, accept: &str, request: &FetchRequest) -

/// Apply bot-auth signature headers when the feature is enabled and configured.
#[cfg(feature = "bot-auth")]
fn apply_bot_auth_if_enabled(
pub(crate) fn apply_bot_auth_if_enabled(
mut headers: HeaderMap,
options: &FetchOptions,
url: &Url,
Expand Down Expand Up @@ -142,7 +146,11 @@ fn apply_bot_auth_if_enabled(
}

#[cfg(not(feature = "bot-auth"))]
fn apply_bot_auth_if_enabled(headers: HeaderMap, _options: &FetchOptions, _url: &Url) -> HeaderMap {
pub(crate) fn apply_bot_auth_if_enabled(
headers: HeaderMap,
_options: &FetchOptions,
_url: &Url,
) -> HeaderMap {
headers
}

Expand Down Expand Up @@ -220,8 +228,14 @@ impl Fetcher for DefaultFetcher {
};

// THREAT[TM-SSRF-010]: Follow redirects manually so every hop is re-validated.
let (response, redirect_chain) =
send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?;
let (response, redirect_chain) = send_request_following_redirects(
parsed_url,
reqwest_method,
headers,
options,
FIRST_BYTE_TIMEOUT,
)
.await?;

let status_code = response.status().as_u16();
let final_url = response.url().to_string();
Expand Down Expand Up @@ -394,8 +408,14 @@ impl Fetcher for DefaultFetcher {
};

// THREAT[TM-SSRF-010]: Follow redirects manually with IP validation at each hop
let (response, redirect_chain) =
send_request_following_redirects(parsed_url, reqwest_method, headers, options).await?;
let (response, redirect_chain) = send_request_following_redirects(
parsed_url,
reqwest_method,
headers,
options,
FIRST_BYTE_TIMEOUT,
)
.await?;

let status_code = response.status().as_u16();
let final_url = response.url().to_string();
Expand Down Expand Up @@ -446,17 +466,18 @@ impl Fetcher for DefaultFetcher {
}

/// Returns `(response, redirect_chain)` where redirect_chain lists intermediate URLs.
async fn send_request_following_redirects(
pub(crate) async fn send_request_following_redirects(
initial_url: Url,
method: reqwest::Method,
headers: HeaderMap,
options: &FetchOptions,
timeout: Duration,
) -> Result<(reqwest::Response, Vec<String>), FetchError> {
let mut current_url = initial_url;
let mut redirect_chain = Vec::new();

for redirect_count in 0..=MAX_REDIRECTS {
let client = build_client_for_url(&current_url, headers.clone(), options)?;
let client = build_client_for_url(&current_url, headers.clone(), options, timeout)?;
let response = client
.request(method.clone(), current_url.clone())
.send()
Expand Down Expand Up @@ -489,12 +510,13 @@ fn build_client_for_url(
url: &Url,
headers: HeaderMap,
options: &FetchOptions,
timeout: Duration,
) -> Result<reqwest::Client, FetchError> {
// THREAT[TM-NET-003]: New client per request prevents connection-pool state leakage
let mut client_builder = reqwest::Client::builder()
.default_headers(headers)
.connect_timeout(FIRST_BYTE_TIMEOUT)
.timeout(FIRST_BYTE_TIMEOUT)
.connect_timeout(timeout)
.timeout(timeout)
.redirect(reqwest::redirect::Policy::none());

if !options.respect_proxy_env {
Expand Down
61 changes: 30 additions & 31 deletions crates/fetchkit/src/fetchers/rss_feed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

use crate::client::FetchOptions;
use crate::error::FetchError;
use crate::fetchers::default::{apply_bot_auth_if_enabled, send_request_following_redirects};
use crate::fetchers::Fetcher;
use crate::types::{FetchRequest, FetchResponse};
use crate::DEFAULT_USER_AGENT;
use async_trait::async_trait;
use reqwest::header::{HeaderValue, ACCEPT, USER_AGENT};
use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, USER_AGENT};
use std::time::Duration;
use url::Url;

Expand Down Expand Up @@ -86,40 +87,35 @@ impl Fetcher for RSSFeedFetcher {
options: &FetchOptions,
) -> Result<FetchResponse, FetchError> {
let user_agent = options.user_agent.as_deref().unwrap_or(DEFAULT_USER_AGENT);
let mut client_builder = reqwest::Client::builder()
.connect_timeout(API_TIMEOUT)
.timeout(API_TIMEOUT)
.redirect(reqwest::redirect::Policy::limited(5));

if !options.respect_proxy_env {
client_builder = client_builder.no_proxy();
}

let client = client_builder
.build()
.map_err(FetchError::ClientBuildError)?;

let mut headers = HeaderMap::new();
let ua_header = HeaderValue::from_str(user_agent)
.unwrap_or_else(|_| HeaderValue::from_static(DEFAULT_USER_AGENT));

let response = client
.get(&request.url)
.header(USER_AGENT, ua_header)
.header(
ACCEPT,
HeaderValue::from_static(
"application/rss+xml, application/atom+xml, application/xml, text/xml, */*",
),
)
.send()
.await
.map_err(FetchError::from_reqwest)?;
headers.insert(USER_AGENT, ua_header);
headers.insert(
ACCEPT,
HeaderValue::from_static(
"application/rss+xml, application/atom+xml, application/xml, text/xml, */*",
),
);

let parsed_url = Url::parse(&request.url).map_err(|_| FetchError::InvalidUrlScheme)?;
let headers = apply_bot_auth_if_enabled(headers, options, &parsed_url);
let (response, redirect_chain) = send_request_following_redirects(
parsed_url,
reqwest::Method::GET,
headers,
options,
API_TIMEOUT,
)
.await?;

let status_code = response.status().as_u16();
let final_url = response.url().to_string();
if !response.status().is_success() {
return Ok(FetchResponse {
url: request.url.clone(),
url: final_url,
status_code,
redirect_chain,
error: Some(format!("HTTP {}", status_code)),
..Default::default()
});
Expand Down Expand Up @@ -150,29 +146,32 @@ impl Fetcher for RSSFeedFetcher {
} else if is_feed_by_ct {
// Content-type indicates a feed but structure wasn't recognized — return as raw XML
return Ok(FetchResponse {
url: request.url.clone(),
url: final_url,
status_code: 200,
content: Some(body),
format: Some("raw".to_string()),
redirect_chain,
..Default::default()
});
} else {
// Not a recognized feed format
return Ok(FetchResponse {
url: request.url.clone(),
url: final_url,
status_code: 200,
content: Some(body),
format: Some("raw".to_string()),
redirect_chain,
..Default::default()
});
};

Ok(FetchResponse {
url: request.url.clone(),
url: final_url,
status_code: 200,
content_type: Some("text/markdown".to_string()),
format: Some("rss_feed".to_string()),
content: Some(content),
redirect_chain,
..Default::default()
})
}
Expand Down
64 changes: 64 additions & 0 deletions crates/fetchkit/tests/ssrf_security.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,70 @@ async fn test_ssrf_010_same_host_redirect_policy_blocks_cross_host_redirect() {
assert!(matches!(result, Err(FetchError::BlockedUrl)));
}

#[tokio::test]
async fn test_ssrf_010_rss_fetcher_blocks_loopback_feed_by_default() {
let mock_server = MockServer::start().await;
let rss = r#"<?xml version="1.0"?>
<rss version="2.0">
<channel>
<title>Loopback Feed</title>
<item>
<title>Entry</title>
<description>Hello</description>
</item>
</channel>
</rss>"#;

Mock::given(method("GET"))
.and(path("/feed"))
.respond_with(ResponseTemplate::new(200).set_body_raw(rss, "application/rss+xml"))
.mount(&mock_server)
.await;

let req = FetchRequest::new(format!("{}/feed", mock_server.uri()));
let result = Tool::default().execute(req).await;

assert!(matches!(result, Err(FetchError::BlockedUrl)));
}

#[tokio::test]
async fn test_ssrf_010_rss_fetcher_enforces_same_host_redirect_policy() {
let mock_server = MockServer::start().await;
let server_addr = mock_server.address();
let final_feed_url = format!("http://127.0.0.1:{}/final-feed", server_addr.port());
let rss = r#"<?xml version="1.0"?>
<rss version="2.0">
<channel>
<title>Redirected Feed</title>
<item>
<title>Entry</title>
<description>Hello</description>
</item>
</channel>
</rss>"#;

Mock::given(method("GET"))
.and(path("/feed"))
.respond_with(ResponseTemplate::new(302).insert_header("Location", &final_feed_url))
.mount(&mock_server)
.await;

Mock::given(method("GET"))
.and(path("/final-feed"))
.respond_with(ResponseTemplate::new(200).set_body_raw(rss, "application/rss+xml"))
.mount(&mock_server)
.await;

let tool = Tool::builder()
.block_private_ips(false)
.same_host_redirects_only(true)
.build();
let req = FetchRequest::new(format!("http://localhost:{}/feed", server_addr.port()));
let result = tool.execute(req).await;

assert!(matches!(result, Err(FetchError::BlockedUrl)));
}

// ============================================================================
// TM-NET-004: Ambient proxy environment variables
// ============================================================================
Expand Down
Loading