Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/video-streamer/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Local-only test fixtures for the (ignored, XMF-gated) integration tests.
# These are real WebM recordings provided out-of-band, not committed to the repo.
testing-assets/
22 changes: 22 additions & 0 deletions crates/video-streamer/src/streamer/channel_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@ impl std::fmt::Display for ChannelWriterError {

impl std::error::Error for ChannelWriterError {}

impl ChannelWriterError {
/// Returns `true` if any error in `err`'s chain is a closed-channel error.
///
/// The destination channel closes when the consumer (the sending task / client) goes
/// away. Every write site treats this as a normal shutdown rather than a hard failure,
/// so they share this single classifier instead of re-deriving the downcast chain
/// (`TagWriterError::WriteError` -> `io::Error` -> `ChannelWriterError`) inline.
pub(crate) fn is_in_chain(err: &anyhow::Error) -> bool {
err.chain().any(|cause| {
// Either the chain link is the error itself, or it is the `io::Error` that wraps
// it. The latter case needs `get_ref()`: `io::Error::source()` exposes the inner
// error's *source* (here, nothing), not the inner error itself, so a plain chain
// walk never yields the `ChannelWriterError`.
cause.downcast_ref::<ChannelWriterError>().is_some()
|| cause
.downcast_ref::<io::Error>()
.and_then(io::Error::get_ref)
.is_some_and(|inner| inner.downcast_ref::<ChannelWriterError>().is_some())
})
}
}

pub(crate) struct ChannelWriter {
writer: tokio::sync::mpsc::Sender<Vec<u8>>,
#[cfg(feature = "perf-diagnostics")]
Expand Down
67 changes: 47 additions & 20 deletions crates/video-streamer/src/streamer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tokio::sync::{Mutex, Notify, watch};
use tokio_util::codec::Framed;
use tracing::Instrument;
use webm_iterable::WebmIterator;
use webm_iterable::errors::{TagIteratorError, TagWriterError};
use webm_iterable::errors::TagIteratorError;
use webm_iterable::matroska_spec::{Master, MatroskaSpec};

pub(crate) mod block_tag;
Expand Down Expand Up @@ -96,7 +96,16 @@ pub fn webm_stream(
let mut header_writer = HeaderWriter::new(chunk_writer);
perf_debug!(?headers);
for header in &headers {
header_writer.write(header)?;
if let Err(e) = header_writer.write(header) {
// A client that disconnects while we are still streaming the headers closes the
// destination channel. Treat this exactly like the main encode loop does below:
// it is normal shutdown, not a streaming failure.
if ChannelWriterError::is_in_chain(&e) {
debug!("client went away during header write; ending stream");
return Ok(());
}
return Err(e);
}
}

let (mut encode_writer, cut_block_hit_marker) = header_writer.into_encoded_writer(encode_writer_config)?;
Expand Down Expand Up @@ -168,20 +177,12 @@ pub fn webm_stream(
match encode_writer.write(tag) {
Ok(WriterResult::Continue) => continue,
Err(e) => {
let Some(TagWriterError::WriteError { source }) = e.downcast_ref::<TagWriterError>() else {
break Err(e);
};

if source.kind() != std::io::ErrorKind::Other {
break Err(e);
// A closed destination channel means the client/consumer is gone:
// normal shutdown, not a failure.
if ChannelWriterError::is_in_chain(&e) {
break Ok(());
}
let Some(ChannelWriterError::ChannelClosed) =
source.get_ref().and_then(|e| e.downcast_ref::<ChannelWriterError>())
else {
break Err(e);
};
// Channel is closed, we can break
break Ok(());
break Err(e);
}
}
}
Expand Down Expand Up @@ -246,6 +247,13 @@ fn spawn_sending_task<W>(
let ws_frame_clone = Arc::clone(&ws_frame);
let mut handle_shutdown_rx = shutdown_rx.clone();

// Both the message-handler task and the control task can reach a termination path, but
// the client must receive at most one terminal frame (End/Error). This flag arbitrates:
// whoever flips it first sends the terminal frame, the other skips it.
let terminal_sent = Arc::new(std::sync::atomic::AtomicBool::new(false));
let terminal_sent_handle = Arc::clone(&terminal_sent);
let terminal_sent_control = terminal_sent;

// Spawn a dedicated task to handle incoming messages from the client
// Reasoning: tokio::select! will stuck on `chunk_receiver.recv()` when there's no more data to receive
// This will disable the ability to receive shutdown signal
Expand Down Expand Up @@ -310,10 +318,10 @@ fn spawn_sending_task<W>(
let shutdown_reason = handle_shutdown_rx.borrow().clone();
match shutdown_reason {
StreamShutdown::Error(err) => {
ws_send(&ws_frame, protocol::ServerMessage::Error(err)).await;
ws_send_terminal_once(&terminal_sent_handle, &ws_frame, protocol::ServerMessage::Error(err)).await;
}
_ => {
ws_send(&ws_frame, protocol::ServerMessage::End).await;
ws_send_terminal_once(&terminal_sent_handle, &ws_frame, protocol::ServerMessage::End).await;
}
}
let _ = ws_frame.lock().await.get_mut().shutdown().await;
Expand All @@ -327,14 +335,15 @@ fn spawn_sending_task<W>(
let reason = shutdown_rx.borrow().clone();
match reason {
StreamShutdown::Error(err) => {
ws_send(&ws_frame_clone, protocol::ServerMessage::Error(err)).await;
ws_send_terminal_once(&terminal_sent_control, &ws_frame_clone, protocol::ServerMessage::Error(err))
.await;
}
StreamShutdown::ExternalShutdown => {
info!("Received shutdown signal");
ws_send(&ws_frame_clone, protocol::ServerMessage::End).await;
ws_send_terminal_once(&terminal_sent_control, &ws_frame_clone, protocol::ServerMessage::End).await;
}
StreamShutdown::ClientDisconnected => {
ws_send(&ws_frame_clone, protocol::ServerMessage::End).await;
ws_send_terminal_once(&terminal_sent_control, &ws_frame_clone, protocol::ServerMessage::End).await;
}
StreamShutdown::Running => {
// Spurious wake, shouldn't happen since we only send non-Running values
Expand Down Expand Up @@ -366,6 +375,24 @@ fn spawn_sending_task<W>(
warn!(error = %e, "Failed to send message to client");
});
}

/// Sends a terminal frame (End/Error) only if no terminal frame has been sent yet.
///
/// The handler and control tasks can both reach a termination path; this guarantees the
/// client sees a single terminal frame regardless of which task gets there first.
async fn ws_send_terminal_once<W: tokio::io::AsyncWrite + tokio::io::AsyncRead + Unpin + Send + 'static>(
terminal_sent: &std::sync::atomic::AtomicBool,
ws_frame: &Arc<Mutex<Framed<W, ProtocolCodeC>>>,
message: protocol::ServerMessage<'_>,
) {
use std::sync::atomic::Ordering;
if terminal_sent
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
ws_send(ws_frame, message).await;
}
}
}

#[derive(Clone, Debug, PartialEq)]
Expand Down
19 changes: 7 additions & 12 deletions crates/video-streamer/src/streamer/tag_writers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -618,20 +618,15 @@ where
let block: MatroskaSpec = block.into();
if let Err(e) = self.writer.write(&block) {
// When the client disconnects or we are shutting down, the destination channel is closed.
// This is normal control flow and is handled at a higher level.
if let TagWriterError::WriteError { source } = &e
&& source.kind() == std::io::ErrorKind::Other
&& source
.get_ref()
.and_then(|inner| inner.downcast_ref::<ChannelWriterError>())
.is_some_and(|inner| matches!(inner, &ChannelWriterError::ChannelClosed))
{
// This is normal control flow and is handled at a higher level, so only the genuine
// failures are logged as errors.
let e = anyhow::Error::from(e);
if ChannelWriterError::is_in_chain(&e) {
perf_trace!("write_block aborted - destination channel closed");
return Err(e.into());
} else {
error!(error = %e, "write_block failed");
}

error!(error = %e, "write_block failed");
return Err(e.into());
return Err(e);
}
perf_trace!("write_block completed successfully");
Ok(())
Expand Down
126 changes: 126 additions & 0 deletions crates/video-streamer/tests/webm_stream_correctness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,129 @@ async fn client_disconnect_exits_cleanly() {

let _ = h.writer_task.await;
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore]
/// Regression (#0): a client that goes away just as the first bytes (the WebM headers)
/// start flowing must exit cleanly (`Ok`), not surface a spurious error.
///
/// The first Chunk a client receives is the header bytes, so disconnecting right after
/// the first Chunk lands the teardown in the header-writing phase. The previous code
/// propagated the closed-channel error from the header path as a hard failure (the main
/// encode loop already treated it as a clean shutdown, but the header path did not).
///
/// Why a loop: whether the disconnect lands during the header writes is a scheduling
/// race. Repeating it makes hitting the header phase overwhelmingly likely.
async fn early_disconnect_during_headers_exits_ok() {
let _permit = global_stream_test_semaphore()
.acquire()
.await
.expect("failed to acquire global test semaphore");
init_tracing();
if !maybe_init_xmf() {
return;
}

for iter in 0..10u32 {
let mut h = spawn_stream_harness(asset_path("uncued-recording.webm"), LiveWriteConfig::default(), 1).await;
assert!(h.client_tx.send(vec![0]).is_ok(), "iter {iter}: failed to send Start");

// Pull until the first Chunk (the header bytes) arrives, then disconnect so the
// teardown races the header-writing phase.
let mut got_chunk = false;
let started_at = tokio::time::Instant::now();
while started_at.elapsed() < Duration::from_secs(30) {
assert!(h.client_tx.send(vec![1]).is_ok(), "iter {iter}: failed to send Pull");
if let Some(msg) = recv_server_message(&mut h.server_rx, Duration::from_secs(2)).await {
let (ty, _payload) = parse_server_message(&msg);
if ty == 0 {
got_chunk = true;
break;
}
}
}
assert!(got_chunk, "iter {iter}: never received a chunk");

drop(h.client_tx);
h.shutdown.notify_waiters();

let joined = tokio::time::timeout(Duration::from_secs(20), h.stream_task)
.await
.unwrap_or_else(|_| panic!("iter {iter}: timeout waiting for webm_stream to exit"))
.expect("webm_stream task panicked");

assert!(
joined.is_ok(),
"iter {iter}: webm_stream returned error on disconnect during headers: {:?}",
joined.err()
);

let _ = h.writer_task.await;
}
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[ignore]
/// Regression (#5): on shutdown the server must emit at most one terminal frame
/// (`End` or `Error`).
///
/// Previously both the message-handler task and the control task independently sent a
/// terminal frame on every non-EOF termination path, so an external shutdown produced
/// two `End` frames. We reach steady state (at least one chunk) and then trigger an
/// external shutdown, draining the remaining frames and counting terminal ones.
async fn external_shutdown_emits_single_terminal_frame() {
let _permit = global_stream_test_semaphore()
.acquire()
.await
.expect("failed to acquire global test semaphore");
init_tracing();
if !maybe_init_xmf() {
return;
}

let mut h = spawn_stream_harness(asset_path("uncued-recording.webm"), LiveWriteConfig::default(), 1).await;
assert!(h.client_tx.send(vec![0]).is_ok(), "failed to send Start");

// Drive to steady state: pull until we observe at least one Chunk.
let mut got_chunk = false;
let started_at = tokio::time::Instant::now();
while started_at.elapsed() < Duration::from_secs(30) {
assert!(h.client_tx.send(vec![1]).is_ok(), "failed to send Pull");
if let Some(msg) = recv_server_message(&mut h.server_rx, Duration::from_secs(2)).await {
let (ty, payload) = parse_server_message(&msg);
if ty == 2 {
panic!("received ServerMessage::Error before shutdown: {}", String::from_utf8_lossy(payload));
}
if ty == 0 {
got_chunk = true;
break;
}
}
}
assert!(got_chunk, "never received a chunk before shutdown");

// Trigger external shutdown and drain remaining frames.
h.shutdown.notify_waiters();

let mut terminal_frames = 0u32;
for _ in 0..50 {
let _ = h.client_tx.send(vec![1]);
match recv_server_message(&mut h.server_rx, Duration::from_secs(1)).await {
Some(msg) => {
let (ty, _payload) = parse_server_message(&msg);
if ty == 2 || ty == 3 {
terminal_frames += 1;
}
}
None => break,
}
}

assert!(
terminal_frames <= 1,
"expected at most one terminal frame on shutdown, got {terminal_frames}"
);

let _ = tokio::time::timeout(Duration::from_secs(10), h.stream_task).await;
let _ = h.writer_task.await;
}
Loading