diff --git a/src/dispatcher.rs b/src/dispatcher.rs index f3c8093..c36c231 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, Mutex}; use bytes::Bytes; use tokio::sync::Semaphore; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; -use tokio::task::JoinHandle; +use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::{Instrument, Span, debug, info_span, warn}; @@ -20,15 +20,17 @@ use crate::{LspError, Result}; /// At startup, the transport is split into a reader half and a writer /// half. The writer half moves into a dedicated send-loop task that /// drains an `unbounded_channel` of outgoing messages. The read-loop -/// owns the reader and spawns every request and non-lifecycle -/// notification handler against `Arc`. Each spawned request is -/// tracked in an in-flight registry keyed by `RequestId`, so a -/// `$/cancelRequest` notification can trigger the handler's -/// [`CancellationToken`] and drop the handler future at its next yield -/// — the wire then carries a `-32800 RequestCancelled` response (ADR -/// 0007). Responses and outgoing notifications all flow through the -/// same channel — the send-loop is the sole writer to the transport. -pub(crate) async fn run(server: S, transport: T, concurrency_limit: usize) -> Result<()> +/// owns the reader and spawns every spawned handler into a shared +/// [`JoinSet`] against `Arc`. Each in-flight request is also tracked +/// in a registry keyed by `RequestId` holding its [`CancellationToken`], +/// so a `$/cancelRequest` can trigger that token and drop the handler +/// future at its next yield — the wire then carries a `-32800 +/// RequestCancelled` response (ADR 0007). On `exit`, the read-loop aborts +/// the entire [`JoinSet`] so no in-flight handler is awaited to +/// completion (issue #4). Responses and outgoing notifications all flow +/// through the same channel — the send-loop is the sole writer to the +/// transport. +pub(crate) async fn run(server: S, transport: T, concurrency_limit: usize) -> Result where S: LanguageServer, T: Transport, @@ -41,33 +43,64 @@ where let state: SharedState = Arc::new(Mutex::new(State::Uninitialized)); let registry: Registry = Arc::new(Mutex::new(HashMap::new())); let permits = Arc::new(Semaphore::new(concurrency_limit)); + // Every spawned handler lives here. Requests also self-remove from + // `registry` on completion; this set additionally lets `exit` abort + // them all at once. + let mut tasks: JoinSet<()> = JoinSet::new(); loop { + // Reap finished handlers so the set doesn't grow unbounded over a + // long session (each completed task already released its permit). + while tasks.try_join_next().is_some() {} + let msg = match reader.recv().await { Ok(msg) => msg, Err(TransportError::Closed) => { + // Peer disconnected before `exit`. Drain whatever + // in-flight handlers have already queued, then return; + // unlike `exit`, we let outstanding handlers finish + // rather than abort them. warn!("transport closed by peer before exit notification"); drop(out_tx); let _ = send_handle.await; - return Ok(()); + return Ok(Outcome::TransportClosed); } Err(e) => return Err(Error::Transport(e)), }; - let flow = dispatch(&server, &out_tx, &state, ®istry, &permits, msg).await?; + let flow = dispatch( + &server, &out_tx, &state, ®istry, &permits, &mut tasks, msg, + ) + .await?; if let Flow::Exit(code) = flow { - // Drop our master sender so the send-loop can drain on its own - // once any in-flight handler tasks release their clones; then - // bail out via process::exit per LSP semantics. Spawned - // handlers and the send-loop die with the process — issue #4 - // tightens lifecycle ordering on top of this. + // `exit` means "stop now": abort every in-flight handler and + // wait for them to drop (which releases their clones of the + // outgoing sender). Then drop our master sender so the + // send-loop drains whatever was already queued and exits + // cleanly, and hand the exit code back to the entry point — + // which decides whether to terminate the process (binary) or + // simply return (library / tests). + tasks.shutdown().await; drop(out_tx); let _ = send_handle.await; - std::process::exit(code); + return Ok(Outcome::Exit(code)); } } } +/// What ended the dispatcher's read-loop. The entry point maps this to a +/// process exit for a real binary (`StdioBuilder::serve`) or simply +/// returns it for the library escape hatch (`lspf::serve`), so the same +/// dispatcher is testable in-process without a `process::exit` that would +/// take the test runner down with it. +pub(crate) enum Outcome { + /// The peer closed the transport before sending `exit`. + TransportClosed, + /// An `exit` notification was processed; carries the LSP exit code + /// (0 if `shutdown` preceded it, else 1). + Exit(i32), +} + async fn send_loop(mut writer: W, mut out_rx: UnboundedReceiver) { while let Some(msg) = out_rx.recv().await { if let Err(e) = writer.send(msg).await { @@ -94,16 +127,13 @@ enum Flow { Exit(i32), } -/// Entry in the in-flight registry: the task running the handler plus -/// the cancellation token wired into it. Removed atomically by +/// In-flight request registry: maps each spawned request's `RequestId` +/// to its [`CancellationToken`]. The entry is removed atomically by /// whichever happens first — the handler completing, or a -/// `$/cancelRequest` arriving for its id. -struct InFlight { - handle: JoinHandle<()>, - token: CancellationToken, -} - -type Registry = Arc>>; +/// `$/cancelRequest` arriving for its id — and that removal arbitrates +/// who writes the single wire response. The handler's [`JoinHandle`] +/// lives in the read-loop's [`JoinSet`], not here. +type Registry = Arc>>; #[derive(serde::Deserialize)] struct CancelParams { @@ -116,6 +146,7 @@ async fn dispatch( state: &SharedState, registry: &Registry, permits: &Arc, + tasks: &mut JoinSet<()>, msg: RawMessage, ) -> Result where @@ -125,6 +156,16 @@ where RawMessage::Request { id, method, params } => { let span = info_span!("request", method = %method, id = ?id); + // Initialize precedence: until `initialize` completes, every + // other request is refused with `ServerNotInitialized` + // *before* any handler task is spawned (issue #4). Gating the + // spawn step — not a post-spawn check inside the task — is + // what keeps the guarantee under concurrent dispatch. + if method != "initialize" && *state.lock().unwrap() == State::Uninitialized { + enqueue_error(out_tx, id, LspError::ServerNotInitialized); + return Ok(Flow::Continue); + } + match method.as_ref() { "initialize" => { if *state.lock().unwrap() != State::Uninitialized { @@ -139,30 +180,33 @@ where ); return Ok(Flow::Continue); } + // Run inline (ADR 0003): the read-loop blocks here until + // `initialize` completes, so the `state → Running` + // transition is synchronous and every later message sees + // the post-init state. Spawning instead would let the + // next message be dispatched while still `Uninitialized`, + // defeating initialize-precedence (issue #4). A slow + // `initialize` stalling the read-loop is correct per the + // LSP spec — clients may not send other requests until it + // returns. initialize is therefore not cancellable; the + // token is a never-firing placeholder. let params = parse_params(¶ms)?; - let server = Arc::clone(server); - let state = Arc::clone(state); - let permit = acquire_permit(permits).await; - spawn_request( - registry, - out_tx, - span, - id, - permit, - move |ctx, ct| async move { - let result = server.initialize(&ctx, params, ct).await; - if result.is_ok() { - *state.lock().unwrap() = State::Running; - } - result.and_then(to_value) - }, - ); + let ctx = Context::for_request(id.clone(), span.clone(), out_tx.clone()); + let result = server + .initialize(&ctx, params, CancellationToken::new()) + .instrument(span) + .await; + if result.is_ok() { + *state.lock().unwrap() = State::Running; + } + enqueue_value_response(out_tx, id, result.and_then(to_value)); } "shutdown" => { let server = Arc::clone(server); let state = Arc::clone(state); let permit = acquire_permit(permits).await; spawn_request( + tasks, registry, out_tx, span, @@ -178,18 +222,28 @@ where ); } other => { - let snapshot = *state.lock().unwrap(); - if snapshot == State::Uninitialized { - enqueue_error(out_tx, id, LspError::ServerNotInitialized); - } else { - enqueue_error(out_tx, id, LspError::MethodNotFound(other.to_string())); - } + // Uninitialized was already refused by the gate above, + // so reaching here means the server is running. + enqueue_error(out_tx, id, LspError::MethodNotFound(other.to_string())); } } } RawMessage::Notification { method, params } => { let span = info_span!("notification", method = %method); + // Initialize precedence (LSP §Initialize): until `initialize` + // completes, every notification is dropped except `exit` + // (which lets an uninitialized server still shut down) and + // `initialized` (the handshake's other half). Dropping happens + // before any handler is spawned (issue #4). + if method != "initialized" + && method != "exit" + && *state.lock().unwrap() == State::Uninitialized + { + debug!(method = %method, "dropping notification before initialize"); + return Ok(Flow::Continue); + } + match method.as_ref() { "exit" => { let ctx = Context::for_notification(span.clone(), out_tx.clone()); @@ -208,6 +262,7 @@ where let params = parse_params(¶ms)?; let permit = acquire_permit(permits).await; spawn_notification( + tasks, server, out_tx, span, @@ -221,6 +276,7 @@ where let params = parse_params(¶ms)?; let permit = acquire_permit(permits).await; spawn_notification( + tasks, server, out_tx, span, @@ -258,7 +314,11 @@ where /// is still there, it writes the response; if `$/cancelRequest` /// already removed it (and wrote `-32800`), the task's response is /// dropped silently. +/// +/// The task is spawned into the shared [`JoinSet`] so `exit` can abort it +/// along with every other in-flight handler. fn spawn_request( + tasks: &mut JoinSet<()>, registry: &Registry, out_tx: &UnboundedSender, span: Span, @@ -281,7 +341,7 @@ fn spawn_request( let span_for_ctx = span.clone(); let out_tx_for_ctx = out_tx.clone(); - let handle = tokio::spawn( + tasks.spawn( async move { // Hold the permit for the lifetime of the task; dropping at // task end (whether the body finished, was cancelled, or @@ -312,10 +372,7 @@ fn spawn_request( .instrument(span), ); - registry - .lock() - .unwrap() - .insert(id, InFlight { handle, token: ct }); + registry.lock().unwrap().insert(id, ct); } fn handle_cancel(registry: &Registry, out_tx: &UnboundedSender, params: &Bytes) { @@ -327,22 +384,23 @@ fn handle_cancel(registry: &Registry, out_tx: &UnboundedSender, para return; } }; - let entry = registry.lock().unwrap().remove(&parsed.id); - if let Some(entry) = entry { + let token = registry.lock().unwrap().remove(&parsed.id); + if let Some(token) = token { // Cancel the token (wakes polite `ct.cancelled().await`s and // flips `ct.is_cancelled()`) and write the wire response. The // spawned task's own `select!` then drops the body future at - // its next yield — we don't call `JoinHandle::abort` directly + // its next yield — we don't abort its `JoinHandle` directly // because abort races with the polite path: it can drop the // future before the handler ever gets polled with the token - // observed. - entry.token.cancel(); + // observed. (The task stays in the `JoinSet` and is reaped once + // it finishes.) + token.cancel(); enqueue_error(out_tx, parsed.id, LspError::RequestCancelled); - drop(entry.handle); } } fn spawn_notification( + tasks: &mut JoinSet<()>, server: &Arc, out_tx: &UnboundedSender, span: tracing::Span, @@ -356,7 +414,7 @@ fn spawn_notification( let server = Arc::clone(server); let out_tx = out_tx.clone(); let span_for_task = span.clone(); - tokio::spawn( + tasks.spawn( async move { let _permit = permit; let ctx = Context::for_notification(span_for_task, out_tx); diff --git a/src/lib.rs b/src/lib.rs index 20bcbca..6de9ec0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,8 @@ where S: LanguageServer, T: Transport, { - dispatcher::run(server, transport, DEFAULT_CONCURRENCY_LIMIT).await + dispatcher::run(server, transport, DEFAULT_CONCURRENCY_LIMIT).await?; + Ok(()) } /// Like [`serve`], but with an explicit cap on in-flight handler tasks @@ -54,5 +55,6 @@ where S: LanguageServer, T: Transport, { - dispatcher::run(server, transport, concurrency_limit).await + dispatcher::run(server, transport, concurrency_limit).await?; + Ok(()) } diff --git a/src/transport/mod.rs b/src/transport/mod.rs index 5d4f086..e9b1359 100644 --- a/src/transport/mod.rs +++ b/src/transport/mod.rs @@ -94,6 +94,13 @@ impl StdioBuilder { pub async fn serve(self) -> crate::Result<()> { let transport = StdioTransport::new(); - crate::dispatcher::run(self.server, transport, self.concurrency_limit).await + match crate::dispatcher::run(self.server, transport, self.concurrency_limit).await? { + // Peer hung up before `exit`: return normally and let the + // caller's `main` decide the process disposition. + crate::dispatcher::Outcome::TransportClosed => Ok(()), + // `exit` notification: terminate the process with the LSP + // exit code, per the spec's lifecycle contract. + crate::dispatcher::Outcome::Exit(code) => std::process::exit(code), + } } } diff --git a/tests/acquire_permit_span.rs b/tests/acquire_permit_span.rs new file mode 100644 index 0000000..43f8cc2 --- /dev/null +++ b/tests/acquire_permit_span.rs @@ -0,0 +1,248 @@ +//! ADR 0012: when the in-flight cap is hit, a handler's wait for a permit +//! must be visible in traces as a `handler.acquire_permit` span (issue #3). +//! +//! This lives in its own test binary on purpose. Span capture uses a +//! **process-global** subscriber: a thread-local `set_default` subscriber +//! is not reliably observed when tokio polls spawned handler tasks, so +//! under load spans go uncaptured. A dedicated binary means the global +//! subscriber sees every span this test produces and nothing else's. + +use std::borrow::Cow; +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use bytes::Bytes; +use serde_json::json; + +use lspf::types::DidOpenTextDocumentParams; +use lspf::{ + Context, LanguageServer, RawMessage, RequestId, Transport, TransportError, TransportReader, + TransportWriter, +}; + +struct VecTransport { + inbox: VecDeque, + outbox: Arc>>, + done: Arc, +} + +struct VecReader { + inbox: VecDeque, + done: Arc, +} + +struct VecWriter { + outbox: Arc>>, +} + +impl Transport for VecTransport { + type Reader = VecReader; + type Writer = VecWriter; + + fn split(self) -> (Self::Reader, Self::Writer) { + ( + VecReader { + inbox: self.inbox, + done: self.done, + }, + VecWriter { + outbox: self.outbox, + }, + ) + } +} + +impl TransportReader for VecReader { + async fn recv(&mut self) -> Result { + match self.inbox.pop_front() { + Some(msg) => Ok(msg), + // Park until the test signals teardown, so the dispatcher + // doesn't tear down while handlers are still gated. + None => { + self.done.notified().await; + Err(TransportError::Closed) + } + } + } +} + +impl TransportWriter for VecWriter { + async fn send(&mut self, msg: RawMessage) -> Result<(), TransportError> { + self.outbox.lock().unwrap().push(msg); + Ok(()) + } + + async fn shutdown(self) -> Result<(), TransportError> { + Ok(()) + } +} + +/// A `didOpen` handler gated by explicit barriers rather than a fixed +/// sleep, so the test controls exactly when the permit-holder finishes. +/// Each handler reports it has started (and thus holds the permit) on +/// `started`, then parks until the test releases it. +struct Gated { + started: Arc, + release: Arc, +} + +impl LanguageServer for Gated { + async fn text_document_did_open(&self, ctx: &Context, params: DidOpenTextDocumentParams) { + self.started.add_permits(1); + self.release.notified().await; + ctx.publish_diagnostics(lspf::types::PublishDiagnosticsParams { + uri: params.text_document.uri, + version: Some(params.text_document.version), + diagnostics: vec![], + }); + } +} + +fn initialize_request(id: i32) -> RawMessage { + let params = json!({ "processId": null, "rootUri": null, "capabilities": {} }); + RawMessage::Request { + id: RequestId::Number(id), + method: Cow::Borrowed("initialize"), + params: Bytes::from(serde_json::to_vec(¶ms).unwrap()), + } +} + +fn did_open_notification(uri: &str) -> RawMessage { + let params = json!({ + "textDocument": { "uri": uri, "languageId": "plaintext", "version": 1, "text": "" } + }); + RawMessage::Notification { + method: Cow::Borrowed("textDocument/didOpen"), + params: Bytes::from(serde_json::to_vec(¶ms).unwrap()), + } +} + +fn count_publish_diagnostics(outbox: &[RawMessage]) -> usize { + outbox + .iter() + .filter(|m| { + matches!( + m, + RawMessage::Notification { method, .. } + if method == "textDocument/publishDiagnostics" + ) + }) + .count() +} + +/// Captures `handler.acquire_permit` span lifetimes. `on_new_span` stores +/// the open instant in the span's extensions; `on_close` computes the +/// elapsed time — i.e. how long the handler waited for a permit. +#[derive(Default, Clone)] +struct SpanCapture { + closed: Arc>>, +} + +struct OpenedAt(Instant); + +impl tracing_subscriber::Layer for SpanCapture +where + S: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>, +{ + fn on_new_span( + &self, + _attrs: &tracing::span::Attributes<'_>, + id: &tracing::Id, + ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + if let Some(span) = ctx.span(id) { + span.extensions_mut().insert(OpenedAt(Instant::now())); + } + } + + fn on_close(&self, id: tracing::Id, ctx: tracing_subscriber::layer::Context<'_, S>) { + let Some(span) = ctx.span(&id) else { return }; + let name = span.metadata().name().to_string(); + let elapsed = span + .extensions() + .get::() + .map(|o| o.0.elapsed()) + .unwrap_or_default(); + self.closed.lock().unwrap().push((name, elapsed)); + } +} + +#[tokio::test(flavor = "current_thread")] +async fn handler_acquire_permit_span_visible_when_cap_exceeded() { + use tracing_subscriber::layer::SubscriberExt; + use tracing_subscriber::util::SubscriberInitExt; + + let capture = SpanCapture::default(); + tracing_subscriber::registry().with(capture.clone()).init(); + + let outbox = Arc::new(Mutex::new(Vec::new())); + let done = Arc::new(tokio::sync::Notify::new()); + let mut inbox = VecDeque::new(); + inbox.push_back(initialize_request(1)); + inbox.push_back(did_open_notification("file:///a")); + inbox.push_back(did_open_notification("file:///b")); + + let transport = VecTransport { + inbox, + outbox: outbox.clone(), + done: done.clone(), + }; + // Both `didOpen` handlers are gated on `release`; `started` reports + // when each one is actually running (and so holds the single permit). + // Driving the barriers explicitly — rather than racing fixed sleeps — + // keeps the queueing window deterministic, and waiting for both to + // publish guarantees every span has closed before we inspect them. + let started = Arc::new(tokio::sync::Semaphore::new(0)); + let release = Arc::new(tokio::sync::Notify::new()); + let server = Gated { + started: started.clone(), + release: release.clone(), + }; + + const QUEUE_HOLD: Duration = Duration::from_millis(80); + let server_handle = tokio::spawn(async move { + let _ = lspf::serve_with_limit(server, transport, 1).await; + }); + + // First handler grabs the only permit and parks; the second is now + // queued inside `acquire_permit`. Hold it queued for a measurable + // window so its acquire span shows real wait time, then release the + // first so the second can acquire (closing its long acquire span). + let _ = started.acquire().await.unwrap(); + tokio::time::sleep(QUEUE_HOLD).await; + release.notify_one(); + + let _ = started.acquire().await.unwrap(); + release.notify_one(); + + // Wait for both handlers to publish before tearing down, so every span + // has closed. Generous cap guards against a true hang. + let start = Instant::now(); + while count_publish_diagnostics(&outbox.lock().unwrap()) < 2 { + assert!( + start.elapsed() < Duration::from_secs(5), + "handlers did not both publish within 5s" + ); + tokio::time::sleep(Duration::from_millis(5)).await; + } + done.notify_one(); + let _ = server_handle.await; + + let closed = capture.closed.lock().unwrap(); + let max_wait = closed + .iter() + .filter(|(name, _)| name == "handler.acquire_permit") + .map(|(_, d)| *d) + .max() + .unwrap_or_default(); + + // The second didOpen was kept queued for `QUEUE_HOLD` behind the + // first under cap=1, so at least one acquire span must reflect that. + assert!( + max_wait >= QUEUE_HOLD / 2, + "expected an acquire span showing queueing (>= {:?}); spans={:#?}", + QUEUE_HOLD / 2, + *closed, + ); +} diff --git a/tests/cancellation.rs b/tests/cancellation.rs index c9bac1b..b05ab94 100644 --- a/tests/cancellation.rs +++ b/tests/cancellation.rs @@ -14,7 +14,6 @@ use bytes::Bytes; use serde_json::json; use tokio::sync::{mpsc, oneshot}; -use lspf::types::{InitializeParams, InitializeResult}; use lspf::{ CancellationToken, Context, LanguageServer, LspError, RawMessage, RequestId, Transport, TransportError, TransportReader, TransportWriter, @@ -77,6 +76,14 @@ fn initialize_request(id: i32) -> RawMessage { } } +fn shutdown_request(id: i32) -> RawMessage { + RawMessage::Request { + id: RequestId::Number(id), + method: Cow::Borrowed("shutdown"), + params: Bytes::from_static(b"{}"), + } +} + fn cancel_notification(id: i32) -> RawMessage { let params = json!({ "id": id }); RawMessage::Notification { @@ -109,22 +116,33 @@ async fn poll_for_response( } } -/// A server whose `initialize` sleeps for a long time, bailing politely -/// when the framework triggers its cancellation token. -struct SleepyInit; +/// Drive `initialize` to completion so the server reaches `Running` — +/// the only state in which the LSP spec permits a client to send (and +/// therefore cancel) further requests. Panics if the initialize response +/// does not arrive promptly. +async fn initialize( + in_tx: &mpsc::UnboundedSender, + outbox: &Arc>>, +) { + in_tx.send(initialize_request(1)).unwrap(); + poll_for_response(outbox, &RequestId::Number(1), Duration::from_millis(500)) + .await + .expect("initialize did not complete within 500ms"); +} + +/// A server whose `shutdown` sleeps for a long time, bailing politely +/// when the framework triggers its cancellation token. Cancellation is +/// exercised on `shutdown` (a post-initialize request) rather than +/// `initialize`, because the spec forbids clients from sending anything — +/// including `$/cancelRequest` — before the initialize response, and the +/// dispatcher drops such notifications (issue #4). +struct SleepyShutdown; -impl LanguageServer for SleepyInit { - async fn initialize( - &self, - _ctx: &Context, - _params: InitializeParams, - ct: CancellationToken, - ) -> Result { +impl LanguageServer for SleepyShutdown { + async fn shutdown(&self, _ctx: &Context, ct: CancellationToken) -> Result<(), LspError> { tokio::select! { _ = ct.cancelled() => Err(LspError::RequestCancelled), - _ = tokio::time::sleep(Duration::from_secs(1)) => { - Ok(InitializeResult::default()) - } + _ = tokio::time::sleep(Duration::from_secs(1)) => Ok(()), } } } @@ -139,18 +157,20 @@ async fn cancel_request_returns_request_cancelled() { outbox: outbox.clone(), }; let server_handle = tokio::spawn(async move { - let _ = lspf::serve(SleepyInit, transport).await; + let _ = lspf::serve(SleepyShutdown, transport).await; }); - in_tx.send(initialize_request(1)).unwrap(); + initialize(&in_tx, &outbox).await; + + in_tx.send(shutdown_request(2)).unwrap(); // Give the spawned handler a moment to land in its await before the cancel. tokio::time::sleep(Duration::from_millis(20)).await; let cancel_sent = Instant::now(); - in_tx.send(cancel_notification(1)).unwrap(); + in_tx.send(cancel_notification(2)).unwrap(); - let response = poll_for_response(&outbox, &RequestId::Number(1), Duration::from_millis(500)) + let response = poll_for_response(&outbox, &RequestId::Number(2), Duration::from_millis(500)) .await - .expect("no response for id=1 within 500ms"); + .expect("no response for id=2 within 500ms"); let elapsed = cancel_sent.elapsed(); assert!( @@ -174,19 +194,14 @@ async fn cancel_request_returns_request_cancelled() { let _ = server_handle.await; } -/// A server whose `initialize` parks on `ct.cancelled()` then asserts the +/// A server whose `shutdown` parks on `ct.cancelled()` then asserts the /// token observed the cancel, signalling via a oneshot. struct ObserveCancel { signal: Mutex>>, } impl LanguageServer for ObserveCancel { - async fn initialize( - &self, - _ctx: &Context, - _params: InitializeParams, - ct: CancellationToken, - ) -> Result { + async fn shutdown(&self, _ctx: &Context, ct: CancellationToken) -> Result<(), LspError> { ct.cancelled().await; let observed = ct.is_cancelled(); if let Some(tx) = self.signal.lock().unwrap().take() { @@ -213,10 +228,12 @@ async fn cancel_request_triggers_handler_token() { let _ = lspf::serve(server, transport).await; }); - in_tx.send(initialize_request(1)).unwrap(); + initialize(&in_tx, &outbox).await; + + in_tx.send(shutdown_request(2)).unwrap(); // Ensure the spawned handler reaches `ct.cancelled().await` before cancel arrives. tokio::time::sleep(Duration::from_millis(20)).await; - in_tx.send(cancel_notification(1)).unwrap(); + in_tx.send(cancel_notification(2)).unwrap(); let observed = tokio::time::timeout(Duration::from_millis(100), signal_rx) .await diff --git a/tests/concurrency_cap.rs b/tests/concurrency_cap.rs index c3caf32..38b6374 100644 --- a/tests/concurrency_cap.rs +++ b/tests/concurrency_cap.rs @@ -136,112 +136,6 @@ fn count_publish_diagnostics(outbox: &[RawMessage]) -> usize { .count() } -/// Captures `handler.acquire_permit` span lifetimes. `on_new_span` and -/// `on_close` fire exactly once per span; storing the open instant in -/// the span's extensions lets us compute the elapsed time between -/// acquire-start and acquire-finish — i.e. the queueing latency. -#[derive(Default, Clone)] -struct SpanCapture { - closed: Arc>>, -} - -struct OpenedAt(Instant); - -impl tracing_subscriber::Layer for SpanCapture -where - S: tracing::Subscriber + for<'a> tracing_subscriber::registry::LookupSpan<'a>, -{ - fn on_new_span( - &self, - _attrs: &tracing::span::Attributes<'_>, - id: &tracing::Id, - ctx: tracing_subscriber::layer::Context<'_, S>, - ) { - if let Some(span) = ctx.span(id) { - span.extensions_mut().insert(OpenedAt(Instant::now())); - } - } - - fn on_close(&self, id: tracing::Id, ctx: tracing_subscriber::layer::Context<'_, S>) { - let Some(span) = ctx.span(&id) else { return }; - let name = span.metadata().name().to_string(); - let elapsed = span - .extensions() - .get::() - .map(|o| o.0.elapsed()) - .unwrap_or_default(); - self.closed.lock().unwrap().push((name, elapsed)); - } -} - -// Single-threaded runtime: `set_default()` installs a thread-local -// subscriber, and tokio's current-thread scheduler keeps spawned tasks -// on the same thread so their spans are captured. -#[tokio::test(flavor = "current_thread")] -async fn handler_acquire_permit_span_visible_when_cap_exceeded() { - use tracing_subscriber::layer::SubscriberExt; - use tracing_subscriber::util::SubscriberInitExt; - - let capture = SpanCapture::default(); - let _guard = tracing_subscriber::registry() - .with(capture.clone()) - .set_default(); - - let outbox = Arc::new(Mutex::new(Vec::new())); - let done = Arc::new(tokio::sync::Notify::new()); - let mut inbox = VecDeque::new(); - inbox.push_back(initialize_request(1)); - inbox.push_back(did_open_notification("file:///a")); - inbox.push_back(did_open_notification("file:///b")); - - let transport = VecTransport { - inbox, - outbox: outbox.clone(), - done: done.clone(), - }; - let server = Sleepy { - sleep: Duration::from_millis(150), - started: Arc::new(tokio::sync::Semaphore::new(0)), - }; - - let server_handle = tokio::spawn(async move { - let _ = lspf::serve_with_limit(server, transport, 1).await; - }); - - let start = Instant::now(); - while count_publish_diagnostics(&outbox.lock().unwrap()) < 2 { - if start.elapsed() > Duration::from_millis(1000) { - break; - } - tokio::time::sleep(Duration::from_millis(10)).await; - } - done.notify_one(); - let _ = server_handle.await; - - let closed = capture.closed.lock().unwrap(); - let acquire_spans: Vec<&(String, Duration)> = closed - .iter() - .filter(|(name, _)| name == "handler.acquire_permit") - .collect(); - - // initialize + 2 × didOpen → 3 spawn sites → 3 acquire spans. - assert_eq!( - acquire_spans.len(), - 3, - "expected 3 handler.acquire_permit spans (initialize + 2 didOpen), got {:#?}", - *closed, - ); - // First two complete fast; the third queues behind the second didOpen - // handler's 150ms sleep. Allow a generous lower bound so we don't - // flake on slow CI but still prove queueing was observed. - let max_wait = acquire_spans.iter().map(|(_, d)| *d).max().unwrap(); - assert!( - max_wait >= Duration::from_millis(50), - "expected at least one acquire span to show queueing (>= 50ms); spans={:#?}", - acquire_spans, - ); -} - #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn cap_of_two_serializes_five_handlers_into_three_batches() { let outbox = Arc::new(Mutex::new(Vec::new())); diff --git a/tests/lifecycle_ordering.rs b/tests/lifecycle_ordering.rs new file mode 100644 index 0000000..36724df --- /dev/null +++ b/tests/lifecycle_ordering.rs @@ -0,0 +1,316 @@ +//! Lifecycle-ordering guarantees under concurrent dispatch (issue #4). +//! +//! Two LSP-spec invariants must survive the concurrent dispatcher: +//! +//! 1. **Initialize precedence** — before `initialize` completes, any +//! other inbound request is answered `ServerNotInitialized` without +//! spawning a handler, and any notification other than `initialized` +//! / `exit` is dropped. +//! 2. **Exit aborts in-flight work** — an `exit` notification aborts +//! every in-flight handler rather than awaiting it. +//! +//! Like `cancellation.rs`, these drive the dispatcher through an +//! in-process channel-backed [`Transport`] so messages can be injected +//! out of band and the outbox inspected directly. + +use std::borrow::Cow; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use bytes::Bytes; +use serde_json::json; +use tokio::sync::mpsc; + +use lspf::types::{ + Diagnostic, DiagnosticSeverity, DidOpenTextDocumentParams, Position, PublishDiagnosticsParams, + Range, +}; +use lspf::{ + Context, LanguageServer, RawMessage, RequestId, Transport, TransportError, TransportReader, + TransportWriter, +}; + +struct ChannelTransport { + in_rx: mpsc::UnboundedReceiver, + outbox: Arc>>, +} + +struct ChannelReader { + in_rx: mpsc::UnboundedReceiver, +} + +struct ChannelWriter { + outbox: Arc>>, +} + +impl Transport for ChannelTransport { + type Reader = ChannelReader; + type Writer = ChannelWriter; + + fn split(self) -> (Self::Reader, Self::Writer) { + ( + ChannelReader { in_rx: self.in_rx }, + ChannelWriter { + outbox: self.outbox, + }, + ) + } +} + +impl TransportReader for ChannelReader { + async fn recv(&mut self) -> Result { + self.in_rx.recv().await.ok_or(TransportError::Closed) + } +} + +impl TransportWriter for ChannelWriter { + async fn send(&mut self, msg: RawMessage) -> Result<(), TransportError> { + self.outbox.lock().unwrap().push(msg); + Ok(()) + } + + async fn shutdown(self) -> Result<(), TransportError> { + Ok(()) + } +} + +/// A server whose every built-in override has an observable effect, so a +/// test can tell whether a handler actually ran. `didOpen` publishes a +/// diagnostic; `initialize`/`shutdown` use the default success replies. +struct Probe; + +impl LanguageServer for Probe { + async fn text_document_did_open(&self, ctx: &Context, params: DidOpenTextDocumentParams) { + ctx.publish_diagnostics(PublishDiagnosticsParams { + uri: params.text_document.uri, + version: Some(params.text_document.version), + diagnostics: vec![Diagnostic { + range: Range { + start: Position { + line: 0, + character: 0, + }, + end: Position { + line: 0, + character: 0, + }, + }, + severity: Some(DiagnosticSeverity::INFORMATION), + source: Some("lifecycle-probe".into()), + message: "didOpen ran".into(), + ..Diagnostic::default() + }], + }); + } +} + +fn initialize_request(id: i32) -> RawMessage { + let params = json!({ "processId": null, "rootUri": null, "capabilities": {} }); + RawMessage::Request { + id: RequestId::Number(id), + method: Cow::Borrowed("initialize"), + params: Bytes::from(serde_json::to_vec(¶ms).unwrap()), + } +} + +fn request(id: i32, method: &'static str) -> RawMessage { + RawMessage::Request { + id: RequestId::Number(id), + method: Cow::Borrowed(method), + params: Bytes::from_static(b"{}"), + } +} + +fn notification(method: &'static str, params: serde_json::Value) -> RawMessage { + RawMessage::Notification { + method: Cow::Borrowed(method), + params: Bytes::from(serde_json::to_vec(¶ms).unwrap()), + } +} + +fn did_open_notification(uri: &str) -> RawMessage { + notification( + "textDocument/didOpen", + json!({ + "textDocument": { + "uri": uri, + "languageId": "plaintext", + "version": 1, + "text": "hello" + } + }), + ) +} + +fn has_publish_diagnostics(outbox: &[RawMessage]) -> bool { + outbox.iter().any(|m| { + matches!( + m, + RawMessage::Notification { method, .. } + if method == "textDocument/publishDiagnostics" + ) + }) +} + +async fn wait_for_response( + outbox: &Arc>>, + id: &RequestId, + deadline: Duration, +) { + let start = std::time::Instant::now(); + loop { + let found = outbox + .lock() + .unwrap() + .iter() + .any(|m| matches!(m, RawMessage::Response { id: rid, .. } if rid == id)); + if found { + return; + } + assert!( + start.elapsed() < deadline, + "no response for {id:?} within {deadline:?}" + ); + tokio::time::sleep(Duration::from_millis(5)).await; + } +} + +fn error_code(outbox: &[RawMessage], id: &RequestId) -> Option { + outbox.iter().find_map(|m| match m { + RawMessage::Response { id: rid, result } if rid == id => match result { + Err(e) => Some(e.code), + Ok(_) => None, + }, + _ => None, + }) +} + +/// Feed a single message into a freshly-started (uninitialized) server, +/// then close the transport so `serve` returns once the message — and +/// any handler it may have spawned — is fully processed. Returns the +/// outbox. +async fn drive_uninitialized(msg: RawMessage) -> Vec { + let (in_tx, in_rx) = mpsc::unbounded_channel::(); + let outbox = Arc::new(Mutex::new(Vec::new())); + let transport = ChannelTransport { + in_rx, + outbox: outbox.clone(), + }; + + let server_handle = tokio::spawn(async move { + let _ = lspf::serve(Probe, transport).await; + }); + + in_tx.send(msg).unwrap(); + drop(in_tx); // peer disconnect → serve drains and returns + + tokio::time::timeout(Duration::from_secs(2), server_handle) + .await + .expect("serve returned within 2s") + .expect("server task did not panic"); + + let v = outbox.lock().unwrap().clone(); + v +} + +/// A server whose `didOpen` sleeps a long time before publishing, so a +/// test can tell whether an in-flight handler was aborted (no publish, +/// prompt return) or awaited to completion (publish after the sleep). +struct SlowOpen; + +const SLOW: Duration = Duration::from_secs(2); + +impl LanguageServer for SlowOpen { + async fn text_document_did_open(&self, ctx: &Context, params: DidOpenTextDocumentParams) { + tokio::time::sleep(SLOW).await; + ctx.publish_diagnostics(PublishDiagnosticsParams { + uri: params.text_document.uri, + version: Some(params.text_document.version), + diagnostics: vec![], + }); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn exit_aborts_in_flight_handler() { + let (in_tx, in_rx) = mpsc::unbounded_channel::(); + let outbox = Arc::new(Mutex::new(Vec::new())); + let transport = ChannelTransport { + in_rx, + outbox: outbox.clone(), + }; + + let server_handle = tokio::spawn(async move { + let _ = lspf::serve(SlowOpen, transport).await; + }); + + // Reach Running so the didOpen isn't gated, then put the slow handler + // in flight. + in_tx.send(initialize_request(1)).unwrap(); + wait_for_response(&outbox, &RequestId::Number(1), Duration::from_millis(500)).await; + in_tx.send(did_open_notification("file:///slow")).unwrap(); + tokio::time::sleep(Duration::from_millis(50)).await; // let it reach its sleep + + // `exit` must abort the in-flight handler, not await its 2s sleep. + let exit_sent = std::time::Instant::now(); + in_tx.send(notification("exit", json!(null))).unwrap(); + + tokio::time::timeout(Duration::from_millis(500), server_handle) + .await + .expect("serve returned within 500ms — exit aborted the in-flight handler") + .expect("server task did not panic"); + + assert!( + exit_sent.elapsed() < SLOW, + "exit took {:?}, which means it awaited the slow handler instead of aborting it", + exit_sent.elapsed() + ); + assert!( + !has_publish_diagnostics(&outbox.lock().unwrap()), + "aborted handler must not have published its diagnostic" + ); +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn notification_before_initialize_is_dropped() { + // Every notification except `initialized` / `exit` must be dropped + // (no handler spawned, no wire output) while uninitialized. `didOpen` + // is the observable case: its handler would publish a diagnostic, so + // an empty outbox proves it never ran. + let cases: &[RawMessage] = &[ + did_open_notification("file:///a"), + notification("$/cancelRequest", json!({ "id": 1 })), + notification("$/setTrace", json!({ "value": "verbose" })), + ]; + + for msg in cases { + let method = match msg { + RawMessage::Notification { method, .. } => method.clone(), + _ => unreachable!(), + }; + let outbox = drive_uninitialized(msg.clone()).await; + assert!( + outbox.is_empty(), + "notification `{method}` before initialize should be dropped, \ + got outbox {outbox:#?}" + ); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn request_before_initialize_returns_server_not_initialized() { + // Every request method except `initialize` must be refused with + // ServerNotInitialized (-32002) while the server is uninitialized. + let cases: &[&'static str] = &["shutdown", "textDocument/hover"]; + + for method in cases { + let id = RequestId::Number(1); + let outbox = drive_uninitialized(request(1, method)).await; + assert_eq!( + error_code(&outbox, &id), + Some(-32002), + "request `{method}` before initialize should return ServerNotInitialized, \ + got outbox {outbox:#?}" + ); + } +}