diff --git a/.planning/ws-disconnect-cancellation.plan.md b/.planning/ws-disconnect-cancellation.plan.md new file mode 100644 index 0000000..cefd631 --- /dev/null +++ b/.planning/ws-disconnect-cancellation.plan.md @@ -0,0 +1,58 @@ +# WebSocket Disconnect Cancellation Plan + +## Goal + +Close issue #111 by making WebSocket handlers observe client disconnects at interpreter safe points, matching the cancellation behavior added for normal HTTP handlers. + +## Current State + +- HTTP handlers use a per-request `Arc` plus `CancelOnDrop`; the forked interpreter polls that flag and exits when axum drops the response future. +- WebSocket routes currently fork one interpreter per connection and hold it in a `parking_lot::Mutex`. +- The WS loop calls the Forge handler synchronously inside the async upgrade task. While a long handler is running, the task is not polling `socket.recv()`, so it cannot observe a closed socket or `Message::Close`. +- The forked WS interpreter currently keeps its default cancellation token instead of a connection-scoped token wired to socket lifecycle. + +## Implementation + +1. In the `"WS"` route branch in `src/runtime/server.rs`, create one connection-scoped `Arc` at upgrade time and assign it to the forked interpreter before wrapping it in the connection mutex. +2. Split the WebSocket into sender and receiver halves. +3. Spawn a lightweight receiver task that: + - forwards `Message::Text` payloads to the main per-connection loop through a bounded Tokio channel with capacity 1, + - uses non-blocking `try_send`; if a client sends more than one queued message while the previous handler is still running, treat that as connection backpressure overflow, set the cancel flag, and stop the receiver, + - treats `Message::Close`, receive errors, and stream end as disconnect, + - stores `true` into the connection cancel flag on disconnect. +4. Process text messages sequentially in the main loop: + - run the Forge handler inside `tokio::task::spawn_blocking`, entering the current tracing span like the HTTP handler path, + - clone the `Arc>` into the blocking closure; acquire and drop the `MutexGuard` entirely inside that closure. The guard must never be held across an `.await` or acquired on the async side before entering `spawn_blocking`, + - after the handler returns, skip sending if the cancel flag was set, + - if `sender.send()` fails, set the cancel flag and stop. +5. Add a local drop guard for the upgrade task so any exit path sets the cancel flag. +6. Respect client `Message::Close` by setting the flag and terminating the connection loop. +7. Preserve current non-cancellation error semantics: if the handler returns an error and the connection is still active, send `error: ` back as before. If the cancel flag is set, skip the send because the peer is gone or the connection is closing. +8. Abort the receiver task when the main connection loop exits so shutdown does not leave a detached task holding connection resources. +9. Document the Ping/Pong assumption: axum 0.8 wraps tungstenite, whose codec handles automatic Pong responses before yielding messages. Non-text messages other than `Close` remain ignored as today. +10. Avoid changing per-connection state semantics: messages on one WS connection remain sequential and share the same forked interpreter; different WS connections stay isolated. + +## Tests + +Add an integration test in `tests/server_concurrency.rs` or a new WS-focused integration test using `tokio_tungstenite`: + +1. Boot a Forge server with: + - `/ping` for readiness, + - a `@ws("/ws")` handler that writes a temp `started` sentinel, runs a long loop with at least one statement boundary per iteration, periodically writes a `progress` sentinel, and writes a `finished` sentinel only after the loop completes. +2. Connect a WS client, send one text message, wait until `started` proves the handler is running, then close/drop the client without waiting for a response. +3. Wait for `progress` to stop changing after disconnect. Because the handler writes progress from inside the loop body, a continued-running handler keeps changing this file; a cancelled handler stops. +4. Assert `finished` does not appear. If it appears, the loop completed normally instead of being cancelled. +5. Keep the loop body cancellation-friendly by using a statement boundary inside the `repeat` body; the interpreter checks `cancelled` at each `exec_stmt`. + +Note: a Forge-level positive `after_safe` sentinel is not viable because the same cancellation flag remains set after `safe { ... }` catches the first `cancelled` error; the next statement would immediately observe cancellation before writing the sentinel. + +Run: + +- `cargo fmt` +- focused WS integration test +- `cargo test --test server_concurrency` +- `cargo test` + +## Rollback + +Revert the WS branch changes and remove the new integration test. HTTP handler cancellation remains unchanged. diff --git a/CHANGELOG.md b/CHANGELOG.md index 06ee012..0ef9050 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- **WebSocket handlers now observe client disconnect cancellation** — WS connections install a connection-scoped cancellation token, run message handlers on the blocking pool, and keep polling the socket for close/error while handlers run so long-running loops exit at the next interpreter safe point. ([#146](https://github.com/humancto/forge-lang/pull/146)) - **VM error traces now include columns when available** — bytecode chunks carry source columns alongside line tables, old v1.1 bytecode still deserializes with zero columns, and standalone decorator statements now fail VM compilation instead of being silently ignored. - **OpenTelemetry feature path is now CI-tested and cheaper when inactive** — CI builds `--features otel`, the OTel export path has a smoke test, and request-span traceparent extraction is skipped unless OTel export was activated at runtime. - **Empty request IDs no longer produce blank span fields** — inbound `X-Request-Id: ` now records as `"unknown"` with a warning, and request-id extraction is covered for empty, non-ASCII, and oversized header values. diff --git a/src/runtime/server.rs b/src/runtime/server.rs index 2d946d4..fcb8bb7 100644 --- a/src/runtime/server.rs +++ b/src/runtime/server.rs @@ -56,6 +56,7 @@ use tracing::Level; use crate::interpreter::{Interpreter, RuntimeError, Value}; use crate::runtime::metadata::{CorsMode, ServerPlan}; use crate::runtime::tracing_init; +use futures_util::{SinkExt, StreamExt}; /// Cap on the recorded `request_id` length. /// @@ -168,7 +169,7 @@ impl Drop for CancelOnDrop { // these per request. tracing::debug!( target: "forge.server", - "response future dropped; cancel flag set" + "runtime future dropped; cancel flag set" ); } } @@ -238,6 +239,18 @@ fn call_handler( } } +fn call_ws_handler(interp: &mut Interpreter, handler_name: &str, text: String) -> String { + let handler = interp.env.get(handler_name); + if let Some(h) = handler { + match interp.call_function(h, vec![Value::String(text)]) { + Ok(v) => format!("{}", v), + Err(e) => format!("error: {}", e.message), + } + } else { + "handler not found".to_string() + } +} + /// Run a Forge handler with full per-request lifecycle: /// 1. Acquire a backpressure permit, or 503 if exhausted. /// 2. Set up the cancel-on-drop guard. @@ -435,9 +448,10 @@ pub async fn start_server( // WebSocket handlers hold session state across messages, so // a per-request fork is the wrong model. Each connection // gets its own forked interpreter held inside a - // parking_lot::Mutex (messages on a single connection - // arrive serially; the lock just gives us !Send across - // await). Different connections are still fully isolated. + // parking_lot::Mutex. The guard is acquired only inside + // spawn_blocking so synchronous Forge execution never + // blocks the async socket task. Different connections + // are still fully isolated. let hn = hn.clone(); app = app.route( &axum_path, @@ -447,30 +461,108 @@ pub async fn start_server( let template = state.template.clone(); let hn = hn.clone(); async move { - ws.on_upgrade(move |mut socket| async move { + ws.on_upgrade(move |socket| async move { use axum::extract::ws::Message; - let interp = Arc::new(parking_lot::Mutex::new(template.fork())); - while let Some(Ok(msg)) = socket.recv().await { - if let Message::Text(text) = msg { - let response = { - let mut interp = interp.lock(); - let handler = interp.env.get(&hn); - if let Some(h) = handler { - match interp.call_function( - h, - vec![Value::String(text.to_string())], - ) { - Ok(v) => format!("{}", v), - Err(e) => format!("error: {}", e.message), + + let cancelled = Arc::new(AtomicBool::new(false)); + let _drop_guard = CancelOnDrop(cancelled.clone()); + + let mut conn_interp = template.fork(); + conn_interp.cancelled = cancelled.clone(); + let interp = Arc::new(parking_lot::Mutex::new(conn_interp)); + let (mut sender, mut receiver) = socket.split(); + let (text_tx, mut text_rx) = + tokio::sync::mpsc::channel::(1); + + let cancel_for_receiver = cancelled.clone(); + let receiver_task = tokio::spawn(async move { + // Axum 0.8/tungstenite handles Pong replies in the + // codec before yielding messages here. We only need + // to forward text and treat Close/errors as cancel. + while let Some(msg) = receiver.next().await { + match msg { + Ok(Message::Text(text)) => { + if text_tx.try_send(text.to_string()).is_err() { + cancel_for_receiver + .store(true, Ordering::Release); + break; } - } else { - "handler not found".to_string() } - }; - let _ = - socket.send(Message::Text(response.into())).await; + Ok(Message::Close(_)) | Err(_) => { + cancel_for_receiver + .store(true, Ordering::Release); + break; + } + Ok(_) => {} + } + } + cancel_for_receiver.store(true, Ordering::Release); + }); + + while let Some(text) = text_rx.recv().await { + if cancelled.load(Ordering::Acquire) { + break; + } + + let interp_for_blocking = interp.clone(); + let hn_for_blocking = hn.clone(); + let span = tracing::Span::current(); + let join = tokio::task::spawn_blocking(move || { + let _g = span.enter(); + let mut interp = interp_for_blocking.lock(); + call_ws_handler(&mut interp, &hn_for_blocking, text) + }); + + let response = match join.await { + Ok(response) => response, + Err(join_err) if join_err.is_panic() => { + let payload = join_err.into_panic(); + let msg = if let Some(s) = + payload.downcast_ref::<&str>() + { + (*s).to_string() + } else if let Some(s) = + payload.downcast_ref::() + { + s.clone() + } else { + "".to_string() + }; + tracing::error!( + target: "forge.server", + handler = %hn, + panic = %msg, + "websocket handler panicked" + ); + "error: internal handler panic".to_string() + } + Err(join_err) => { + tracing::error!( + target: "forge.server", + handler = %hn, + error = %join_err, + "websocket handler task failed" + ); + "error: handler task failed".to_string() + } + }; + + if cancelled.load(Ordering::Acquire) { + break; + } + + if sender + .send(Message::Text(response.into())) + .await + .is_err() + { + cancelled.store(true, Ordering::Release); + break; } } + + cancelled.store(true, Ordering::Release); + receiver_task.abort(); }) } }, diff --git a/tests/server_concurrency.rs b/tests/server_concurrency.rs index a3ace1a..6a4547b 100644 --- a/tests/server_concurrency.rs +++ b/tests/server_concurrency.rs @@ -20,8 +20,10 @@ use forge_lang::lexer::Lexer; use forge_lang::parser::Parser; use forge_lang::runtime::metadata::extract_runtime_plan; use forge_lang::runtime::server::start_server; +use futures_util::SinkExt; use std::net::TcpListener; +use std::path::{Path, PathBuf}; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH}; @@ -82,6 +84,30 @@ fn spawn_test_server(source: &str) -> u16 { panic!("server failed to start on port {} within 5s", port); } +fn unique_temp_file(name: &str) -> PathBuf { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_nanos(); + std::env::temp_dir().join(format!( + "forge_{}_{}_{}.txt", + name, + std::process::id(), + unique + )) +} + +fn wait_for_path(path: &Path, timeout: Duration) -> bool { + let deadline = Instant::now() + timeout; + while Instant::now() < deadline { + if path.exists() { + return true; + } + std::thread::sleep(Duration::from_millis(25)); + } + false +} + /// Time N concurrent GET requests using blocking reqwest on N OS threads. /// Returns the total wall time from the first request issued to the last /// response received. @@ -243,15 +269,7 @@ fn closure_capturing_handlers_run_in_parallel_not_serialized() { #[test] fn schedule_mutations_do_not_leak_into_handler_forks() { - let unique = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system clock before unix epoch") - .as_nanos(); - let sentinel = std::env::temp_dir().join(format!( - "forge_schedule_handler_isolation_{}_{}.txt", - std::process::id(), - unique - )); + let sentinel = unique_temp_file("schedule_handler_isolation"); let _ = std::fs::remove_file(&sentinel); let sentinel_str = sentinel .to_str() @@ -316,6 +334,114 @@ fn schedule_mutations_do_not_leak_into_handler_forks() { let _ = std::fs::remove_file(&sentinel); } +#[test] +fn websocket_handler_cancelled_on_client_disconnect() { + let started = unique_temp_file("ws_cancel_started"); + let progress = unique_temp_file("ws_cancel_progress"); + let finished = unique_temp_file("ws_cancel_finished"); + for path in [&started, &progress, &finished] { + let _ = std::fs::remove_file(path); + } + + let started_str = started.to_str().expect("temp path should be UTF-8"); + let progress_str = progress.to_str().expect("temp path should be UTF-8"); + let finished_str = finished.to_str().expect("temp path should be UTF-8"); + + let source = r#" + @server(port: __PORT__) + + @get("/ping") + fn ping() -> Json { + return { ok: true } + } + + @ws("/ws") + fn socket(msg) { + fs.write("__STARTED__", "1") + let mut i = 0 + repeat 1000000 times { + i = i + 1 + fs.write("__PROGRESS__", str(i)) + } + fs.write("__FINISHED__", "done") + return "done" + } + "# + .replace("__STARTED__", started_str) + .replace("__PROGRESS__", progress_str) + .replace("__FINISHED__", finished_str); + + let port = spawn_test_server(&source); + let url = format!("ws://127.0.0.1:{}/ws", port); + + let rt = tokio::runtime::Runtime::new().expect("test runtime"); + rt.block_on(async { + use tokio_tungstenite::tungstenite::Message; + + let (mut ws, _) = tokio_tungstenite::connect_async(&url) + .await + .expect("connect websocket"); + ws.send(Message::Text("go".into())) + .await + .expect("send websocket message"); + + let started_path = started.clone(); + tokio::task::spawn_blocking(move || { + assert!( + wait_for_path(&started_path, Duration::from_secs(5)), + "websocket handler never started" + ); + }) + .await + .expect("wait for started sentinel"); + + let progress_path = progress.clone(); + tokio::task::spawn_blocking(move || { + assert!( + wait_for_path(&progress_path, Duration::from_secs(5)), + "websocket handler never entered progress loop" + ); + }) + .await + .expect("wait for progress sentinel"); + + let _ = ws.send(Message::Close(None)).await; + drop(ws); + }); + + let mut last_progress = + std::fs::read_to_string(&progress).expect("progress sentinel should be readable"); + let deadline = Instant::now() + Duration::from_secs(5); + let mut stabilized = false; + while Instant::now() < deadline { + assert!( + !finished.exists(), + "websocket handler completed normally instead of being cancelled" + ); + std::thread::sleep(Duration::from_millis(200)); + let current = + std::fs::read_to_string(&progress).expect("progress sentinel should remain readable"); + if current == last_progress { + stabilized = true; + break; + } + last_progress = current; + } + + assert!( + stabilized, + "websocket handler progress kept changing after client disconnect" + ); + assert!( + !finished.exists(), + "websocket handler wrote finished sentinel after disconnect" + ); + + for path in [&started, &progress, &finished] { + let _ = std::fs::remove_file(path); + } +} + #[test] fn request_id_is_generated_and_propagated() { // Two scenarios to verify: