Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions crates/slipstream-client/src/dns/poll.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::ClientError;
use slipstream_dns::{build_qname, encode_query, QueryParams, CLASS_IN, RR_TXT};
use slipstream_dns::{build_qname_with_limit, encode_query, QueryParams, CLASS_IN, RR_TXT};
use slipstream_ffi::picoquic::{
picoquic_cnx_t, picoquic_current_time, picoquic_prepare_packet_ex, slipstream_request_poll,
};
Expand Down Expand Up @@ -85,8 +85,12 @@ pub(crate) async fn send_poll_queries(
resolver.debug.polls_sent = resolver.debug.polls_sent.saturating_add(1);

let poll_id = *dns_id;
let qname = build_qname(&send_buf[..send_length], config.domain)
.map_err(|err| ClientError::new(err.to_string()))?;
let qname = build_qname_with_limit(
&send_buf[..send_length],
config.domain,
config.max_qname_len,
)
.map_err(|err| ClientError::new(err.to_string()))?;
let params = QueryParams {
id: poll_id,
qname: &qname,
Expand Down
45 changes: 45 additions & 0 deletions crates/slipstream-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ struct Args {
gso: bool,
#[arg(long = "domain", short = 'd', value_parser = parse_domain)]
domain: String,
#[arg(
long = "max-qname-len",
value_name = "LEN",
default_value_t = 253,
value_parser = parse_max_qname_len
)]
max_qname_len: usize,
#[arg(long = "cert", value_name = "PATH")]
cert: Option<String>,
#[arg(long = "keep-alive-interval", short = 't', default_value_t = 400)]
Expand All @@ -72,6 +79,7 @@ fn main() {
congestion_control: args.congestion_control.as_deref(),
gso: args.gso,
domain: &args.domain,
max_qname_len: args.max_qname_len,
cert: args.cert.as_deref(),
keep_alive_interval: args.keep_alive_interval as usize,
debug_poll: args.debug_poll,
Expand Down Expand Up @@ -109,6 +117,16 @@ fn parse_resolver(input: &str) -> Result<HostPort, String> {
parse_host_port(input, 53, AddressKind::Resolver).map_err(|err| err.to_string())
}

fn parse_max_qname_len(input: &str) -> Result<usize, String> {
let value: usize = input
.parse()
.map_err(|_| "Max QNAME length must be a positive integer".to_string())?;
if !(1..=253).contains(&value) {
return Err("Max QNAME length must be between 1 and 253".to_string());
}
Ok(value)
}

fn build_resolvers(matches: &clap::ArgMatches) -> Result<Vec<ResolverSpec>, String> {
let mut ordered = Vec::new();
collect_resolvers(matches, "resolver", ResolverMode::Recursive, &mut ordered)?;
Expand Down Expand Up @@ -151,6 +169,33 @@ fn collect_resolvers(
mod tests {
use super::*;

#[test]
fn max_qname_len_defaults_to_253() {
let args = Args::try_parse_from([
"slipstream-client",
"--resolver",
"1.1.1.1",
"--domain",
"example.com",
])
.expect("parse args");
assert_eq!(args.max_qname_len, 253);
}

#[test]
fn max_qname_len_rejects_out_of_range() {
let result = Args::try_parse_from([
"slipstream-client",
"--resolver",
"1.1.1.1",
"--domain",
"example.com",
"--max-qname-len",
"254",
]);
assert!(result.is_err());
}

#[test]
fn preserves_ordered_resolvers() {
let matches = Args::command()
Expand Down
13 changes: 8 additions & 5 deletions crates/slipstream-client/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::pinning::configure_pinned_certificate;
use crate::streams::{
client_callback, drain_commands, drain_stream_data, handle_command, spawn_acceptor, ClientState,
};
use slipstream_dns::{build_qname, encode_query, QueryParams, CLASS_IN, RR_TXT};
use slipstream_dns::{build_qname_with_limit, encode_query, QueryParams, CLASS_IN, RR_TXT};
use slipstream_ffi::{
configure_quic_with_custom,
picoquic::{
Expand Down Expand Up @@ -47,8 +47,7 @@ const DNS_WAKE_DELAY_MAX_US: i64 = 10_000_000;
const DNS_POLL_SLICE_US: u64 = 50_000;

pub async fn run_client(config: &ClientConfig<'_>) -> Result<i32, ClientError> {
let domain_len = config.domain.len();
let mtu = compute_mtu(domain_len)?;
let mtu = compute_mtu(config.domain, config.max_qname_len)?;
let mut resolvers = resolve_resolvers(config.resolvers, mtu, config.debug_poll)?;
if resolvers.is_empty() {
return Err(ClientError::new("At least one resolver is required"));
Expand Down Expand Up @@ -342,8 +341,12 @@ pub async fn run_client(config: &ClientConfig<'_>) -> Result<i32, ClientError> {
}
}

let qname = build_qname(&send_buf[..send_length], config.domain)
.map_err(|err| ClientError::new(err.to_string()))?;
let qname = build_qname_with_limit(
&send_buf[..send_length],
config.domain,
config.max_qname_len,
)
.map_err(|err| ClientError::new(err.to_string()))?;
let params = QueryParams {
id: dns_id,
qname: &qname,
Expand Down
17 changes: 7 additions & 10 deletions crates/slipstream-client/src/runtime/setup.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
use crate::error::ClientError;
use slipstream_dns::max_payload_len_for_domain_with_limit;
use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
use tokio::net::UdpSocket as TokioUdpSocket;

pub(crate) fn compute_mtu(domain_len: usize) -> Result<u32, ClientError> {
if domain_len >= 240 {
pub(crate) fn compute_mtu(domain: &str, max_qname_len: usize) -> Result<u32, ClientError> {
let max_payload = max_payload_len_for_domain_with_limit(domain, max_qname_len)
.map_err(|err| ClientError::new(err.to_string()))?;
if max_payload == 0 {
return Err(ClientError::new(
"Domain name is too long for DNS transport",
"Max QNAME length leaves no room for payload; adjust --max-qname-len or domain",
));
}
let mtu = ((240.0 - domain_len as f64) / 1.6) as u32;
if mtu == 0 {
return Err(ClientError::new(
"MTU computed to zero; check domain length",
));
}
Ok(mtu)
Ok(max_payload as u32)
}

pub(crate) async fn bind_udp_socket() -> Result<TokioUdpSocket, ClientError> {
Expand Down
92 changes: 92 additions & 0 deletions crates/slipstream-client/tests/max_qname_len_e2e.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
use slipstream_dns::decode_query;
use std::net::UdpSocket;
use std::path::PathBuf;
use std::process::{Child, Command, Stdio};
use std::time::{Duration, Instant};

struct ChildGuard {
child: Child,
}

impl Drop for ChildGuard {
fn drop(&mut self) {
let _ = self.child.kill();
let _ = self.child.wait();
}
}

fn bind_resolver_socket() -> std::io::Result<(UdpSocket, String)> {
if let Ok(socket) = UdpSocket::bind("[::1]:0") {
let port = socket.local_addr()?.port();
return Ok((socket, format!("[::1]:{}", port)));
}
let socket = UdpSocket::bind("127.0.0.1:0")?;
let port = socket.local_addr()?.port();
Ok((socket, format!("127.0.0.1:{}", port)))
}

#[test]
fn max_qname_len_e2e() {
let domain = "example.com";
let max_qname_len = 101usize;
let (socket, resolver) = match bind_resolver_socket() {
Ok(value) => value,
Err(err) => {
eprintln!("skipping max qname len e2e test: {}", err);
return;
}
};
socket
.set_read_timeout(Some(Duration::from_secs(5)))
.expect("set UDP timeout");

let client_bin = PathBuf::from(env!("CARGO_BIN_EXE_slipstream-client"));
let child = Command::new(client_bin)
.arg("--tcp-listen-port")
.arg("0")
.arg("--resolver")
.arg(resolver)
.arg("--domain")
.arg(domain)
.arg("--max-qname-len")
.arg(max_qname_len.to_string())
.env("RUST_LOG", "error")
.stdout(Stdio::null())
.stderr(Stdio::null())
.spawn()
.expect("start slipstream-client");
let _child_guard = ChildGuard { child };

let deadline = Instant::now() + Duration::from_secs(5);
let mut buf = [0u8; 2048];
let mut observed_len = None;

while Instant::now() < deadline {
match socket.recv_from(&mut buf) {
Ok((size, _)) => {
if let Ok(decoded) = decode_query(&buf[..size], domain) {
let qname_len = decoded.question.name.trim_end_matches('.').len();
observed_len = Some(qname_len);
break;
}
}
Err(err)
if matches!(
err.kind(),
std::io::ErrorKind::WouldBlock | std::io::ErrorKind::TimedOut
) =>
{
continue;
}
Err(err) => panic!("failed to read UDP query: {}", err),
}
}

let observed_len = observed_len.expect("no DNS query captured from client");
assert!(
observed_len <= max_qname_len,
"observed QNAME length {} exceeds limit {}",
observed_len,
max_qname_len
);
}
27 changes: 21 additions & 6 deletions crates/slipstream-dns/src/bin/bench_dns.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use slipstream_dns::{
build_qname, decode_query, decode_response, encode_query, encode_response,
max_payload_len_for_domain, QueryParams, Question, ResponseParams, CLASS_IN, RR_TXT,
build_qname_with_limit, decode_query, decode_response, encode_query, encode_response,
max_payload_len_for_domain_with_limit, QueryParams, Question, ResponseParams, CLASS_IN, RR_TXT,
};
use std::env;
use std::time::Instant;
Expand All @@ -12,6 +12,7 @@ fn main() {
let mut iterations = 10_000usize;
let mut payload_len = 256usize;
let mut domain = "test.com".to_string();
let mut max_qname_len = 253usize;

for arg in env::args().skip(1) {
if let Some(value) = arg.strip_prefix("--iterations=") {
Expand All @@ -20,13 +21,25 @@ fn main() {
payload_len = value.parse().unwrap_or(payload_len);
} else if let Some(value) = arg.strip_prefix("--domain=") {
domain = value.to_string();
} else if let Some(value) = arg.strip_prefix("--max-qname-len=") {
match value.parse::<usize>() {
Ok(value) if (1..=253).contains(&value) => max_qname_len = value,
Ok(_) => {
error!("Max QNAME length must be between 1 and 253.");
std::process::exit(1);
}
Err(_) => {
error!("Max QNAME length must be a positive integer.");
std::process::exit(1);
}
}
} else if arg == "--help" {
print_usage();
return;
}
}

let max_payload = match max_payload_len_for_domain(&domain) {
let max_payload = match max_payload_len_for_domain_with_limit(&domain, max_qname_len) {
Ok(limit) => limit,
Err(err) => {
error!("Invalid domain: {}", err);
Expand All @@ -46,7 +59,7 @@ fn main() {
}

let payload: Vec<u8> = (0..payload_len).map(|i| (i % 256) as u8).collect();
let qname = match build_qname(&payload, &domain) {
let qname = match build_qname_with_limit(&payload, &domain, max_qname_len) {
Ok(name) => name,
Err(err) => {
error!("Failed to build qname: {}", err);
Expand Down Expand Up @@ -81,7 +94,7 @@ fn main() {
let response = encode_response(&response_params).expect("encode response");

bench("build_qname", iterations, payload_len, || {
let _ = build_qname(&payload, &domain).expect("build qname");
let _ = build_qname_with_limit(&payload, &domain, max_qname_len).expect("build qname");
});
bench("encode_query", iterations, query.len(), || {
let _ = encode_query(&query_params).expect("encode query");
Expand Down Expand Up @@ -121,7 +134,9 @@ fn bench(label: &str, iterations: usize, bytes_per_iter: usize, mut f: impl FnMu
}

fn print_usage() {
println!("Usage: bench_dns [--iterations=N] [--payload-len=N] [--domain=NAME]");
println!(
"Usage: bench_dns [--iterations=N] [--payload-len=N] [--domain=NAME] [--max-qname-len=N]"
);
}

fn init_logging() {
Expand Down
Loading