Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,17 @@ impl DnsTcpHandler {
if !dns.submit_tcp_query(&request, self.receiver.sender()) {
tracelimit::warn_ratelimited!(
msg_len,
src_port = self.flow.src_port,
src = %self.flow.src,
"dns_tcp: query rate-limited, closing connection"
);
return Err(DnsTcpError::RateLimited);
}
tracing::trace!(
msg_len,
src = %self.flow.src,
dst = %self.flow.dst,
"dns_tcp: query submitted, entering in-flight",
);
self.buf.clear();
self.phase = Phase::InFlight;
Ok(true)
Expand All @@ -195,10 +201,15 @@ impl DnsTcpHandler {
Ok(response) => {
dns.complete_tcp_query();
let payload_len = response.response_data.len();
tracing::trace!(
payload_len,
src = %self.flow.src,
"dns_tcp: response received from backend resolver",
);
if payload_len > MAX_DNS_TCP_PAYLOAD_SIZE {
tracelimit::warn_ratelimited!(
size = payload_len,
"DNS TCP response exceeds maximum message size"
"dns_tcp: response exceeds maximum message size"
);
return Poll::Ready(Err(DnsTcpError::ResponseTooLarge));
}
Expand All @@ -217,6 +228,10 @@ impl DnsTcpHandler {
}
Err(_) => {
dns.complete_tcp_query();
tracing::trace!(
src = %self.flow.src,
"dns_tcp: query cancelled (channel closed without response)",
);
return Poll::Ready(Err(DnsTcpError::QueryCancelled));
}
},
Expand Down Expand Up @@ -288,6 +303,7 @@ mod tests {
&self,
request: &DnsRequest<'_>,
response_sender: mesh_channel_core::Sender<DnsResponse>,
_query_id: u64,
) {
response_sender.send(DnsResponse {
flow: request.flow.clone(),
Expand All @@ -298,13 +314,10 @@ mod tests {

fn test_flow() -> DnsFlow {
use smoltcp::wire::EthernetAddress;
use smoltcp::wire::IpAddress;
use smoltcp::wire::Ipv4Address;
use std::net::SocketAddr;
DnsFlow {
src_addr: IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 2)),
dst_addr: IpAddress::Ipv4(Ipv4Address::new(10, 0, 0, 1)),
src_port: 12345,
dst_port: 53,
src: SocketAddr::new([10, 0, 0, 2].into(), 12345),
dst: SocketAddr::new([10, 0, 0, 1].into(), 53),
gateway_mac: EthernetAddress([0x52, 0x55, 10, 0, 0, 1]),
client_mac: EthernetAddress([0, 0, 0, 0, 1, 0]),
transport: crate::dns_resolver::DnsTransport::Tcp,
Expand Down
20 changes: 13 additions & 7 deletions vm/devices/net/net_consomme/consomme/src/dns_resolver/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use inspect::Inspect;
use mesh_channel_core::Receiver;
use mesh_channel_core::Sender;
use smoltcp::wire::EthernetAddress;
use smoltcp::wire::IpAddress;
use std::net::SocketAddr;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
Expand Down Expand Up @@ -37,10 +37,8 @@ pub enum DnsTransport {

#[derive(Debug, Clone)]
pub struct DnsFlow {
pub src_addr: IpAddress,
pub dst_addr: IpAddress,
pub src_port: u16,
pub dst_port: u16,
pub src: SocketAddr,
pub dst: SocketAddr,
pub gateway_mac: EthernetAddress,
pub client_mac: EthernetAddress,
// Used by the glibc and Windows DNS backends. The musl resolver
Expand Down Expand Up @@ -70,7 +68,7 @@ pub struct DnsResponse {
/// TCP 2-byte length prefix). Transport framing is the responsibility of the
/// caller (see [`dns_tcp::DnsTcpHandler`]).
pub(crate) trait DnsBackend: Send + Sync {
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>);
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>, query_id: u64);
}

#[derive(Inspect)]
Expand All @@ -85,6 +83,7 @@ pub struct DnsResolver<B: DnsBackend = PlatformDnsBackend> {
udp_receiver: Receiver<DnsResponse>,
pending_requests: usize,
max_pending_requests: usize,
next_query_id: u64,
}

/// Default maximum number of pending DNS requests.
Expand All @@ -105,6 +104,7 @@ impl DnsResolver {
udp_receiver,
pending_requests: 0,
max_pending_requests,
next_query_id: 0,
})
}

Expand All @@ -122,6 +122,7 @@ impl DnsResolver {
udp_receiver,
pending_requests: 0,
max_pending_requests,
next_query_id: 0,
})
}
}
Expand All @@ -138,8 +139,10 @@ impl<B: DnsBackend> DnsResolver<B> {
response_sender: Sender<DnsResponse>,
) -> bool {
if self.pending_requests < self.max_pending_requests {
let query_id = self.next_query_id;
self.next_query_id += 1;
self.pending_requests += 1;
self.backend.query(request, response_sender);
self.backend.query(request, response_sender, query_id);
true
Comment thread
damanm24 marked this conversation as resolved.
} else {
tracelimit::warn_ratelimited!(
Expand Down Expand Up @@ -222,13 +225,16 @@ impl<B: DnsBackend> DnsResolver<B> {
udp_receiver,
pending_requests: 0,
max_pending_requests: DEFAULT_MAX_PENDING_DNS_REQUESTS,
next_query_id: 0,
}
}
}

/// Internal DNS request structure used by backend implementations.
#[derive(Debug)]
pub(crate) struct DnsRequestInternal {
#[cfg_attr(not(target_os = "windows"), expect(dead_code))]
pub query_id: u64,
pub flow: DnsFlow,
pub query: Vec<u8>,
pub response_sender: Sender<DnsResponse>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ impl DnsBackend for UnixDnsResolverBackend {
///
/// Each query spawns a blocking task that uses the appropriate resolver
/// functions for the target platform.
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>) {
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>, query_id: u64) {
let flow = request.flow.clone();
let query = request.dns_query.to_vec();

blocking::unblock(move || {
handle_dns_query(DnsRequestInternal {
query_id,
flow,
query,
response_sender,
Comment thread
damanm24 marked this conversation as resolved.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,12 @@ fn is_dns_raw_apis_supported() -> bool {

/// Context passed to the DNS query callback.
struct RawCallbackContext {
request_id: usize,
slab_key: usize,
request: DnsRequestInternal,
pending_requests: Arc<Mutex<Slab<DNS_QUERY_RAW_CANCEL>>>,
}

pub struct WindowsDnsResolverBackend {
/// Map of pending DNS requests (for cancellation support).
pending_requests: Arc<Mutex<Slab<DNS_QUERY_RAW_CANCEL>>>,
}

Expand All @@ -71,7 +70,7 @@ impl WindowsDnsResolverBackend {
}

impl DnsBackend for WindowsDnsResolverBackend {
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>) {
fn query(&self, request: &DnsRequest<'_>, response_sender: Sender<DnsResponse>, query_id: u64) {
// Clone the sender for error handling
let response_sender_clone = response_sender.clone();

Expand All @@ -92,6 +91,7 @@ impl DnsBackend for WindowsDnsResolverBackend {
// Create internal request with raw DNS bytes (no TCP prefix) so that
// SERVFAIL generation works correctly.
let internal_request = DnsRequestInternal {
query_id,
flow: request.flow.clone(),
query: request.dns_query.to_vec(),
response_sender,
Expand All @@ -102,14 +102,25 @@ impl DnsBackend for WindowsDnsResolverBackend {

// Pre-insert placeholder before calling DnsQueryRaw to avoid race condition
// where callback fires before we can insert the cancel handle.
let request_id = self
let slab_key = self
.pending_requests
.lock()
.insert(DNS_QUERY_RAW_CANCEL::default());

let pending_count = self.pending_requests.lock().len();
tracing::trace!(
query_id,
pending_count,
query_len = dns_query_size,
src = %request.flow.src,
dst = %request.flow.dst,
transport = ?request.flow.transport,
"dns_windows: submitting query to DnsQueryRaw",
);

// Create callback context
let context = Box::new(RawCallbackContext {
request_id,
slab_key,
request: internal_request,
pending_requests: self.pending_requests.clone(),
});
Expand Down Expand Up @@ -150,14 +161,21 @@ impl DnsBackend for WindowsDnsResolverBackend {
// If the callback already fired and removed the entry, this is a no-op.
{
let mut pending = self.pending_requests.lock();
if let Some(v) = pending.get_mut(request_id) {
if let Some(v) = pending.get_mut(slab_key) {
*v = cancel_handle;
}
}
} else {
// Remove placeholder since callback won't fire on error
self.pending_requests.lock().remove(request_id);
tracelimit::warn_ratelimited!("DnsQueryRaw failed with error code: {}", result);
self.pending_requests.lock().remove(slab_key);
tracelimit::warn_ratelimited!(
query_id,
src = %request.flow.src,
dst = %request.flow.dst,
transport = ?request.flow.transport,
result,
"dns_windows: DnsQueryRaw failed",
);
// SAFETY: We're reclaiming ownership of the context we just created
unsafe {
let _ = Box::from_raw(context_ptr);
Expand Down Expand Up @@ -243,9 +261,17 @@ unsafe extern "system" fn dns_query_raw_callback(

{
let mut pending = context.pending_requests.lock();
pending.remove(context.request_id);
let _ = pending.try_remove(context.slab_key);
}

tracing::trace!(
query_id = context.request.query_id,
src = %context.request.flow.src,
dst = %context.request.flow.dst,
transport = ?context.request.flow.transport,
"dns_windows: callback fired",
);

// SAFETY: query_results is provided by Windows and will be freed after processing
let response = match unsafe { process_dns_results(query_results) } {
Ok(mut response_data) => {
Expand Down
46 changes: 35 additions & 11 deletions vm/devices/net/net_consomme/consomme/src/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ impl<T: Client> Access<'_, T> {
src: SocketAddr::V6(SocketAddrV6::new(addresses.src_addr, tcp.src_port, 0, 0)),
},
};
trace_tcp_packet(&tcp, tcp.payload.len(), "recv");
trace_tcp_packet(&ft, &tcp, tcp.payload.len(), "recv");

let is_dns_tcp =
is_gateway_dns_tcp(&ft, &self.inner.state.params, self.inner.dns.is_some());
Expand Down Expand Up @@ -569,7 +569,7 @@ impl<T: Client> Sender<'_, T> {
payload: &[],
};

trace_tcp_packet(&tcp, 0, "rst xmit");
trace_tcp_packet(self.ft, &tcp, 0, "rst xmit");

self.send_packet(&tcp, None);
}
Expand Down Expand Up @@ -709,10 +709,8 @@ impl TcpConnection {
inner.initialize_from_first_client_packet(tcp)?;

let flow = crate::dns_resolver::DnsFlow {
src_addr: sender.ft.src.ip().into(),
dst_addr: sender.ft.dst.ip().into(),
src_port: sender.ft.src.port(),
dst_port: sender.ft.dst.port(),
src: sender.ft.src,
dst: sender.ft.dst,
gateway_mac: sender.state.params.gateway_mac,
client_mac: sender.state.params.client_mac,
transport: crate::dns_resolver::DnsTransport::Tcp,
Expand Down Expand Up @@ -769,6 +767,13 @@ impl TcpConnectionInner {
// Propagate guest FIN before the tx path so that poll_read can
// detect EOF on the same iteration.
if self.state.rx_fin() && !dns_handler.guest_fin() {
tracing::trace!(
src = %sender.ft.src,
dst = %sender.ft.dst,
tx_buffer_len = self.tx_buffer.len(),
tx_buffer_full = self.tx_buffer.is_full(),
"tcp: guest FIN received, signaling EOF to DNS handler",
);
dns_handler.set_guest_fin();
}

Expand All @@ -787,6 +792,14 @@ impl TcpConnectionInner {
break;
}
self.tx_buffer.extend_by(n);
tracing::trace!(
src = %sender.ft.src,
dst = %sender.ft.dst,
n,
tx_buffer_len = self.tx_buffer.len(),
tx_buffer_full = self.tx_buffer.is_full(),
"tcp: response from DNS handler into tx_buffer",
);
}
Poll::Ready(Err(_)) => {
sender.rst(self.tx_send, Some(self.rx_seq));
Expand Down Expand Up @@ -839,7 +852,11 @@ impl TcpConnectionInner {
return false;
}

tracing::debug!("connection established");
tracing::debug!(
src = %sender.ft.src,
dst = %sender.ft.dst,
"connection established",
);
self.state = TcpState::SynReceived;
}
Poll::Pending => return true,
Expand Down Expand Up @@ -955,7 +972,12 @@ impl TcpConnectionInner {
// Avoid resetting so that the guest doesn't think there is a
// responding TCP stack at this address. The guest will time out on
// its own.
tracing::debug!(error = &err as &dyn std::error::Error, "connect timed out");
tracing::debug!(
src = %sender.ft.src,
dst = %sender.ft.dst,
error = &err as &dyn std::error::Error,
"connect timed out",
);
} else {
log_connect_error(&err);
sender.rst(self.tx_send, Some(self.rx_seq));
Expand Down Expand Up @@ -1082,7 +1104,7 @@ impl TcpConnectionInner {
assert!(tx_next <= tx_end);
assert!(self.needs_ack || tx_next > self.tx_send);

trace_tcp_packet(&tcp, payload_len, "xmit");
trace_tcp_packet(sender.ft, &tcp, payload_len, "xmit");

let payload = self
.tx_buffer
Expand Down Expand Up @@ -1137,7 +1159,7 @@ impl TcpConnectionInner {
payload: &[],
};

trace_tcp_packet(&tcp, 0, "ack");
trace_tcp_packet(sender.ft, &tcp, 0, "ack");

sender.send_packet(&tcp, None);
}
Expand Down Expand Up @@ -1439,9 +1461,11 @@ impl TcpListener {
/// Logs protocol-relevant fields (flags, seq, ack, window, payload length)
/// as individual tracing fields instead of dumping the full `TcpRepr` Debug
/// output which includes raw payload bytes.
fn trace_tcp_packet(tcp: &TcpRepr<'_>, payload_len: usize, label: &str) {
fn trace_tcp_packet(ft: &FourTuple, tcp: &TcpRepr<'_>, payload_len: usize, label: &str) {
tracing::trace!(
label,
src = %ft.src,
dst = %ft.dst,
flags = match tcp.control {
TcpControl::Syn => Some("SYN"),
TcpControl::Fin => Some("FIN"),
Expand Down
Loading
Loading