diff --git a/src/connection.rs b/src/connection.rs index 2f18026..a8bb486 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -73,7 +73,7 @@ impl ManageConnection for ConnectionManager { enum Transport { Tcp(TcpOptions), - Udp, + Udp(UdpOptions), #[cfg(unix)] Unix, #[cfg(feature = "tls")] @@ -94,6 +94,17 @@ struct TcpOptions { nodelay: bool, } +struct UdpOptions { + bind_addr: Option, +} + +impl UdpOptions { + fn from_url(url: &Url) -> Self { + let bind_addr = url.query_pairs().find(|(k, _)| k == "bind").map(|(_, v)| v.to_string()); + UdpOptions { bind_addr } + } +} + #[cfg(feature = "tls")] fn get_param(url: &Url, key: &str) -> Option { return url @@ -173,7 +184,7 @@ impl Transport { if let Some(proto) = parts.next() { return match proto { "tcp" => Ok(Transport::Tcp(TcpOptions::from_url(url))), - "udp" => Ok(Transport::Udp), + "udp" => Ok(Transport::Udp(UdpOptions::from_url(url))), #[cfg(unix)] "unix" => Ok(Transport::Unix), #[cfg(feature = "tls")] @@ -186,7 +197,7 @@ impl Transport { let is_udp = url.query_pairs().any(|(ref k, ref v)| k == "udp" && v == "true"); if is_udp { - return Ok(Transport::Udp); + return Ok(Transport::Udp(UdpOptions::from_url(url))); } #[cfg(unix)] @@ -220,7 +231,7 @@ impl Connection { let is_ascii = url.query_pairs().any(|(ref k, ref v)| k == "protocol" && v == "ascii"); let stream: Stream = match transport { Transport::Tcp(options) => Stream::Tcp(tcp_stream(url, &options)?), - Transport::Udp => Stream::Udp(UdpStream::new(url)?), + Transport::Udp(options) => Stream::Udp(UdpStream::new(url, options.bind_addr.as_deref())?), #[cfg(unix)] Transport::Unix => Stream::Unix(UnixStream::connect(url.path())?), #[cfg(feature = "tls")] diff --git a/src/stream/udp_stream.rs b/src/stream/udp_stream.rs index cc17245..44c2292 100644 --- a/src/stream/udp_stream.rs +++ b/src/stream/udp_stream.rs @@ -16,9 +16,23 @@ pub struct UdpStream { } impl UdpStream { - pub fn new(addr: &Url) -> Result { - let socket = UdpSocket::bind("0.0.0.0:0")?; - socket.connect(&*addr.socket_addrs(|| None)?)?; + pub fn new(addr: &Url, bind_addr: Option<&str>) -> Result { + let remote_addrs = addr.socket_addrs(|| None)?; + + let bind = match bind_addr { + Some(addr) => format!("{}:0", addr), + None => { + // Auto-detect: use loopback if target is loopback + if remote_addrs.iter().any(|a| a.ip().is_loopback()) { + "127.0.0.1:0".to_string() + } else { + "0.0.0.0:0".to_string() + } + } + }; + + let socket = UdpSocket::bind(&bind)?; + socket.connect(&*remote_addrs)?; return Ok(UdpStream { socket, read_buf: Vec::new(),