From dd91ad93f0bc4754d7cbfb9b51ca2c0930059de8 Mon Sep 17 00:00:00 2001 From: Tomasz Klak Date: Wed, 1 Apr 2026 16:46:09 +0200 Subject: [PATCH 1/3] Prevent unbounded memory usage when running async http get requests Add upper bound on the possible memory usage for http traffic, hardcoded to 256KB. --- Cargo.toml | 10 +- rust-toolchain | 2 +- src/aio/search.rs | 220 ++++++++++++++++++++++++++++++++++++++++---- src/common/mod.rs | 2 +- src/common/tests.rs | 36 +++++--- src/errors.rs | 6 ++ 6 files changed, 240 insertions(+), 36 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7b67b333..bc11ca89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,22 +15,30 @@ version = "0.12.1" all-features = true [dependencies] +bytes = "1.11.1" +http = "1.4.0" +http-body-util = "0.1" +hyper = { version = "1", features = ["http1", "client"] } +hyper-util = { version = "0.1", features = ["http1", "client"] } log = "0.4" rand = "0.8" reqwest = { version = "0.12.9", default-features = false, features = ["blocking", "rustls-tls"] } thiserror = "2.0.4" -tokio = {version = "1", optional = true, features = ["net"]} +tokio = {version = "1", optional = true, features = ["net", "macros"]} url = "2" xmltree = "0.11" [dev-dependencies] +assert_matches = "1.5.0" http-body-util = "0.1" +httptest = "0.16.4" hyper = { package = "hyper", version = "1", features = ["server", "http1"] } hyper-util = { version = "0.1", features = ["tokio"] } paste = "1.0.15" simplelog = "0.9" test-log = "0.2" tokio = {version = "1", features = ["full"]} +tokio-stream = "0.1.18" [features] aio = ["tokio"] diff --git a/rust-toolchain b/rust-toolchain index 369f9966..59be5921 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -1.77.2 +1.88.0 diff --git a/src/aio/search.rs b/src/aio/search.rs index 29843264..82adbb62 100644 --- a/src/aio/search.rs +++ b/src/aio/search.rs @@ -3,7 +3,11 @@ use std::future::Future; use std::net::SocketAddr; use std::str::FromStr; use std::time::Duration; -use tokio::net::UdpSocket; + +use http::header::HOST; +use http::Uri; +use hyper::{client::conn::http1::Builder, Request, StatusCode}; +use tokio::net::{TcpStream, UdpSocket}; use tokio::time::timeout; use crate::aio::Gateway; @@ -11,6 +15,7 @@ use crate::common::{messages, parsing, SearchOptions}; use crate::errors::SearchError; use crate::search::{check_is_ip_spoofed, validate_url}; +const MAX_HTTP_RESPONSE_SIZE: usize = 256 * 1024; const MAX_RESPONSE_SIZE: usize = 1500; /// Search for a gateway with the provided options @@ -43,7 +48,7 @@ pub async fn search_gateway(options: SearchOptions) -> Result Result<(SocketAddr, } async fn get_control_urls(addr: &SocketAddr, path: &str) -> Result<(String, String), SearchError> { - let url: reqwest::Url = format!("http://{}{}", addr, path).parse()?; - - validate_url(addr.ip(), &url)?; - - debug!("requesting control url from: {:?}", url); - let client = reqwest::Client::new(); - let resp = client.get(url).send().await?; - + debug!("requesting control url from: http://{}{}", addr, path); + let body = http_get_bounded(addr, path, MAX_HTTP_RESPONSE_SIZE).await?; debug!("handling control response from: {}", addr); - let body = resp.bytes().await?; - parsing::parse_control_urls(body.as_ref()) + parsing::parse_control_urls(std::io::Cursor::new(body)) } async fn get_control_schemas( addr: &SocketAddr, control_schema_url: &str, ) -> Result>, SearchError> { - let url: reqwest::Url = format!("http://{}{}", addr, control_schema_url).parse()?; + debug!("requesting control schema from: http://{}{}", addr, control_schema_url); + let body = http_get_bounded(addr, control_schema_url, MAX_HTTP_RESPONSE_SIZE).await?; + debug!("handling schema response from: {}", addr); + parsing::parse_schemas(std::io::Cursor::new(body)) +} +async fn http_get_bounded(addr: &SocketAddr, path: &str, memory_upper_bound: usize) -> Result, SearchError> { + use http_body_util::BodyExt; + + let authority = addr.to_string(); + let uri: Uri = format!("http://{}{}", addr, path).parse()?; + + let url: url::Url = uri.to_string().parse()?; validate_url(addr.ip(), &url)?; - debug!("requesting control schema from: {}", url); - let client = reqwest::Client::new(); - let resp = client.get(url).send().await?; + let stream = TcpStream::connect(addr) + .await + .map_err(|e| SearchError::HttpError(e.to_string()))?; + let io = hyper_util::rt::TokioIo::new(stream); - debug!("handling schema response from: {}", addr); + let (mut sender, connection) = Builder::new() + .max_buf_size(memory_upper_bound) + .handshake(io) + .await + .map_err(|e| SearchError::HttpError(e.to_string()))?; + + let req = Request::builder() + .uri(&uri) + .header(HOST, &authority) + .body(http_body_util::Empty::::new()) + .map_err(|e| SearchError::HttpError(e.to_string()))?; + + tokio::spawn(async move { + // See why we need to await connection: + // https://docs.rs/hyper/latest/hyper/client/conn/http1/struct.Builder.html#method.handshake + if let Err(e) = connection.await { + error!("http connection failed: {e}"); + } + }); + + let resp = sender + .send_request(req) + .await + .map_err(|e| SearchError::HttpError(e.to_string()))?; + + if resp.status() != StatusCode::OK { + return Err(SearchError::HttpError(format!("unexpected status: {}", resp.status()))); + } + + let body = http_body_util::Limited::new(resp.into_body(), memory_upper_bound) + .collect() + .await + .map_err(|e| SearchError::HttpError(e.to_string()))? + .to_bytes(); + + Ok(body.to_vec()) +} + +#[cfg(test)] +mod tests { + use std::{ + convert::Infallible, + net::{Ipv4Addr, SocketAddrV4}, + }; + + use assert_matches::assert_matches; + use http::Response; + use http_body_util::StreamBody; + use httptest::{matchers::request, responders::status_code, Expectation, ServerBuilder}; + use hyper::{ + body::{Bytes, Frame}, + server::conn::http1, + service::service_fn, + }; + use hyper_util::rt::TokioIo; + use rand::{distributions::Alphanumeric, thread_rng, Rng}; + use tokio::net::TcpListener; + use tokio_stream::wrappers::ReceiverStream; - let body = resp.bytes().await?; - parsing::parse_schemas(body.as_ref()) + use super::*; + + fn generate_random_body(n: usize) -> Vec { + let s: String = thread_rng() + .sample_iter(&Alphanumeric) + .take(n) + .map(char::from) + .collect(); + s.into_bytes() + } + + #[tokio::test] + async fn working_http_get_bounded() { + // 8k is a minimum max buffer size allowed by http1 / hyper: + // see: https://github.com/hyperium/hyper/blob/0d6c7d5469baa09e2fb127ee3758a79b3271a4f0/src/proto/h1/io.rs#L14-L18 + for memory_bound in [8 * 1024, 16 * 1024, 32 * 1024] { + for body_size in (0..=memory_bound).step_by(512) { + let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)); + let server = ServerBuilder::new().bind_addr(bind_addr).run().unwrap(); + let addr = server.addr(); + let get_url = server.url("/get"); + let path = get_url.path(); + + let test_body = generate_random_body(body_size); + + server.expect( + Expectation::matching(request::method_path("GET", "/get")) + .respond_with(status_code(200).body(test_body.clone())), + ); + let body = http_get_bounded(&addr, path, memory_bound).await.unwrap(); + + assert_eq!(test_body, body); + } + } + } + + #[tokio::test] + async fn failing_http_get_bounded() { + const MEMORY_BOUND: usize = 16 * 1024; + + let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)); + let server = ServerBuilder::new().bind_addr(bind_addr).run().unwrap(); + let addr = server.addr(); + let get_url = server.url("/get"); + let path = get_url.path(); + + let test_body = generate_random_body(MEMORY_BOUND + 1); + + server.expect( + Expectation::matching(request::method_path("GET", "/get")) + .respond_with(status_code(200).body(test_body.clone())), + ); + assert_matches!( + http_get_bounded(&addr, path, MEMORY_BOUND).await, + Err(SearchError::HttpError(m)) if m == "length limit exceeded" + ); + } + + async fn infinite_body_handle( + _req: Request, + ) -> Result, Infallible>>>>, Infallible> { + let (tx, rx) = tokio::sync::mpsc::channel::, Infallible>>(2); + + tokio::spawn(async move { + let chunk = Bytes::from(vec![b'A'; 4096]); + loop { + if tx.send(Ok(Frame::data(chunk.clone()))).await.is_err() { + break; + } + } + }); + + let stream = ReceiverStream::new(rx); + let body = StreamBody::new(stream); + + Ok(Response::builder() + .header("transfer-encoding", "chunked") + .body(body) + .unwrap()) + } + + async fn start_infinite_server() -> Result> { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await?; + let addr = listener.local_addr().unwrap(); + eprintln!("Listening on http://{addr}"); + + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let io = TokioIo::new(stream); + + tokio::spawn(async move { + if let Err(e) = http1::Builder::new() + .serve_connection(io, service_fn(infinite_body_handle)) + .await + { + eprintln!("connection error: {e}"); + } + }); + } + }); + + Ok(addr) + } + + #[tokio::test] + async fn search_gateway_should_fail_for_infinite_http_get_body() { + let http_addr = start_infinite_server().await.unwrap(); + let local_free_port = crate::common::tests::start_broadcast_reply_sender(format!("http://{http_addr}")).await; + let options = crate::common::tests::default_options_with_using_free_port(local_free_port).await; + + assert_matches!( + search_gateway(options).await, + Err(SearchError::HttpError(m)) if m == "length limit exceeded" + ); + } } diff --git a/src/common/mod.rs b/src/common/mod.rs index 7cfac75b..876ff53f 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -2,7 +2,7 @@ pub mod messages; pub mod options; pub mod parsing; #[cfg(test)] -mod tests; +pub mod tests; pub use self::options::SearchOptions; diff --git a/src/common/tests.rs b/src/common/tests.rs index c9af0e5c..cd81be33 100644 --- a/src/common/tests.rs +++ b/src/common/tests.rs @@ -6,19 +6,21 @@ use paste::paste; use std::{ convert::Infallible, future::Future, - net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4}, time::Duration, }; use test_log::test; use tokio::net::{TcpListener, UdpSocket}; -async fn start_broadcast_reply_sender(location: String) -> u16 { - let local_free_port = { - // Not 100% reliable way to find a free port number, but should be good enough - let sock = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, 0)).await.unwrap(); - let ret = sock.local_addr().unwrap().port(); - ret - }; +async fn find_free_port(ip: IpAddr) -> u16 { + // Not 100% reliable way to find a free port number, but should be good enough + let sock = UdpSocket::bind((ip, 0)).await.unwrap(); + let ret = sock.local_addr().unwrap().port(); + ret +} + +pub async fn start_broadcast_reply_sender(location: String) -> u16 { + let local_free_port = find_free_port(Ipv4Addr::LOCALHOST.into()).await; tokio::spawn(async move { tokio::time::sleep(Duration::from_secs(1)).await; @@ -35,11 +37,14 @@ async fn start_broadcast_reply_sender(location: String) -> u16 { local_free_port } -fn default_options_with_using_free_port(port: u16) -> SearchOptions { +pub async fn default_options_with_using_free_port(port: u16) -> SearchOptions { + let broadcast_ip: IpAddr = [239u8, 255, 255, 250].into(); + let free_broadcast_port = find_free_port(broadcast_ip).await; SearchOptions { bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)), timeout: Some(Duration::from_secs(5)), http_timeout: Some(Duration::from_secs(1)), + broadcast_address: (broadcast_ip, free_broadcast_port).into(), ..Default::default() } } @@ -120,7 +125,10 @@ const RESP_CONTROL_SCHEMA: &'static str = r#" "#; async fn start_http_server(responses: Vec) -> u16 { - let listener = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], 0))).await.unwrap(); + let listener = TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))) + .await + .unwrap(); + println!("starting new http listener on: {:?}", listener.local_addr()); let local_port = listener.local_addr().unwrap().port(); tokio::task::spawn(async move { @@ -187,7 +195,7 @@ macro_rules! run_tests { run_tests! { fn ip_spoofing_in_broadcast_response(search_gateway) { let local_free_port = start_broadcast_reply_sender("http://1.2.3.4:5".to_owned()).await; - let options = default_options_with_using_free_port(local_free_port); + let options = default_options_with_using_free_port(local_free_port).await; let result = search_gateway(options).await; if let Err(SearchError::SpoofedIp { src_ip, url_ip }) = result { @@ -202,7 +210,7 @@ run_tests! { let http_port = start_http_server(vec![RESP_SPOOFED_SCPDURL.to_owned()]).await; let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await; - let options = default_options_with_using_free_port(local_free_port); + let options = default_options_with_using_free_port(local_free_port).await; let result = search_gateway(options).await; if let Err(SearchError::SpoofedUrl { src_ip, url_host }) = result { @@ -221,7 +229,7 @@ run_tests! { .await; let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await; - let options = default_options_with_using_free_port(local_free_port); + let options = default_options_with_using_free_port(local_free_port).await; let result = search_gateway(options).await; if let Err(SearchError::SpoofedUrl { src_ip, url_host }) = result { @@ -235,7 +243,7 @@ run_tests! { fn non_spoofed_urls_result_in_search_gateway_success(search_gateway) { let http_port = start_http_server(vec![RESP.to_owned(), RESP_CONTROL_SCHEMA.to_owned()]).await; let local_free_port = start_broadcast_reply_sender(format!("http://127.0.0.1:{http_port}")).await; - let options = default_options_with_using_free_port(local_free_port); + let options = default_options_with_using_free_port(local_free_port).await; assert!(search_gateway(options).await.is_ok()); } } diff --git a/src/errors.rs b/src/errors.rs index 867f5af5..fa94d609 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -245,6 +245,9 @@ pub enum SearchError { #[error("Error parsing URI: {0}")] /// Invalid uri InvalidUri(#[from] url::ParseError), + #[error("Invalid HTTP Uri: {0}")] + /// Invalid HTTP Uri + InvalidHttpUri(#[from] http::uri::InvalidUri), #[error("The uri is missing the host: {0}")] /// Uri is missing host UrlMissingHost(reqwest::Url), @@ -265,6 +268,9 @@ pub enum SearchError { /// The IP which the receiving packet pretended to be from url_host: String, }, + #[error("HTTP transport error: {0}")] + /// HTTP transport error + HttpError(String), } #[cfg(feature = "aio")] From 595baa0b8522d88d421e08843d3fd2533aa13968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20K=C5=82ak?= Date: Thu, 2 Apr 2026 11:45:00 +0200 Subject: [PATCH 2/3] Use macos-14 macOS 15 introduced Local Network Access privacy restrictions that block multicast/broadcast network access for unsigned, non-entitled processes (like your test binary). For more details see: - github: https://github.com/actions/runner-images/issues/10924 - apple: https://developer.apple.com/forums/thread/770473 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7ded3586..8ccf6f66 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-14, windows-latest] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v2 From 1b2179dc466ef7d0d2f8c26cf5b303a16ee4008c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tomasz=20K=C5=82ak?= Date: Thu, 2 Apr 2026 12:33:18 +0200 Subject: [PATCH 3/3] Don't randomize broadcast port on windows --- src/common/tests.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/common/tests.rs b/src/common/tests.rs index cd81be33..8840de49 100644 --- a/src/common/tests.rs +++ b/src/common/tests.rs @@ -39,7 +39,14 @@ pub async fn start_broadcast_reply_sender(location: String) -> u16 { pub async fn default_options_with_using_free_port(port: u16) -> SearchOptions { let broadcast_ip: IpAddr = [239u8, 255, 255, 250].into(); + + // We don't want to use the standard 1900 port, to avoid any collisions with + // already existing upnp gateways in the network of the host machine (eg. your router) + #[cfg(not(target_os = "windows"))] let free_broadcast_port = find_free_port(broadcast_ip).await; + #[cfg(target_os = "windows")] + let free_broadcast_port = 49012; // Windows doesn't allow for binding to broadcast addresses using port 0 + SearchOptions { bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port)), timeout: Some(Duration::from_secs(5)),