diff --git a/Cargo.lock b/Cargo.lock index a53db661..fc4d2bcf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -524,6 +524,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -909,6 +923,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -987,6 +1007,29 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "governor" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9efcab3c1958580ff1f25a2a41be1668f7603d849bb63af523b208a3cc1223b8" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.4", + "hashbrown 0.16.1", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand", + "smallvec", + "spinning_top", + "web-time", +] + [[package]] name = "h2" version = "0.4.13" @@ -1015,6 +1058,12 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.16.1" @@ -1389,7 +1438,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.16.1", ] [[package]] @@ -1646,6 +1695,7 @@ dependencies = [ "derive_more", "futures-buffered", "futures-util", + "governor", "iroh-quinn", "irpc-derive", "n0-error", @@ -1681,6 +1731,7 @@ dependencies = [ "clap", "futures-util", "getrandom 0.3.4", + "governor", "hex", "iroh", "iroh-base", @@ -1782,7 +1833,7 @@ version = "0.16.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1dc47f592c06f33f8e3aea9591776ec7c9f9e4124778ff8a3c3b87159f7e593" dependencies = [ - "hashbrown", + "hashbrown 0.16.1", ] [[package]] @@ -2063,6 +2114,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "ntimestamp" version = "1.0.0" @@ -2499,6 +2556,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quick-xml" version = "0.38.4" @@ -2607,6 +2679,15 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "rcgen" version = "0.14.7" @@ -3027,6 +3108,15 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591" +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.8.0-rc.4" diff --git a/Cargo.toml b/Cargo.toml index ee30e64f..5ab6d590 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,6 +61,8 @@ trybuild = "1.0.104" testresult = "0.4.1" # used in examples anyhow = { workspace = true } +# used in rate_limit example +governor = { workspace = true } [features] # enable the remote transport @@ -90,6 +92,10 @@ required-features = ["derive"] name = "storage" required-features = ["rpc", "quinn_endpoint_setup"] +[[example]] +name = "rate_limit" +required-features = ["rpc", "derive", "quinn_endpoint_setup"] + [workspace] members = ["irpc-derive", "irpc-iroh"] @@ -113,3 +119,4 @@ iroh = { version = "0.96" } iroh-base = { version = "0.96" } quinn = { package = "iroh-quinn", version = "0.16.0", default-features = false } futures-util = { version = "0.3", features = ["sink"] } +governor = { version = "0.10" } diff --git a/examples/rate_limit.rs b/examples/rate_limit.rs new file mode 100644 index 00000000..4c4d483d --- /dev/null +++ b/examples/rate_limit.rs @@ -0,0 +1,176 @@ +//! Example demonstrating per-connection and per-request rate limiting. +//! +//! Uses [`irpc::rpc::ConnectionFilter`] for per-IP connection filtering and +//! [`irpc::rpc::map_filter`] for per-request filtering with the `governor` crate. +use anyhow::{Context, Result}; +use governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Quota, RateLimiter}; +use irpc::{ + channel::oneshot, + rpc::{ConnectionFilter, ListenerBuilder, RemoteService, RequestFilter}, + rpc_requests, + util::{make_client_endpoint, make_server_endpoint}, + Client, WithChannels, +}; +use n0_future::task::{self, AbortOnDropHandle}; +use serde::{Deserialize, Serialize}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::num::NonZeroU32; + +#[derive(Debug, Serialize, Deserialize)] +struct Ping { + payload: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Info; + +#[rpc_requests(message = AppMessage)] +#[derive(Serialize, Deserialize, Debug)] +enum AppProtocol { + #[rpc(tx = oneshot::Sender>)] + Ping(Ping), + #[rpc(tx = oneshot::Sender)] + Info(Info), +} + +struct AppActor { + recv: tokio::sync::mpsc::Receiver, +} + +impl AppActor { + pub fn spawn() -> AppApi { + let (tx, rx) = tokio::sync::mpsc::channel(1); + n0_future::task::spawn(Self { recv: rx }.run()); + AppApi { + inner: Client::local(tx), + } + } + + async fn run(mut self) { + while let Some(msg) = self.recv.recv().await { + match msg { + AppMessage::Ping(ping) => { + let WithChannels { tx, inner, .. } = ping; + tx.send(inner.payload).await.ok(); + } + AppMessage::Info(info) => { + let WithChannels { tx, .. } = info; + tx.send("irpc rate-limit example".to_string()).await.ok(); + } + } + } + } +} + +/// Per-connection rate limiter using governor, keyed by remote address. +struct GovernorConnectionFilter { + limiter: DefaultKeyedRateLimiter, +} + +impl GovernorConnectionFilter { + fn new(per_second: u32) -> Self { + Self { + limiter: RateLimiter::keyed(Quota::per_second( + NonZeroU32::new(per_second).expect("per_second must be > 0"), + )), + } + } +} + +impl ConnectionFilter for GovernorConnectionFilter { + fn accept(&self, addr: &SocketAddr) -> bool { + self.limiter.check_key(addr).is_ok() + } +} + +/// Per-request rate limiter: rate-limits Ping requests, always allows Info. +struct PingRateLimiter { + limiter: DefaultDirectRateLimiter, +} + +impl PingRateLimiter { + fn new(per_second: u32) -> Self { + Self { + limiter: RateLimiter::direct(Quota::per_second( + NonZeroU32::new(per_second).expect("per_second must be > 0"), + )), + } + } +} + +impl RequestFilter for PingRateLimiter { + fn accept(&self, req: &AppProtocol) -> bool { + match req { + AppProtocol::Ping(_) => self.limiter.check().is_ok(), + _ => true, + } + } +} + +struct AppApi { + inner: Client, +} + +impl AppApi { + pub fn connect(endpoint: quinn::Endpoint, addr: SocketAddr) -> Result { + Ok(AppApi { + inner: Client::quinn(endpoint, addr), + }) + } + + pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { + let local = self + .inner + .as_local() + .context("cannot listen on remote API")?; + let handler = AppProtocol::remote_handler(local); + let listener = ListenerBuilder::new(endpoint, handler) + .request_filter(PingRateLimiter::new(2)) + .connection_filter(GovernorConnectionFilter::new(10)); + Ok(AbortOnDropHandle::new(task::spawn(listener.listen()))) + } + + pub async fn ping(&self, payload: Vec) -> irpc::Result> { + self.inner.rpc(Ping { payload }).await + } + + pub async fn info(&self) -> irpc::Result { + self.inner.rpc(Info).await + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + let port = 10114; + let addr: SocketAddr = SocketAddrV4::new(Ipv4Addr::LOCALHOST, port).into(); + + let (server_handle, cert) = { + let (endpoint, cert) = make_server_endpoint(addr)?; + let api = AppActor::spawn(); + let handle = api.listen(endpoint)?; + (handle, cert) + }; + + let endpoint = + make_client_endpoint(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0).into(), &[&cert])?; + + // Fire bursts of Ping with interspersed Info requests. + // Ping is rate-limited to 2/sec, Info always gets through. + for i in 0..10 { + let api = AppApi::connect(endpoint.clone(), addr)?; + match api.ping(b"hello".to_vec()).await { + Ok(response) => println!("{i}: ping = {}", String::from_utf8_lossy(&response)), + Err(e) => println!("{i}: ping rejected: {e}"), + } + let api = AppApi::connect(endpoint.clone(), addr)?; + match api.info().await { + Ok(response) => println!("{i}: info = {response}"), + Err(e) => println!("{i}: info rejected: {e}"), + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + drop(server_handle); + Ok(()) +} diff --git a/irpc-iroh/Cargo.toml b/irpc-iroh/Cargo.toml index fc63f0ff..9595eed9 100644 --- a/irpc-iroh/Cargo.toml +++ b/irpc-iroh/Cargo.toml @@ -35,3 +35,4 @@ futures-util.workspace = true hex = "0.4.3" rand = "0.9.2" anyhow = { workspace = true } +governor = { workspace = true } diff --git a/irpc-iroh/examples/rate-limit.rs b/irpc-iroh/examples/rate-limit.rs new file mode 100644 index 00000000..d5f68484 --- /dev/null +++ b/irpc-iroh/examples/rate-limit.rs @@ -0,0 +1,178 @@ +//! Example demonstrating per-connection and per-request rate limiting with iroh. +//! +//! Uses [`irpc_iroh::IrohConnectionFilter`] for per-IP connection filtering and +//! [`irpc::rpc::RequestFilter`] for per-request filtering with the `governor` crate. +use std::net::SocketAddr; +use std::num::NonZeroU32; + +use anyhow::{Context, Result}; +use governor::{DefaultDirectRateLimiter, DefaultKeyedRateLimiter, Quota, RateLimiter}; +use iroh::Endpoint; +use irpc::{ + channel::oneshot, + rpc::{RemoteService, RequestFilter}, + rpc_requests, Client, WithChannels, +}; +use irpc_iroh::IrohListenerBuilder; +use n0_future::task::{self, AbortOnDropHandle}; +use serde::{Deserialize, Serialize}; + +const ALPN: &[u8] = b"irpc-iroh/rate-limit/0"; + +#[derive(Debug, Serialize, Deserialize)] +struct Ping { + payload: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Info; + +#[rpc_requests(message = AppMessage)] +#[derive(Serialize, Deserialize, Debug)] +enum AppProtocol { + #[rpc(tx = oneshot::Sender>)] + Ping(Ping), + #[rpc(tx = oneshot::Sender)] + Info(Info), +} + +struct AppActor { + recv: tokio::sync::mpsc::Receiver, +} + +impl AppActor { + pub fn spawn() -> AppApi { + let (tx, rx) = tokio::sync::mpsc::channel(1); + n0_future::task::spawn(Self { recv: rx }.run()); + AppApi { + inner: Client::local(tx), + } + } + + async fn run(mut self) { + while let Some(msg) = self.recv.recv().await { + match msg { + AppMessage::Ping(ping) => { + let WithChannels { tx, inner, .. } = ping; + tx.send(inner.payload).await.ok(); + } + AppMessage::Info(info) => { + let WithChannels { tx, .. } = info; + tx.send("irpc-iroh rate-limit example".to_string()) + .await + .ok(); + } + } + } + } +} + +/// Per-connection rate limiter using governor, keyed by remote address. +struct GovernorConnectionFilter { + limiter: DefaultKeyedRateLimiter, +} + +impl GovernorConnectionFilter { + fn new(per_second: u32) -> Self { + Self { + limiter: RateLimiter::keyed(Quota::per_second( + NonZeroU32::new(per_second).expect("per_second must be > 0"), + )), + } + } +} + +impl irpc_iroh::IrohConnectionFilter for GovernorConnectionFilter { + fn accept(&self, addr: &SocketAddr) -> bool { + self.limiter.check_key(addr).is_ok() + } +} + +/// Per-request rate limiter: rate-limits Ping requests, always allows Info. +struct PingRateLimiter { + limiter: DefaultDirectRateLimiter, +} + +impl PingRateLimiter { + fn new(per_second: u32) -> Self { + Self { + limiter: RateLimiter::direct(Quota::per_second( + NonZeroU32::new(per_second).expect("per_second must be > 0"), + )), + } + } +} + +impl RequestFilter for PingRateLimiter { + fn accept(&self, req: &AppProtocol) -> bool { + match req { + AppProtocol::Ping(_) => self.limiter.check().is_ok(), + _ => true, + } + } +} + +struct AppApi { + inner: Client, +} + +impl AppApi { + pub fn connect(endpoint: Endpoint, addr: impl Into) -> AppApi { + AppApi { + inner: irpc_iroh::client(endpoint, addr, ALPN), + } + } + + pub fn listen(&self, endpoint: iroh::Endpoint) -> Result> { + let local = self + .inner + .as_local() + .context("cannot listen on remote API")?; + let handler = AppProtocol::remote_handler(local); + let listener = IrohListenerBuilder::new(endpoint, handler) + .request_filter(PingRateLimiter::new(2)) + .connection_filter(GovernorConnectionFilter::new(10)); + Ok(AbortOnDropHandle::new(task::spawn(listener.listen()))) + } + + pub async fn ping(&self, payload: Vec) -> irpc::Result> { + self.inner.rpc(Ping { payload }).await + } + + pub async fn info(&self) -> irpc::Result { + self.inner.rpc(Info).await + } +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt::init(); + + let server_endpoint = Endpoint::builder() + .alpns(vec![ALPN.to_vec()]) + .bind() + .await?; + let api = AppActor::spawn(); + let _server_handle = api.listen(server_endpoint.clone())?; + server_endpoint.online().await; + + let client_endpoint = Endpoint::builder().bind().await?; + + // Fire bursts of Ping with interspersed Info requests. + // Ping is rate-limited to 2/sec, Info always gets through. + for i in 0..10 { + let api = AppApi::connect(client_endpoint.clone(), server_endpoint.addr()); + match api.ping(b"hello".to_vec()).await { + Ok(response) => println!("{i}: ping = {}", String::from_utf8_lossy(&response)), + Err(e) => println!("{i}: ping rejected: {e}"), + } + let api = AppApi::connect(client_endpoint.clone(), server_endpoint.addr()); + match api.info().await { + Ok(response) => println!("{i}: info = {response}"), + Err(e) => println!("{i}: info rejected: {e}"), + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + Ok(()) +} diff --git a/irpc-iroh/src/lib.rs b/irpc-iroh/src/lib.rs index 2556f85a..8986805e 100644 --- a/irpc-iroh/src/lib.rs +++ b/irpc-iroh/src/lib.rs @@ -17,8 +17,8 @@ use iroh::{ use irpc::{ channel::oneshot, rpc::{ - Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, - MAX_MESSAGE_SIZE, + Handler, RemoteConnection, RemoteService, RequestFilter, + ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, ERROR_CODE_REQUEST_LIMITED, MAX_MESSAGE_SIZE, }, util::AsyncReadVarintExt, LocalSender, RequestError, @@ -400,37 +400,188 @@ pub async fn read_request_raw( Ok(Some((msg, rx, tx))) } -/// Utility function to listen for incoming connections and handle them with the provided handler -pub async fn listen(endpoint: iroh::Endpoint, handler: Handler) { - let mut request_id = 0u64; - let mut tasks = n0_future::task::JoinSet::new(); - loop { - let incoming = tokio::select! { - Some(res) = tasks.join_next(), if !tasks.is_empty() => { - res.expect("irpc connection task panicked"); - continue; - } - incoming = endpoint.accept() => { - match incoming { - None => break, - Some(incoming) => incoming +/// Filter for incoming iroh connections. +/// +/// Like [`irpc::rpc::ConnectionFilter`], but additionally allows filtering by +/// [`EndpointId`] — the remote node's cryptographic identity, available after +/// the QUIC handshake completes. +/// +/// Most implementations only need [`Self::accept`] and/or [`Self::accept_endpoint_id`]. +/// Override [`Self::accept_unvalidated`] for coarse pre-handshake flood protection. +pub trait IrohConnectionFilter: Send + Sync + 'static { + /// Check whether to accept a connection from the given validated address. + /// + /// Returns `true` to accept, `false` to refuse. + fn accept(&self, _addr: &std::net::SocketAddr) -> bool { + true + } + + /// Check whether to accept a connection before address validation. + /// + /// # Security + /// + /// The address has **not** been validated at this stage and can be + /// freely spoofed by an attacker. It usually should not be used for + /// access-control decisions. It is mainly useful for coarse, + /// high-threshold flood protection (e.g. blocking known-bad IPs). + /// + /// Returns `true` to accept, `false` to refuse. Defaults to `true`. + fn accept_unvalidated(&self, _addr: &std::net::SocketAddr) -> bool { + true + } + + /// Check whether to accept a connection from the given endpoint ID. + /// + /// Called after the QUIC handshake, when the remote node's identity is known. + /// + /// Returns `true` to accept, `false` to refuse. + fn accept_endpoint_id(&self, _id: &EndpointId) -> bool { + true + } +} + +/// An [`IrohConnectionFilter`] that accepts all connections. +#[derive(Debug, Clone, Default)] +pub struct AcceptAll; + +impl IrohConnectionFilter for AcceptAll {} + +fn wrap_request_filter( + handler: Handler, + filter: Arc>, +) -> Handler { + Arc::new(move |r, rx, mut tx| { + if filter.accept(&r) { + handler(r, rx, tx) + } else { + tx.reset(ERROR_CODE_REQUEST_LIMITED.into()).ok(); + drop(rx); + Box::pin(async { Ok(()) }) + } + }) +} + +/// Builder for configuring and running an iroh listener with optional filters. +pub struct IrohListenerBuilder { + endpoint: iroh::Endpoint, + handler: Handler, + connection_filter: Arc, +} + +impl IrohListenerBuilder { + /// Creates a new listener builder. + pub fn new(endpoint: iroh::Endpoint, handler: Handler) -> Self { + Self { + endpoint, + handler, + connection_filter: Arc::new(AcceptAll), + } + } + + /// Sets a connection filter for per-IP and per-endpoint-ID rate limiting. + pub fn connection_filter(mut self, filter: impl IrohConnectionFilter) -> Self { + self.connection_filter = Arc::new(filter); + self + } + + /// Sets a per-request filter. + /// + /// The filter is called with `&R` (the deserialized protocol enum) + /// before the handler. If it returns `false`, the request is dropped. + pub fn request_filter(mut self, filter: impl RequestFilter) -> Self { + self.handler = wrap_request_filter(self.handler, Arc::new(filter)); + self + } + + /// Runs the listener, accepting connections until the endpoint is closed. + pub async fn listen(self) { + IrohListener { + endpoint: self.endpoint, + handler: self.handler, + filter: self.connection_filter, + } + .run() + .await + } +} + +struct IrohListener { + endpoint: iroh::Endpoint, + handler: Handler, + filter: Arc, +} + +impl IrohListener { + async fn run(self) { + let mut request_id = 0u64; + let mut tasks = n0_future::task::JoinSet::new(); + loop { + let incoming = tokio::select! { + Some(res) = tasks.join_next(), if !tasks.is_empty() => { + res.expect("irpc connection task panicked"); + continue; } - } - }; - let handler = handler.clone(); - let fut = async move { - match incoming.await { - Ok(connection) => match handle_connection(&connection, handler).await { - Err(err) => warn!("connection closed with error: {err:?}"), - Ok(()) => debug!("connection closed"), - }, - Err(cause) => { - warn!("failed to accept connection: {cause:?}"); + incoming = self.endpoint.accept() => { + match incoming { + None => break, + Some(incoming) => incoming + } } }; - }; - let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty); - tasks.spawn(fut.instrument(span)); - request_id += 1; + let addr = incoming.remote_address(); + let validated = incoming.remote_address_validated(); + let refused = if validated { + !self.filter.accept(&addr) + } else { + !self.filter.accept_unvalidated(&addr) + }; + if refused { + incoming.refuse(); + continue; + } + let handler = self.handler.clone(); + let filter = self.filter.clone(); + let fut = async move { + let accepting = match incoming.accept() { + Ok(accepting) => accepting, + Err(cause) => { + warn!("failed to accept connection: {cause:?}"); + return; + } + }; + match accepting.await { + Ok(connection) => { + // Deferred validated-address check for initially-unvalidated connections + if !validated && !filter.accept(&addr) { + connection.close(ERROR_CODE_REQUEST_LIMITED.into(), b"rate limited"); + return; + } + // Endpoint ID check (only available after handshake) + if !filter.accept_endpoint_id(&connection.remote_id()) { + connection.close(ERROR_CODE_REQUEST_LIMITED.into(), b"rate limited"); + return; + } + match handle_connection(&connection, handler).await { + Err(err) => warn!("connection closed with error: {err:?}"), + Ok(()) => debug!("connection closed"), + } + } + Err(cause) => { + warn!("failed to accept connection: {cause:?}"); + } + }; + }; + let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty); + tasks.spawn(fut.instrument(span)); + request_id += 1; + } } } + +/// Utility function to listen for incoming connections and handle them with the provided handler +pub async fn listen( + endpoint: iroh::Endpoint, + handler: Handler, +) { + IrohListenerBuilder::new(endpoint, handler).listen().await +} diff --git a/src/lib.rs b/src/lib.rs index 1d849a2c..f2169ec7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1864,6 +1864,9 @@ pub mod rpc { /// Error code on streams if the sender tried to send an message that could not be postcard serialized. pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2; + /// Error code on connections/streams if the request was rate limited. + pub const ERROR_CODE_REQUEST_LIMITED: u32 = 3; + /// Error that can occur when writing the initial message when doing a /// cross-process RPC. #[stack_error(derive, add_meta, from_sources)] @@ -2377,6 +2380,31 @@ pub mod rpc { + 'static, >; + /// Per-request filter, called after deserialization but before the handler. + /// + /// Implement this trait to add per-request rate limiting or access control. + pub trait RequestFilter: Send + Sync + 'static { + /// Check whether to accept the request. + /// + /// Returns `true` to accept, `false` to drop the request. + fn accept(&self, request: &R) -> bool; + } + + fn wrap_request_filter( + handler: Handler, + filter: Arc>, + ) -> Handler { + Arc::new(move |r, rx, mut tx| { + if filter.accept(&r) { + handler(r, rx, tx) + } else { + tx.reset(ERROR_CODE_REQUEST_LIMITED.into()).ok(); + drop(rx); + Box::pin(async { Ok(()) }) + } + }) + } + /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`] /// and a pair of QUIC streams. /// @@ -2400,40 +2428,158 @@ pub mod rpc { } /// Utility function to listen for incoming connections and handle them with the provided handler - pub async fn listen( + pub async fn listen( endpoint: quinn::Endpoint, handler: Handler, ) { - let mut request_id = 0u64; - let mut tasks = JoinSet::new(); - loop { - let incoming = tokio::select! { - Some(res) = tasks.join_next(), if !tasks.is_empty() => { - res.expect("irpc connection task panicked"); - continue; - } - incoming = endpoint.accept() => { - match incoming { - None => break, - Some(incoming) => incoming + ListenerBuilder::new(endpoint, handler).listen().await + } + + /// Filter for incoming connections, called before accepting. + /// + /// Implement this trait to add rate limiting or other connection filtering. + /// + /// Most implementations only need [`Self::accept`], which is called with + /// a validated (non-spoofable) remote address. Override + /// [`Self::accept_unvalidated`] for coarse pre-handshake flood protection. + pub trait ConnectionFilter: Send + Sync + 'static { + /// Check whether to accept a connection from the given validated address. + /// + /// The address has been verified by QUIC source address validation. + /// + /// Returns `true` to accept, `false` to refuse. + fn accept(&self, _addr: &std::net::SocketAddr) -> bool { + true + } + + /// Check whether to accept a connection before address validation. + /// + /// # Security + /// + /// The address has **not** been validated at this stage and can be + /// freely spoofed by an attacker. It usually should not be used for + /// access-control decisions. It is mainly useful for coarse, + /// high-threshold flood protection (e.g. blocking known-bad IPs). + /// + /// Returns `true` to accept, `false` to refuse. Defaults to `true`. + fn accept_unvalidated(&self, _addr: &std::net::SocketAddr) -> bool { + true + } + } + + /// A [`ConnectionFilter`] that accepts all connections. + #[derive(Debug, Clone, Default)] + pub struct AcceptAll; + + impl ConnectionFilter for AcceptAll {} + + /// Builder for configuring and running a listener with optional filters. + /// + /// Created via [`ListenerBuilder::new`]. + pub struct ListenerBuilder { + endpoint: quinn::Endpoint, + handler: Handler, + connection_filter: Arc, + } + + impl ListenerBuilder { + /// Creates a new listener builder. + pub fn new(endpoint: quinn::Endpoint, handler: Handler) -> Self { + Self { + endpoint, + handler, + connection_filter: Arc::new(AcceptAll), + } + } + + /// Sets a connection filter for per-IP rate limiting. + /// + /// The filter is called with validated remote addresses before + /// accepting connections. See [`ConnectionFilter`] for details. + pub fn connection_filter(mut self, filter: impl ConnectionFilter) -> Self { + self.connection_filter = Arc::new(filter); + self + } + + /// Sets a per-request filter. + /// + /// The filter is called with `&R` (the deserialized protocol enum) + /// before the handler. If it returns `false`, the request is dropped. + pub fn request_filter(mut self, filter: impl RequestFilter) -> Self { + self.handler = wrap_request_filter(self.handler, Arc::new(filter)); + self + } + + /// Runs the listener, accepting connections until the endpoint is closed. + pub async fn listen(self) { + Listener { + endpoint: self.endpoint, + handler: self.handler, + filter: self.connection_filter, + } + .run() + .await + } + } + + struct Listener { + endpoint: quinn::Endpoint, + handler: Handler, + filter: Arc, + } + + impl Listener { + async fn run(self) { + let mut request_id = 0u64; + let mut tasks = JoinSet::new(); + loop { + let incoming = tokio::select! { + Some(res) = tasks.join_next(), if !tasks.is_empty() => { + res.expect("irpc connection task panicked"); + continue; } - } - }; - let handler = handler.clone(); - let fut = async move { - match incoming.await { - Ok(connection) => match handle_connection(connection, handler).await { - Err(err) => warn!("connection closed with error: {err:?}"), - Ok(()) => debug!("connection closed"), - }, - Err(cause) => { - warn!("failed to accept connection: {cause:?}"); + incoming = self.endpoint.accept() => { + match incoming { + None => break, + Some(incoming) => incoming + } } }; - }; - let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty); - tasks.spawn(fut.instrument(span)); - request_id += 1; + let addr = incoming.remote_address(); + let validated = incoming.remote_address_validated(); + let refused = if validated { + !self.filter.accept(&addr) + } else { + !self.filter.accept_unvalidated(&addr) + }; + if refused { + incoming.refuse(); + continue; + } + let handler = self.handler.clone(); + let filter = self.filter.clone(); + let fut = async move { + match incoming.await { + Ok(connection) => { + if !validated && !filter.accept(&connection.remote_address()) { + connection + .close(ERROR_CODE_REQUEST_LIMITED.into(), b"rate limited"); + return; + } + match handle_connection(connection, handler).await { + Err(err) => warn!("connection closed with error: {err:?}"), + Ok(()) => debug!("connection closed"), + } + } + Err(cause) => { + warn!("failed to accept connection: {cause:?}"); + } + }; + }; + let span = error_span!("rpc", id = request_id, remote = tracing::field::Empty); + tasks.spawn(fut.instrument(span)); + request_id += 1; + } } }