Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ impl ManageConnection for ConnectionManager {

enum Transport {
Tcp(TcpOptions),
Udp,
Udp(UdpOptions),
#[cfg(unix)]
Unix,
#[cfg(feature = "tls")]
Expand All @@ -94,6 +94,17 @@ struct TcpOptions {
nodelay: bool,
}

struct UdpOptions {
bind_addr: Option<String>,
}

impl UdpOptions {
fn from_url(url: &Url) -> Self {
let bind_addr = url.query_pairs().find(|(k, _)| k == "bind").map(|(_, v)| v.to_string());
UdpOptions { bind_addr }
}
}

#[cfg(feature = "tls")]
fn get_param(url: &Url, key: &str) -> Option<String> {
return url
Expand Down Expand Up @@ -173,7 +184,7 @@ impl Transport {
if let Some(proto) = parts.next() {
return match proto {
"tcp" => Ok(Transport::Tcp(TcpOptions::from_url(url))),
"udp" => Ok(Transport::Udp),
"udp" => Ok(Transport::Udp(UdpOptions::from_url(url))),
#[cfg(unix)]
"unix" => Ok(Transport::Unix),
#[cfg(feature = "tls")]
Expand All @@ -186,7 +197,7 @@ impl Transport {

let is_udp = url.query_pairs().any(|(ref k, ref v)| k == "udp" && v == "true");
if is_udp {
return Ok(Transport::Udp);
return Ok(Transport::Udp(UdpOptions::from_url(url)));
}

#[cfg(unix)]
Expand Down Expand Up @@ -220,7 +231,7 @@ impl Connection {
let is_ascii = url.query_pairs().any(|(ref k, ref v)| k == "protocol" && v == "ascii");
let stream: Stream = match transport {
Transport::Tcp(options) => Stream::Tcp(tcp_stream(url, &options)?),
Transport::Udp => Stream::Udp(UdpStream::new(url)?),
Transport::Udp(options) => Stream::Udp(UdpStream::new(url, options.bind_addr.as_deref())?),
#[cfg(unix)]
Transport::Unix => Stream::Unix(UnixStream::connect(url.path())?),
#[cfg(feature = "tls")]
Expand Down
20 changes: 17 additions & 3 deletions src/stream/udp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,23 @@ pub struct UdpStream {
}

impl UdpStream {
pub fn new(addr: &Url) -> Result<Self, MemcacheError> {
let socket = UdpSocket::bind("0.0.0.0:0")?;
socket.connect(&*addr.socket_addrs(|| None)?)?;
pub fn new(addr: &Url, bind_addr: Option<&str>) -> Result<Self, MemcacheError> {
let remote_addrs = addr.socket_addrs(|| None)?;

let bind = match bind_addr {
Some(addr) => format!("{}:0", addr),
None => {
// Auto-detect: use loopback if target is loopback
if remote_addrs.iter().any(|a| a.ip().is_loopback()) {
"127.0.0.1:0".to_string()
} else {
"0.0.0.0:0".to_string()
}
}
};

let socket = UdpSocket::bind(&bind)?;
socket.connect(&*remote_addrs)?;
return Ok(UdpStream {
socket,
read_buf: Vec::new(),
Expand Down
Loading