From 8bacc64cf2b0fb417b21e8e59be113d86cd288ce Mon Sep 17 00:00:00 2001 From: irving ou Date: Fri, 12 Jun 2026 17:59:57 -0400 Subject: [PATCH] fix(video-streamer): handle channel close during header write and emit a single terminal frame Hardening of the WebM shadow streamer surfaced by the integration harness. - Header-write path now treats a closed destination channel as a clean shutdown (Ok), matching the main encode loop. Previously a client that disconnected while the initial WebM headers were still being written caused `webm_stream` to return Err, which the caller turned into a spurious server_error to the client. - spawn_sending_task now arbitrates the terminal frame with a shared AtomicBool so the client receives at most one End/Error frame. Both the message-handler and the control task could previously emit one on the same termination path. - Consolidate the closed-channel detection (previously a fragile triple downcast in two places) into ChannelWriterError::is_in_chain. The classifier walks the anyhow chain and peeks inside io::Error via get_ref(), since io::Error::source() exposes the inner error's source rather than the ChannelWriterError itself. Adds two ignored, XMF-gated regression tests covering both fixes. --- crates/video-streamer/.gitignore | 3 + .../src/streamer/channel_writer.rs | 22 +++ crates/video-streamer/src/streamer/mod.rs | 67 +++++++--- .../src/streamer/tag_writers.rs | 19 +-- .../tests/webm_stream_correctness.rs | 126 ++++++++++++++++++ 5 files changed, 205 insertions(+), 32 deletions(-) create mode 100644 crates/video-streamer/.gitignore diff --git a/crates/video-streamer/.gitignore b/crates/video-streamer/.gitignore new file mode 100644 index 000000000..7a66bb30a --- /dev/null +++ b/crates/video-streamer/.gitignore @@ -0,0 +1,3 @@ +# Local-only test fixtures for the (ignored, XMF-gated) integration tests. +# These are real WebM recordings provided out-of-band, not committed to the repo. +testing-assets/ diff --git a/crates/video-streamer/src/streamer/channel_writer.rs b/crates/video-streamer/src/streamer/channel_writer.rs index 29af214bf..26faef96f 100644 --- a/crates/video-streamer/src/streamer/channel_writer.rs +++ b/crates/video-streamer/src/streamer/channel_writer.rs @@ -15,6 +15,28 @@ impl std::fmt::Display for ChannelWriterError { impl std::error::Error for ChannelWriterError {} +impl ChannelWriterError { + /// Returns `true` if any error in `err`'s chain is a closed-channel error. + /// + /// The destination channel closes when the consumer (the sending task / client) goes + /// away. Every write site treats this as a normal shutdown rather than a hard failure, + /// so they share this single classifier instead of re-deriving the downcast chain + /// (`TagWriterError::WriteError` -> `io::Error` -> `ChannelWriterError`) inline. + pub(crate) fn is_in_chain(err: &anyhow::Error) -> bool { + err.chain().any(|cause| { + // Either the chain link is the error itself, or it is the `io::Error` that wraps + // it. The latter case needs `get_ref()`: `io::Error::source()` exposes the inner + // error's *source* (here, nothing), not the inner error itself, so a plain chain + // walk never yields the `ChannelWriterError`. + cause.downcast_ref::().is_some() + || cause + .downcast_ref::() + .and_then(io::Error::get_ref) + .is_some_and(|inner| inner.downcast_ref::().is_some()) + }) + } +} + pub(crate) struct ChannelWriter { writer: tokio::sync::mpsc::Sender>, #[cfg(feature = "perf-diagnostics")] diff --git a/crates/video-streamer/src/streamer/mod.rs b/crates/video-streamer/src/streamer/mod.rs index 7d4f13eb7..c5b68580d 100644 --- a/crates/video-streamer/src/streamer/mod.rs +++ b/crates/video-streamer/src/streamer/mod.rs @@ -10,7 +10,7 @@ use tokio::sync::{Mutex, Notify, watch}; use tokio_util::codec::Framed; use tracing::Instrument; use webm_iterable::WebmIterator; -use webm_iterable::errors::{TagIteratorError, TagWriterError}; +use webm_iterable::errors::TagIteratorError; use webm_iterable::matroska_spec::{Master, MatroskaSpec}; pub(crate) mod block_tag; @@ -96,7 +96,16 @@ pub fn webm_stream( let mut header_writer = HeaderWriter::new(chunk_writer); perf_debug!(?headers); for header in &headers { - header_writer.write(header)?; + if let Err(e) = header_writer.write(header) { + // A client that disconnects while we are still streaming the headers closes the + // destination channel. Treat this exactly like the main encode loop does below: + // it is normal shutdown, not a streaming failure. + if ChannelWriterError::is_in_chain(&e) { + debug!("client went away during header write; ending stream"); + return Ok(()); + } + return Err(e); + } } let (mut encode_writer, cut_block_hit_marker) = header_writer.into_encoded_writer(encode_writer_config)?; @@ -168,20 +177,12 @@ pub fn webm_stream( match encode_writer.write(tag) { Ok(WriterResult::Continue) => continue, Err(e) => { - let Some(TagWriterError::WriteError { source }) = e.downcast_ref::() else { - break Err(e); - }; - - if source.kind() != std::io::ErrorKind::Other { - break Err(e); + // A closed destination channel means the client/consumer is gone: + // normal shutdown, not a failure. + if ChannelWriterError::is_in_chain(&e) { + break Ok(()); } - let Some(ChannelWriterError::ChannelClosed) = - source.get_ref().and_then(|e| e.downcast_ref::()) - else { - break Err(e); - }; - // Channel is closed, we can break - break Ok(()); + break Err(e); } } } @@ -246,6 +247,13 @@ fn spawn_sending_task( let ws_frame_clone = Arc::clone(&ws_frame); let mut handle_shutdown_rx = shutdown_rx.clone(); + // Both the message-handler task and the control task can reach a termination path, but + // the client must receive at most one terminal frame (End/Error). This flag arbitrates: + // whoever flips it first sends the terminal frame, the other skips it. + let terminal_sent = Arc::new(std::sync::atomic::AtomicBool::new(false)); + let terminal_sent_handle = Arc::clone(&terminal_sent); + let terminal_sent_control = terminal_sent; + // Spawn a dedicated task to handle incoming messages from the client // Reasoning: tokio::select! will stuck on `chunk_receiver.recv()` when there's no more data to receive // This will disable the ability to receive shutdown signal @@ -310,10 +318,10 @@ fn spawn_sending_task( let shutdown_reason = handle_shutdown_rx.borrow().clone(); match shutdown_reason { StreamShutdown::Error(err) => { - ws_send(&ws_frame, protocol::ServerMessage::Error(err)).await; + ws_send_terminal_once(&terminal_sent_handle, &ws_frame, protocol::ServerMessage::Error(err)).await; } _ => { - ws_send(&ws_frame, protocol::ServerMessage::End).await; + ws_send_terminal_once(&terminal_sent_handle, &ws_frame, protocol::ServerMessage::End).await; } } let _ = ws_frame.lock().await.get_mut().shutdown().await; @@ -327,14 +335,15 @@ fn spawn_sending_task( let reason = shutdown_rx.borrow().clone(); match reason { StreamShutdown::Error(err) => { - ws_send(&ws_frame_clone, protocol::ServerMessage::Error(err)).await; + ws_send_terminal_once(&terminal_sent_control, &ws_frame_clone, protocol::ServerMessage::Error(err)) + .await; } StreamShutdown::ExternalShutdown => { info!("Received shutdown signal"); - ws_send(&ws_frame_clone, protocol::ServerMessage::End).await; + ws_send_terminal_once(&terminal_sent_control, &ws_frame_clone, protocol::ServerMessage::End).await; } StreamShutdown::ClientDisconnected => { - ws_send(&ws_frame_clone, protocol::ServerMessage::End).await; + ws_send_terminal_once(&terminal_sent_control, &ws_frame_clone, protocol::ServerMessage::End).await; } StreamShutdown::Running => { // Spurious wake, shouldn't happen since we only send non-Running values @@ -366,6 +375,24 @@ fn spawn_sending_task( warn!(error = %e, "Failed to send message to client"); }); } + + /// Sends a terminal frame (End/Error) only if no terminal frame has been sent yet. + /// + /// The handler and control tasks can both reach a termination path; this guarantees the + /// client sees a single terminal frame regardless of which task gets there first. + async fn ws_send_terminal_once( + terminal_sent: &std::sync::atomic::AtomicBool, + ws_frame: &Arc>>, + message: protocol::ServerMessage<'_>, + ) { + use std::sync::atomic::Ordering; + if terminal_sent + .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) + .is_ok() + { + ws_send(ws_frame, message).await; + } + } } #[derive(Clone, Debug, PartialEq)] diff --git a/crates/video-streamer/src/streamer/tag_writers.rs b/crates/video-streamer/src/streamer/tag_writers.rs index 70cd2e72e..346245917 100644 --- a/crates/video-streamer/src/streamer/tag_writers.rs +++ b/crates/video-streamer/src/streamer/tag_writers.rs @@ -618,20 +618,15 @@ where let block: MatroskaSpec = block.into(); if let Err(e) = self.writer.write(&block) { // When the client disconnects or we are shutting down, the destination channel is closed. - // This is normal control flow and is handled at a higher level. - if let TagWriterError::WriteError { source } = &e - && source.kind() == std::io::ErrorKind::Other - && source - .get_ref() - .and_then(|inner| inner.downcast_ref::()) - .is_some_and(|inner| matches!(inner, &ChannelWriterError::ChannelClosed)) - { + // This is normal control flow and is handled at a higher level, so only the genuine + // failures are logged as errors. + let e = anyhow::Error::from(e); + if ChannelWriterError::is_in_chain(&e) { perf_trace!("write_block aborted - destination channel closed"); - return Err(e.into()); + } else { + error!(error = %e, "write_block failed"); } - - error!(error = %e, "write_block failed"); - return Err(e.into()); + return Err(e); } perf_trace!("write_block completed successfully"); Ok(()) diff --git a/crates/video-streamer/tests/webm_stream_correctness.rs b/crates/video-streamer/tests/webm_stream_correctness.rs index 22a407c02..8758168e8 100644 --- a/crates/video-streamer/tests/webm_stream_correctness.rs +++ b/crates/video-streamer/tests/webm_stream_correctness.rs @@ -379,3 +379,129 @@ async fn client_disconnect_exits_cleanly() { let _ = h.writer_task.await; } + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[ignore] +/// Regression (#0): a client that goes away just as the first bytes (the WebM headers) +/// start flowing must exit cleanly (`Ok`), not surface a spurious error. +/// +/// The first Chunk a client receives is the header bytes, so disconnecting right after +/// the first Chunk lands the teardown in the header-writing phase. The previous code +/// propagated the closed-channel error from the header path as a hard failure (the main +/// encode loop already treated it as a clean shutdown, but the header path did not). +/// +/// Why a loop: whether the disconnect lands during the header writes is a scheduling +/// race. Repeating it makes hitting the header phase overwhelmingly likely. +async fn early_disconnect_during_headers_exits_ok() { + let _permit = global_stream_test_semaphore() + .acquire() + .await + .expect("failed to acquire global test semaphore"); + init_tracing(); + if !maybe_init_xmf() { + return; + } + + for iter in 0..10u32 { + let mut h = spawn_stream_harness(asset_path("uncued-recording.webm"), LiveWriteConfig::default(), 1).await; + assert!(h.client_tx.send(vec![0]).is_ok(), "iter {iter}: failed to send Start"); + + // Pull until the first Chunk (the header bytes) arrives, then disconnect so the + // teardown races the header-writing phase. + let mut got_chunk = false; + let started_at = tokio::time::Instant::now(); + while started_at.elapsed() < Duration::from_secs(30) { + assert!(h.client_tx.send(vec![1]).is_ok(), "iter {iter}: failed to send Pull"); + if let Some(msg) = recv_server_message(&mut h.server_rx, Duration::from_secs(2)).await { + let (ty, _payload) = parse_server_message(&msg); + if ty == 0 { + got_chunk = true; + break; + } + } + } + assert!(got_chunk, "iter {iter}: never received a chunk"); + + drop(h.client_tx); + h.shutdown.notify_waiters(); + + let joined = tokio::time::timeout(Duration::from_secs(20), h.stream_task) + .await + .unwrap_or_else(|_| panic!("iter {iter}: timeout waiting for webm_stream to exit")) + .expect("webm_stream task panicked"); + + assert!( + joined.is_ok(), + "iter {iter}: webm_stream returned error on disconnect during headers: {:?}", + joined.err() + ); + + let _ = h.writer_task.await; + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[ignore] +/// Regression (#5): on shutdown the server must emit at most one terminal frame +/// (`End` or `Error`). +/// +/// Previously both the message-handler task and the control task independently sent a +/// terminal frame on every non-EOF termination path, so an external shutdown produced +/// two `End` frames. We reach steady state (at least one chunk) and then trigger an +/// external shutdown, draining the remaining frames and counting terminal ones. +async fn external_shutdown_emits_single_terminal_frame() { + let _permit = global_stream_test_semaphore() + .acquire() + .await + .expect("failed to acquire global test semaphore"); + init_tracing(); + if !maybe_init_xmf() { + return; + } + + let mut h = spawn_stream_harness(asset_path("uncued-recording.webm"), LiveWriteConfig::default(), 1).await; + assert!(h.client_tx.send(vec![0]).is_ok(), "failed to send Start"); + + // Drive to steady state: pull until we observe at least one Chunk. + let mut got_chunk = false; + let started_at = tokio::time::Instant::now(); + while started_at.elapsed() < Duration::from_secs(30) { + assert!(h.client_tx.send(vec![1]).is_ok(), "failed to send Pull"); + if let Some(msg) = recv_server_message(&mut h.server_rx, Duration::from_secs(2)).await { + let (ty, payload) = parse_server_message(&msg); + if ty == 2 { + panic!("received ServerMessage::Error before shutdown: {}", String::from_utf8_lossy(payload)); + } + if ty == 0 { + got_chunk = true; + break; + } + } + } + assert!(got_chunk, "never received a chunk before shutdown"); + + // Trigger external shutdown and drain remaining frames. + h.shutdown.notify_waiters(); + + let mut terminal_frames = 0u32; + for _ in 0..50 { + let _ = h.client_tx.send(vec![1]); + match recv_server_message(&mut h.server_rx, Duration::from_secs(1)).await { + Some(msg) => { + let (ty, _payload) = parse_server_message(&msg); + if ty == 2 || ty == 3 { + terminal_frames += 1; + } + } + None => break, + } + } + + assert!( + terminal_frames <= 1, + "expected at most one terminal frame on shutdown, got {terminal_frames}" + ); + + let _ = tokio::time::timeout(Duration::from_secs(10), h.stream_task).await; + let _ = h.writer_task.await; +}