Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use sip_addr::SipAddr;
pub use tcp_listener::TcpListenerConnection;
pub use tls::{TlsConfig, TlsListenerConnection};
pub use transport_layer::TransportLayer;
pub use transport_layer::TransportWhitelist;
pub use websocket::WebSocketListenerConnection;

#[cfg(test)]
Expand Down
4 changes: 4 additions & 0 deletions src/transport/tcp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ impl TcpListenerConnection {
continue;
}
};
if !transport_layer_inner.is_whitelisted(remote_addr.ip()).await {
debug!(remote = %remote_addr, "tcp connection rejected by whitelist");
continue;
}
let local_addr = SipAddr {
r#type: Some(rsip::transport::Transport::Tcp),
addr: remote_addr.into(),
Expand Down
4 changes: 4 additions & 0 deletions src/transport/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ impl TlsListenerConnection {
continue;
}
};
if !transport_layer_inner.is_whitelisted(remote_addr.ip()).await {
debug!(remote = %remote_addr, "tls connection rejected by whitelist");
continue;
}

let acceptor_clone = acceptor.clone();
let transport_layer_inner_ref = transport_layer_inner.clone();
Expand Down
72 changes: 70 additions & 2 deletions src/transport/transport_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rsip_dns::trust_dns_resolver::TokioAsyncResolver;
#[cfg(feature = "rsip-dns")]
use rsip_dns::ResolvableExt;

use std::net::IpAddr;
use std::sync::{Mutex, RwLock};
use std::{collections::HashMap, sync::Arc};
use tokio::select;
Expand All @@ -23,6 +24,25 @@ pub trait DomainResolver: Send + Sync {
async fn resolve(&self, target: &SipAddr) -> Result<SipAddr>;
}

#[async_trait]
pub trait TransportWhitelist: Send + Sync {
/// Return true to accept the packet/connection for the given peer IP.
async fn allow(&self, ip: IpAddr) -> bool;
}

#[async_trait]
impl<F, Fut> TransportWhitelist for F
where
F: Send + Sync + Fn(IpAddr) -> Fut,
Fut: std::future::Future<Output = bool> + Send,
{
async fn allow(&self, ip: IpAddr) -> bool {
(self)(ip).await
}
}

pub(crate) type TransportWhitelistRef = Arc<dyn TransportWhitelist>;

pub struct DefaultDomainResolver {}

impl DefaultDomainResolver {
Expand Down Expand Up @@ -107,6 +127,7 @@ pub struct TransportLayerInner {
pub(crate) transport_tx: TransportSender,
pub(crate) transport_rx: Mutex<Option<TransportReceiver>>,
pub domain_resolver: Box<dyn DomainResolver>,
whitelist: RwLock<Option<TransportWhitelistRef>>,
}
pub(crate) type TransportLayerInnerRef = Arc<TransportLayerInner>;

Expand All @@ -129,6 +150,7 @@ impl TransportLayer {
transport_tx,
transport_rx: Mutex::new(Some(transport_rx)),
domain_resolver,
whitelist: RwLock::new(None),
};
Self {
outbound: None,
Expand Down Expand Up @@ -196,9 +218,48 @@ impl TransportLayer {
}
}
}

/// Set an async whitelist callback invoked on incoming packets/connections.
pub fn set_whitelist<T>(&self, whitelist: T)
where
T: TransportWhitelist + 'static,
{
self.inner.set_whitelist(Some(Arc::new(whitelist)));
}

/// Remove the whitelist callback.
pub fn clear_whitelist(&self) {
self.inner.set_whitelist(None);
}
}

impl TransportLayerInner {
pub(super) fn set_whitelist(&self, whitelist: Option<TransportWhitelistRef>) {
match self.whitelist.write() {
Ok(mut guard) => {
*guard = whitelist;
}
Err(e) => {
warn!(error = ?e, "Failed to update whitelist");
}
}
}

pub(crate) async fn is_whitelisted(&self, ip: IpAddr) -> bool {
let whitelist = match self.whitelist.read() {
Ok(guard) => guard.clone(),
Err(e) => {
warn!(error = ?e, "Failed to read whitelist");
return true;
}
};

match whitelist {
Some(whitelist) => whitelist.allow(ip).await,
None => true,
}
}

pub(super) fn add_listener(&self, connection: SipConnection) {
match self.listens.write() {
Ok(mut listens) => {
Expand Down Expand Up @@ -349,7 +410,12 @@ impl TransportLayerInner {
let sender = self.transport_tx.clone();
match transport {
SipConnection::Udp(transport) => {
tokio::spawn(async move { transport.serve_loop(sender).await });
let transport_layer_inner = self.clone();
tokio::spawn(async move {
transport
.serve_loop_with_whitelist(sender, Some(transport_layer_inner))
.await
});
Ok(())
}
SipConnection::TcpListener(connection) => connection.serve_listener(self.clone()).await,
Expand Down Expand Up @@ -383,7 +449,9 @@ impl TransportLayerInner {
}
select! {
_ = sub_token.cancelled() => { }
_ = transport.serve_loop(sender_clone.clone()) => {
_ = async {
transport.serve_loop(sender_clone.clone()).await
} => {
}
}
info!(addr=%transport.get_addr(), "transport serve_loop exited");
Expand Down
16 changes: 16 additions & 0 deletions src/transport/udp.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{connection::TransportSender, SipAddr, SipConnection};
use crate::{
transport::transport_layer::TransportLayerInnerRef,
transport::{
connection::{KEEPALIVE_REQUEST, KEEPALIVE_RESPONSE, MAX_UDP_BUF_SIZE},
TransportEvent,
Expand Down Expand Up @@ -64,6 +65,14 @@ impl UdpConnection {
}

pub async fn serve_loop(&self, sender: TransportSender) -> Result<()> {
self.serve_loop_with_whitelist(sender, None).await
}

pub async fn serve_loop_with_whitelist(
&self,
sender: TransportSender,
transport_layer_inner: Option<TransportLayerInnerRef>,
) -> Result<()> {
let mut buf = BytesMut::with_capacity(MAX_UDP_BUF_SIZE);
buf.resize(MAX_UDP_BUF_SIZE, 0);
loop {
Expand Down Expand Up @@ -92,6 +101,13 @@ impl UdpConnection {
}
};

if let Some(transport_layer_inner) = &transport_layer_inner {
if !transport_layer_inner.is_whitelisted(addr.ip()).await {
debug!(src = %addr, "udp packet rejected by whitelist");
continue;
}
}

match &buf[..len] {
KEEPALIVE_REQUEST => {
self.inner.conn.send_to(KEEPALIVE_RESPONSE, addr).await.ok();
Expand Down
4 changes: 4 additions & 0 deletions src/transport/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ impl WebSocketListenerConnection {
continue;
}
};
if !transport_layer_inner.is_whitelisted(remote_addr.ip()).await {
debug!(remote = %remote_addr, "websocket connection rejected by whitelist");
continue;
}

debug!(remote = %remote_addr, "New WebSocket connection");

Expand Down