Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
os: [ubuntu-latest, macos-14, windows-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
Expand Down
10 changes: 9 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,30 @@ version = "0.12.1"
all-features = true

[dependencies]
bytes = "1.11.1"
http = "1.4.0"
http-body-util = "0.1"
hyper = { version = "1", features = ["http1", "client"] }
hyper-util = { version = "0.1", features = ["http1", "client"] }
log = "0.4"
rand = "0.8"
reqwest = { version = "0.12.9", default-features = false, features = ["blocking", "rustls-tls"] }
thiserror = "2.0.4"
tokio = {version = "1", optional = true, features = ["net"]}
tokio = {version = "1", optional = true, features = ["net", "macros"]}
url = "2"
xmltree = "0.11"

[dev-dependencies]
assert_matches = "1.5.0"
http-body-util = "0.1"
httptest = "0.16.4"
hyper = { package = "hyper", version = "1", features = ["server", "http1"] }
hyper-util = { version = "0.1", features = ["tokio"] }
paste = "1.0.15"
simplelog = "0.9"
test-log = "0.2"
tokio = {version = "1", features = ["full"]}
tokio-stream = "0.1.18"

[features]
aio = ["tokio"]
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.77.2
1.88.0
Comment thread
sfraczek marked this conversation as resolved.
220 changes: 201 additions & 19 deletions src/aio/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@ use std::future::Future;
use std::net::SocketAddr;
use std::str::FromStr;
use std::time::Duration;
use tokio::net::UdpSocket;

use http::header::HOST;
use http::Uri;
use hyper::{client::conn::http1::Builder, Request, StatusCode};
use tokio::net::{TcpStream, UdpSocket};
use tokio::time::timeout;

use crate::aio::Gateway;
use crate::common::{messages, parsing, SearchOptions};
use crate::errors::SearchError;
use crate::search::{check_is_ip_spoofed, validate_url};

const MAX_HTTP_RESPONSE_SIZE: usize = 256 * 1024;
const MAX_RESPONSE_SIZE: usize = 1500;

/// Search for a gateway with the provided options
Expand Down Expand Up @@ -43,7 +48,7 @@ pub async fn search_gateway(options: SearchOptions) -> Result<Gateway, SearchErr
addr: addr_v4,
root_url,
control_url,
control_schema_url,
control_schema_url: control_schema_url.to_owned(),
control_schema,
};

Expand Down Expand Up @@ -98,33 +103,210 @@ fn handle_broadcast_resp(from: &SocketAddr, data: &[u8]) -> Result<(SocketAddr,
}

async fn get_control_urls(addr: &SocketAddr, path: &str) -> Result<(String, String), SearchError> {
let url: reqwest::Url = format!("http://{}{}", addr, path).parse()?;

validate_url(addr.ip(), &url)?;

debug!("requesting control url from: {:?}", url);
let client = reqwest::Client::new();
let resp = client.get(url).send().await?;

debug!("requesting control url from: http://{}{}", addr, path);
let body = http_get_bounded(addr, path, MAX_HTTP_RESPONSE_SIZE).await?;
debug!("handling control response from: {}", addr);
let body = resp.bytes().await?;
parsing::parse_control_urls(body.as_ref())
parsing::parse_control_urls(std::io::Cursor::new(body))
}

async fn get_control_schemas(
addr: &SocketAddr,
control_schema_url: &str,
) -> Result<HashMap<String, Vec<String>>, SearchError> {
let url: reqwest::Url = format!("http://{}{}", addr, control_schema_url).parse()?;
debug!("requesting control schema from: http://{}{}", addr, control_schema_url);
let body = http_get_bounded(addr, control_schema_url, MAX_HTTP_RESPONSE_SIZE).await?;
debug!("handling schema response from: {}", addr);
parsing::parse_schemas(std::io::Cursor::new(body))
}

async fn http_get_bounded(addr: &SocketAddr, path: &str, memory_upper_bound: usize) -> Result<Vec<u8>, SearchError> {
use http_body_util::BodyExt;

let authority = addr.to_string();
let uri: Uri = format!("http://{}{}", addr, path).parse()?;

let url: url::Url = uri.to_string().parse()?;
validate_url(addr.ip(), &url)?;

debug!("requesting control schema from: {}", url);
let client = reqwest::Client::new();
let resp = client.get(url).send().await?;
let stream = TcpStream::connect(addr)
.await
.map_err(|e| SearchError::HttpError(e.to_string()))?;
let io = hyper_util::rt::TokioIo::new(stream);

debug!("handling schema response from: {}", addr);
let (mut sender, connection) = Builder::new()
.max_buf_size(memory_upper_bound)
.handshake(io)
.await
.map_err(|e| SearchError::HttpError(e.to_string()))?;

let req = Request::builder()
.uri(&uri)
.header(HOST, &authority)
.body(http_body_util::Empty::<bytes::Bytes>::new())
.map_err(|e| SearchError::HttpError(e.to_string()))?;

tokio::spawn(async move {
// See why we need to await connection:
// https://docs.rs/hyper/latest/hyper/client/conn/http1/struct.Builder.html#method.handshake
if let Err(e) = connection.await {
error!("http connection failed: {e}");
}
});

let resp = sender
.send_request(req)
.await
.map_err(|e| SearchError::HttpError(e.to_string()))?;

if resp.status() != StatusCode::OK {
return Err(SearchError::HttpError(format!("unexpected status: {}", resp.status())));
}

let body = http_body_util::Limited::new(resp.into_body(), memory_upper_bound)
.collect()
.await
.map_err(|e| SearchError::HttpError(e.to_string()))?
.to_bytes();

Ok(body.to_vec())
}

#[cfg(test)]
mod tests {
use std::{
convert::Infallible,
net::{Ipv4Addr, SocketAddrV4},
};

use assert_matches::assert_matches;
use http::Response;
use http_body_util::StreamBody;
use httptest::{matchers::request, responders::status_code, Expectation, ServerBuilder};
use hyper::{
body::{Bytes, Frame},
server::conn::http1,
service::service_fn,
};
use hyper_util::rt::TokioIo;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use tokio::net::TcpListener;
use tokio_stream::wrappers::ReceiverStream;

let body = resp.bytes().await?;
parsing::parse_schemas(body.as_ref())
use super::*;

fn generate_random_body(n: usize) -> Vec<u8> {
let s: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(n)
.map(char::from)
.collect();
s.into_bytes()
}

#[tokio::test]
async fn working_http_get_bounded() {
// 8k is a minimum max buffer size allowed by http1 / hyper:
// see: https://github.com/hyperium/hyper/blob/0d6c7d5469baa09e2fb127ee3758a79b3271a4f0/src/proto/h1/io.rs#L14-L18
for memory_bound in [8 * 1024, 16 * 1024, 32 * 1024] {
for body_size in (0..=memory_bound).step_by(512) {
let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let server = ServerBuilder::new().bind_addr(bind_addr).run().unwrap();
let addr = server.addr();
let get_url = server.url("/get");
let path = get_url.path();

let test_body = generate_random_body(body_size);

server.expect(
Expectation::matching(request::method_path("GET", "/get"))
.respond_with(status_code(200).body(test_body.clone())),
);
let body = http_get_bounded(&addr, path, memory_bound).await.unwrap();

assert_eq!(test_body, body);
}
}
}

#[tokio::test]
async fn failing_http_get_bounded() {
const MEMORY_BOUND: usize = 16 * 1024;

let bind_addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let server = ServerBuilder::new().bind_addr(bind_addr).run().unwrap();
let addr = server.addr();
let get_url = server.url("/get");
let path = get_url.path();

let test_body = generate_random_body(MEMORY_BOUND + 1);

server.expect(
Expectation::matching(request::method_path("GET", "/get"))
.respond_with(status_code(200).body(test_body.clone())),
);
assert_matches!(
http_get_bounded(&addr, path, MEMORY_BOUND).await,
Err(SearchError::HttpError(m)) if m == "length limit exceeded"
);
}

async fn infinite_body_handle(
_req: Request<hyper::body::Incoming>,
) -> Result<Response<StreamBody<ReceiverStream<Result<Frame<Bytes>, Infallible>>>>, Infallible> {
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, Infallible>>(2);

tokio::spawn(async move {
let chunk = Bytes::from(vec![b'A'; 4096]);
loop {
if tx.send(Ok(Frame::data(chunk.clone()))).await.is_err() {
break;
}
}
});

let stream = ReceiverStream::new(rx);
let body = StreamBody::new(stream);

Ok(Response::builder()
.header("transfer-encoding", "chunked")
.body(body)
.unwrap())
}

async fn start_infinite_server() -> Result<SocketAddr, Box<dyn std::error::Error>> {
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let listener = TcpListener::bind(addr).await?;
let addr = listener.local_addr().unwrap();
eprintln!("Listening on http://{addr}");

tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);

tokio::spawn(async move {
if let Err(e) = http1::Builder::new()
.serve_connection(io, service_fn(infinite_body_handle))
.await
{
eprintln!("connection error: {e}");
}
});
}
});

Ok(addr)
}

#[tokio::test]
async fn search_gateway_should_fail_for_infinite_http_get_body() {
let http_addr = start_infinite_server().await.unwrap();
let local_free_port = crate::common::tests::start_broadcast_reply_sender(format!("http://{http_addr}")).await;
let options = crate::common::tests::default_options_with_using_free_port(local_free_port).await;

assert_matches!(
search_gateway(options).await,
Err(SearchError::HttpError(m)) if m == "length limit exceeded"
);
}
}
2 changes: 1 addition & 1 deletion src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ pub mod messages;
pub mod options;
pub mod parsing;
#[cfg(test)]
mod tests;
pub mod tests;

pub use self::options::SearchOptions;

Expand Down
Loading
Loading