diff --git a/lib/saluki-core/examples/basic_supervisor.rs b/lib/saluki-core/examples/basic_supervisor.rs index 5809acf7e3..21083d858e 100644 --- a/lib/saluki-core/examples/basic_supervisor.rs +++ b/lib/saluki-core/examples/basic_supervisor.rs @@ -1,6 +1,7 @@ use std::time::Duration; -use saluki_core::runtime::{ProcessShutdown, Supervisable, Supervisor, SupervisorFuture}; +use async_trait::async_trait; +use saluki_core::runtime::{InitializationError, ProcessShutdown, Supervisable, Supervisor, SupervisorFuture}; use saluki_error::GenericError; use tokio::{pin, select}; use tracing::{error, info}; @@ -129,19 +130,24 @@ impl MockWorker { } } +#[async_trait] impl Supervisable for MockWorker { fn name(&self) -> &str { self.worker_name } - fn initialize(&self, mut process_shutdown: ProcessShutdown) -> Option { + async fn initialize(&self, mut process_shutdown: ProcessShutdown) -> Result { let worker_name = self.worker_name; let delay = self.delay; let result = self.result; let panic = self.panic; let shutdown_delay = self.shutdown_delay; - Some(Box::pin(async move { + // This is where async initialization would happen, e.g.: + // let listener = TcpListener::bind(port).await + // .map_err(|e| InitializationError::Failed { source: e.into() })?; + + Ok(Box::pin(async move { info!(worker_name, "Worker started."); let work = if delay.is_zero() { tokio::time::sleep(Duration::MAX) diff --git a/lib/saluki-core/src/runtime/mod.rs b/lib/saluki-core/src/runtime/mod.rs index 6884a63890..8369acb63f 100644 --- a/lib/saluki-core/src/runtime/mod.rs +++ b/lib/saluki-core/src/runtime/mod.rs @@ -63,7 +63,9 @@ mod restart; pub use self::restart::{RestartMode, RestartStrategy}; mod supervisor; -pub use self::supervisor::{ShutdownStrategy, Supervisable, Supervisor, SupervisorError, SupervisorFuture}; +pub use self::supervisor::{ + InitializationError, ShutdownStrategy, Supervisable, Supervisor, SupervisorError, SupervisorFuture, +}; mod shutdown; pub use self::shutdown::{ProcessShutdown, ShutdownHandle}; diff --git a/lib/saluki-core/src/runtime/supervisor.rs b/lib/saluki-core/src/runtime/supervisor.rs index 2dc72efefe..183ba7e77d 100644 --- a/lib/saluki-core/src/runtime/supervisor.rs +++ b/lib/saluki-core/src/runtime/supervisor.rs @@ -1,5 +1,6 @@ use std::{future::Future, pin::Pin, sync::Arc, time::Duration}; +use async_trait::async_trait; use saluki_common::collections::FastIndexMap; use saluki_error::{ErrorContext as _, GenericError}; use snafu::{OptionExt as _, Snafu}; @@ -19,6 +20,26 @@ use crate::runtime::process::{Process, ProcessExt as _}; /// A `Future` that represents the execution of a supervised process. pub type SupervisorFuture = Pin> + Send>>; +/// A `Future` that represents the full lifecycle of a worker, including initialization. +/// +/// Unlike [`SupervisorFuture`], which only represents the runtime phase, this future first performs async +/// initialization and then runs the worker. This allows initialization to happen concurrently when multiple workers are +/// spawned, and keeps the supervisor loop responsive to shutdown signals during initialization. +type WorkerFuture = Pin> + Send>>; + +/// Worker lifecycle errors. +/// +/// Distinguishes between initialization failures (which should NOT trigger restart logic) and runtime failures (which +/// are eligible for restart). +#[derive(Debug)] +enum WorkerError { + /// The worker failed during async initialization. + Initialization(InitializationError), + + /// The worker failed during runtime execution. + Runtime(GenericError), +} + /// Process errors. #[derive(Debug, Snafu)] pub enum ProcessError { @@ -38,6 +59,28 @@ pub enum ProcessError { }, } +/// Initialization errors. +/// +/// Initialization errors are distinct from runtime errors: they indicate that a process could not be started at all +/// (e.g., failed to bind a port, missing configuration). These errors do NOT trigger restart logic; instead, they +/// immediately propagate up and fail the supervisor. +#[derive(Debug, Snafu)] +#[snafu(context(suffix(false)))] +pub enum InitializationError { + /// The process could not be initialized due to an error. + #[snafu(display("Process failed to initialize: {}", source))] + Failed { + /// The underlying error that caused initialization to fail. + source: GenericError, + }, + + /// The process is permanently unavailable and cannot be initialized. + /// + /// This is for cases where initialization is structurally impossible, not due to a transient error. + #[snafu(display("Process is permanently unavailable"))] + PermanentlyUnavailable, +} + /// Strategy for shutting down a process. pub enum ShutdownStrategy { /// Waits for the configured duration for the process to exit, and then forcefully aborts it otherwise. @@ -48,22 +91,26 @@ pub enum ShutdownStrategy { } /// A supervisable process. +#[async_trait] pub trait Supervisable: Send + Sync { /// Returns the name of the process. fn name(&self) -> &str; - /// Defines the shutdown strategy for the process. + /// Returns the shutdown strategy for the process. fn shutdown_strategy(&self) -> ShutdownStrategy { ShutdownStrategy::Graceful(Duration::from_secs(5)) } - /// Initialize a `Future` that represents the execution of the process. + /// Initializes the process asynchronously. + /// + /// During initialization, any resources or configuration for the process can be created asynchronously, and the + /// same runtime that is used for running the process is used for initialization. The resulting future is expected + /// to complete as soon as reasonably possible after `process_shutdown` resolves. /// - /// When `Some` is returned, the process is spawned and managed by the supervisor. When `None` is returned, the - /// process is considered to be permanently failed. This can be useful for supervised tasks that are not expected to - /// ever fail, or cannot support restart, but should still be managed within the same supervision hierarchy as other - /// processes. - fn initialize(&self, process_shutdown: ProcessShutdown) -> Option; + /// # Errors + /// + /// If the process cannot be initialized, an error is returned. + async fn initialize(&self, process_shutdown: ProcessShutdown) -> Result; } /// Supervisor errors. @@ -82,8 +129,17 @@ pub enum SupervisorError { NoChildren, /// A child process failed to initialize. - #[snafu(display("Child process failed to initialize."))] - FailedToInitialize, + /// + /// This error indicates that a child could not complete its async initialization. This is distinct from runtime + /// failures and does NOT trigger restart logic. + #[snafu(display("Child process '{}' failed to initialize: {}", child_name, source))] + FailedToInitialize { + /// The name of the child that failed to initialize. + child_name: String, + + /// The underlying initialization error. + source: InitializationError, + }, /// The supervisor exceeded its restart limits and was forced to shutdown. #[snafu(display("Supervisor has exceeded restart limits and was forced to shutdown."))] @@ -128,35 +184,49 @@ impl ChildSpecification { } } - fn initialize( - &self, parent_process: &Process, process_shutdown: ProcessShutdown, - ) -> Result, SupervisorError> { + fn create_process(&self, parent_process: &Process) -> Result { match self { - Self::Worker(worker) => { - let process = Process::worker(worker.name(), parent_process).context(InvalidName { - name: worker.name().to_string(), - })?; - Ok(worker.initialize(process_shutdown).map(|future| (process, future))) - } + Self::Worker(worker) => Process::worker(worker.name(), parent_process).context(InvalidName { + name: worker.name().to_string(), + }), Self::Supervisor(sup) => { - let process = Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName { + Process::supervisor(&sup.supervisor_id, Some(parent_process)).context(InvalidName { name: sup.supervisor_id.to_string(), - })?; + }) + } + } + } + fn create_worker_future( + &self, process: Process, process_shutdown: ProcessShutdown, + ) -> Result { + match self { + Self::Worker(worker) => { + let worker = Arc::clone(worker); + Ok(Box::pin(async move { + let run_future = worker + .initialize(process_shutdown) + .await + .map_err(WorkerError::Initialization)?; + run_future.await.map_err(WorkerError::Runtime) + })) + } + Self::Supervisor(sup) => { match sup.runtime_mode() { RuntimeMode::Ambient => { // Run on the parent's ambient runtime. - Ok(Some(( - process.clone(), - sup.as_nested_process(process, process_shutdown), - ))) + Ok(sup.as_nested_process(process, process_shutdown)) } RuntimeMode::Dedicated(config) => { // Spawn in a dedicated runtime on a new OS thread. + let child_name = sup.supervisor_id.to_string(); let handle = spawn_dedicated_runtime(sup.inner_clone(), config.clone(), process_shutdown) - .map_err(|_| SupervisorError::FailedToInitialize)?; + .map_err(|e| SupervisorError::FailedToInitialize { + child_name, + source: InitializationError::Failed { source: e }, + })?; - Ok(Some((process, Box::pin(handle)))) + Ok(Box::pin(async move { handle.await.map_err(WorkerError::Runtime) })) } } } @@ -202,9 +272,10 @@ where /// # Instrumentation /// /// Supervisors automatically create their own allocation group -/// ([`TrackingAllocator`][memory_accounting::allocator::TrackingAllocator]), which is used to track both the memory usage of the -/// supervisor itself and its children. Additionally, individual worker processes are wrapped in a dedicated -/// [`tracing::Span`] to allow tracing the casual relationship between arbitrary code and the worker executing it. +/// ([`TrackingAllocator`][memory_accounting::allocator::TrackingAllocator]), which is used to track both the memory +/// usage of the supervisor itself and its children. Additionally, individual worker processes are wrapped in a +/// dedicated [`tracing::Span`] to allow tracing the causal relationship between arbitrary code and the worker executing +/// it. /// /// # Restart Strategies /// @@ -254,8 +325,8 @@ impl Supervisor { /// Configures this supervisor to run in a dedicated runtime. /// - /// When this supervisor is added as a child to another supervisor, it will spawn its own - /// OS thread(s) and Tokio runtime instead of running on the parent's ambient runtime. + /// When this supervisor is added as a child to another supervisor, it will spawn its own OS thread(s) and Tokio + /// runtime instead of running on the parent's ambient runtime. /// /// This provides runtime isolation, which can be useful for: /// - CPU-bound work that shouldn't block the parent's runtime @@ -317,7 +388,8 @@ impl Supervisor { let mut restart_state = RestartState::new(self.restart_strategy); let mut worker_state = WorkerState::new(process); - // Do the initial spawn of all child processes and supervisors. + // Spawn all child processes. Since initialization is folded into each worker's task, this returns immediately + // after spawning -- children initialize concurrently in the background. self.spawn_all_children(&mut worker_state)?; // Now we supervise. @@ -335,12 +407,31 @@ impl Supervisor { break; }, worker_task_result = worker_state.wait_for_next_worker() => match worker_task_result { - // TODO: Erlang/OTP defaults to always trying to restart a process, even if it doesn't terminate due to a - // legitimate failure. It does allow configuring this behavior on a per-process basis, however. We don't - // support dynamically adding child processes, which is the only real use case I can think of for having - // non-long-lived child processes... so I think for now, we're OK just always try to restart. + // TODO: Erlang/OTP defaults to always trying to restart a process, even if it doesn't terminate due + // to a legitimate failure. It does allow configuring this behavior on a per-process basis, however. + // We don't support dynamically adding child processes, which is the only real use case I can think + // of for having non-long-lived child processes... so I think for now, we're OK just always try to + // restart. Some((child_spec_idx, worker_result)) => { let child_spec = self.get_child_spec(child_spec_idx); + + // Initialization failures are not eligible for restart -- they propagate immediately. + if let Err(WorkerError::Initialization(e)) = worker_result { + error!(supervisor_id = %self.supervisor_id, worker_name = child_spec.name(), "Child process failed to initialize: {}", e); + worker_state.shutdown_workers().await; + return Err(SupervisorError::FailedToInitialize { + child_name: child_spec.name().to_string(), + source: e, + }); + } + + // Convert the worker result to a process error for restart evaluation. + let worker_result = worker_result + .map_err(|e| match e { + WorkerError::Runtime(e) => ProcessError::Terminated { source: e }, + WorkerError::Initialization(_) => unreachable!("handled above"), + }); + match restart_state.evaluate_restart() { RestartAction::Restart(mode) => match mode { RestartMode::OneForOne => { @@ -368,7 +459,7 @@ impl Supervisor { Ok(()) } - fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> SupervisorFuture { + fn as_nested_process(&self, process: Process, process_shutdown: ProcessShutdown) -> WorkerFuture { // Simple wrapper around `run_inner` to satisfy the return type signature needed when running the supervisor as // a nested child process in another supervisor. debug!(supervisor_id = %self.supervisor_id, "Nested supervisor starting."); @@ -380,6 +471,7 @@ impl Supervisor { sup.run_inner(process, process_shutdown) .await .error_context("Nested supervisor failed to exit cleanly.") + .map_err(WorkerError::Runtime) }) } @@ -389,8 +481,8 @@ impl Supervisor { /// /// If the supervisor exceeds its restart limits, or fails to initialize a child process, an error is returned. pub async fn run(&mut self) -> Result<(), SupervisorError> { - // Create a no-op `ProcessShutdown` to satisfy the `run_inner` function. This is never used since we want to - // run forever, but we need to satisfy the signature. + // Create a no-op `ProcessShutdown` to satisfy the `run_inner` function. This is never used since we want to run + // forever, but we need to satisfy the signature. let process_shutdown = ProcessShutdown::noop(); let process = Process::supervisor(&self.supervisor_id, None).context(InvalidName { name: self.supervisor_id.to_string(), @@ -417,9 +509,8 @@ impl Supervisor { /// Runs the supervisor until the given `ProcessShutdown` signal is received. /// - /// This is an internal variant of `run_with_shutdown` that takes a `ProcessShutdown` directly, - /// used when spawning supervisors in dedicated runtimes where the shutdown signal is already - /// wrapped in a `ProcessShutdown`. + /// This is an internal variant of `run_with_shutdown` that takes a `ProcessShutdown` directly, used when spawning + /// supervisors in dedicated runtimes where the shutdown signal is already wrapped in a `ProcessShutdown`. /// /// # Errors /// @@ -459,7 +550,7 @@ struct ProcessState { struct WorkerState { process: Process, - worker_tasks: JoinSet>, + worker_tasks: JoinSet>, worker_map: FastIndexMap, } @@ -474,28 +565,23 @@ impl WorkerState { fn add_worker(&mut self, worker_id: usize, child_spec: &ChildSpecification) -> Result<(), SupervisorError> { let (process_shutdown, shutdown_handle) = ProcessShutdown::paired(); - match child_spec.initialize(&self.process, process_shutdown)? { - Some((process, worker)) => { - let shutdown_strategy = child_spec.shutdown_strategy(); - - let abort_handle = self.worker_tasks.spawn(worker.into_instrumented(process)); - self.worker_map.insert( - abort_handle.id(), - ProcessState { - worker_id, - shutdown_strategy, - shutdown_handle, - abort_handle, - }, - ); - - Ok(()) - } - None => Err(SupervisorError::FailedToInitialize), - } + let process = child_spec.create_process(&self.process)?; + let worker_future = child_spec.create_worker_future(process.clone(), process_shutdown)?; + let shutdown_strategy = child_spec.shutdown_strategy(); + let abort_handle = self.worker_tasks.spawn(worker_future.into_instrumented(process)); + self.worker_map.insert( + abort_handle.id(), + ProcessState { + worker_id, + shutdown_strategy, + shutdown_handle, + abort_handle, + }, + ); + Ok(()) } - async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), ProcessError>)> { + async fn wait_for_next_worker(&mut self) -> Option<(usize, Result<(), WorkerError>)> { debug!("Waiting for next process to complete."); match self.worker_tasks.join_next_with_id().await { @@ -504,10 +590,7 @@ impl WorkerState { .worker_map .swap_remove(&worker_task_id) .expect("worker task ID not found"); - Some(( - process_state.worker_id, - worker_result.map_err(|e| ProcessError::Terminated { source: e }), - )) + Some((process_state.worker_id, worker_result)) } Some(Err(e)) => { let worker_task_id = e.id(); @@ -520,7 +603,7 @@ impl WorkerState { } else { ProcessError::Panicked }; - Some((process_state.worker_id, Err(e))) + Some((process_state.worker_id, Err(WorkerError::Runtime(e.into())))) } None => None, } @@ -609,3 +692,365 @@ impl WorkerState { ); } } + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicUsize, Ordering}; + + use async_trait::async_trait; + use tokio::{ + sync::oneshot, + task::JoinHandle, + time::{sleep, timeout}, + }; + + use super::*; + + /// Behavior for a mock worker during initialization. + #[derive(Clone)] + enum InitBehavior { + /// Initialization succeeds immediately. + Instant, + + /// Initialization takes the given duration before succeeding. + Slow(Duration), + + /// Initialization fails with the given message. + Fail(&'static str), + } + + /// Behavior for a mock worker during runtime (after initialization). + #[derive(Clone)] + enum RunBehavior { + /// Runs until shutdown is received. + UntilShutdown, + + /// Fails with the given error message after the given delay. + FailAfter(Duration, &'static str), + } + + /// A configurable mock worker for testing supervisor behavior. + struct MockWorker { + name: &'static str, + init_behavior: InitBehavior, + run_behavior: RunBehavior, + start_count: Arc, + } + + impl MockWorker { + /// Creates a worker that runs until shutdown. + fn long_running(name: &'static str) -> Self { + Self { + name, + init_behavior: InitBehavior::Instant, + run_behavior: RunBehavior::UntilShutdown, + start_count: Arc::new(AtomicUsize::new(0)), + } + } + + /// Creates a worker that fails after the given delay. + fn failing(name: &'static str, delay: Duration) -> Self { + Self { + name, + init_behavior: InitBehavior::Instant, + run_behavior: RunBehavior::FailAfter(delay, "worker failed"), + start_count: Arc::new(AtomicUsize::new(0)), + } + } + + /// Creates a worker that fails during initialization. + fn init_failure(name: &'static str) -> Self { + Self { + name, + init_behavior: InitBehavior::Fail("init failed"), + run_behavior: RunBehavior::UntilShutdown, + start_count: Arc::new(AtomicUsize::new(0)), + } + } + + /// Creates a worker with slow initialization. + fn slow_init(name: &'static str, init_delay: Duration) -> Self { + Self { + name, + init_behavior: InitBehavior::Slow(init_delay), + run_behavior: RunBehavior::UntilShutdown, + start_count: Arc::new(AtomicUsize::new(0)), + } + } + + /// Returns a shared handle to the start count for this worker. + fn start_count(&self) -> Arc { + Arc::clone(&self.start_count) + } + } + + #[async_trait] + impl Supervisable for MockWorker { + fn name(&self) -> &str { + self.name + } + + fn shutdown_strategy(&self) -> ShutdownStrategy { + ShutdownStrategy::Graceful(Duration::from_millis(500)) + } + + async fn initialize( + &self, mut process_shutdown: ProcessShutdown, + ) -> Result { + match &self.init_behavior { + InitBehavior::Instant => {} + InitBehavior::Slow(delay) => { + sleep(*delay).await; + } + InitBehavior::Fail(msg) => { + return Err(InitializationError::Failed { + source: GenericError::msg(*msg), + }); + } + } + + let start_count = Arc::clone(&self.start_count); + let run_behavior = self.run_behavior.clone(); + + Ok(Box::pin(async move { + start_count.fetch_add(1, Ordering::SeqCst); + + match run_behavior { + RunBehavior::UntilShutdown => { + process_shutdown.wait_for_shutdown().await; + Ok(()) + } + RunBehavior::FailAfter(delay, msg) => { + select! { + _ = sleep(delay) => { + Err(GenericError::msg(msg)) + } + _ = process_shutdown.wait_for_shutdown() => { + Ok(()) + } + } + } + } + })) + } + } + + /// Helper: run a supervisor with a oneshot-based shutdown trigger. + /// Returns the supervisor result and provides the shutdown sender. + async fn run_supervisor_with_trigger( + mut supervisor: Supervisor, + ) -> (oneshot::Sender<()>, JoinHandle>) { + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async move { supervisor.run_with_shutdown(rx).await }); + // Give the supervisor a moment to start and spawn children. + sleep(Duration::from_millis(50)).await; + (tx, handle) + } + + // -- Supervisor run mode tests --------------------------------------------------------- + + #[tokio::test] + async fn standalone_supervisor_shuts_down_cleanly() { + let mut sup = Supervisor::new("test-sup").unwrap(); + sup.add_worker(MockWorker::long_running("worker1")); + sup.add_worker(MockWorker::long_running("worker2")); + + let (tx, handle) = run_supervisor_with_trigger(sup).await; + tx.send(()).unwrap(); + + let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap(); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn nested_supervisor_shuts_down_cleanly() { + let mut child_sup = Supervisor::new("child-sup").unwrap(); + child_sup.add_worker(MockWorker::long_running("inner-worker")); + + let mut parent_sup = Supervisor::new("parent-sup").unwrap(); + parent_sup.add_worker(MockWorker::long_running("outer-worker")); + parent_sup.add_worker(child_sup); + + let (tx, handle) = run_supervisor_with_trigger(parent_sup).await; + tx.send(()).unwrap(); + + let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap(); + assert!(result.is_ok()); + } + + #[tokio::test] + async fn supervisor_with_no_children_returns_error() { + let mut sup = Supervisor::new("empty-sup").unwrap(); + + let (tx, rx) = oneshot::channel::<()>(); + let result = sup.run_with_shutdown(rx).await; + drop(tx); + + assert!(matches!(result, Err(SupervisorError::NoChildren))); + } + + // -- Child restart behavior tests ------------------------------------------------------ + + #[tokio::test] + async fn one_for_one_restarts_only_failed_child() { + let failing = MockWorker::failing("failing-worker", Duration::from_millis(50)); + let failing_count = failing.start_count(); + + let stable = MockWorker::long_running("stable-worker"); + let stable_count = stable.start_count(); + + let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy( + RestartStrategy::one_to_one().with_intensity_and_period(20, Duration::from_secs(10)), + ); + sup.add_worker(stable); + sup.add_worker(failing); + + let (tx, handle) = run_supervisor_with_trigger(sup).await; + + // Wait for a few restarts to happen. + sleep(Duration::from_millis(300)).await; + let _ = tx.send(()); + + let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap(); + assert!(result.is_ok()); + + // The failing worker should have been started multiple times. + assert!( + failing_count.load(Ordering::SeqCst) >= 2, + "failing worker should have been restarted" + ); + // The stable worker should only have been started once (never restarted). + assert_eq!( + stable_count.load(Ordering::SeqCst), + 1, + "stable worker should not have been restarted" + ); + } + + #[tokio::test] + async fn one_for_all_restarts_all_children() { + let failing = MockWorker::failing("failing-worker", Duration::from_millis(50)); + let failing_count = failing.start_count(); + + let stable = MockWorker::long_running("stable-worker"); + let stable_count = stable.start_count(); + + let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy( + RestartStrategy::one_for_all().with_intensity_and_period(20, Duration::from_secs(10)), + ); + sup.add_worker(stable); + sup.add_worker(failing); + + let (tx, handle) = run_supervisor_with_trigger(sup).await; + + // Wait for at least one restart cycle. + sleep(Duration::from_millis(300)).await; + let _ = tx.send(()); + + let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap(); + assert!(result.is_ok()); + + // Both workers should have been started multiple times. + assert!( + failing_count.load(Ordering::SeqCst) >= 2, + "failing worker should have been restarted" + ); + assert!( + stable_count.load(Ordering::SeqCst) >= 2, + "stable worker should also have been restarted" + ); + } + + #[tokio::test] + async fn restart_limit_exceeded_shuts_down_supervisor() { + let mut sup = Supervisor::new("test-sup") + .unwrap() + .with_restart_strategy(RestartStrategy::one_to_one().with_intensity_and_period(1, Duration::from_secs(10))); + // This worker fails immediately, which will exhaust the restart budget quickly. + sup.add_worker(MockWorker::failing("fast-fail", Duration::ZERO)); + + let (tx, rx) = oneshot::channel::<()>(); + let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await }); + + let result = timeout(Duration::from_secs(2), handle).await.unwrap().unwrap(); + drop(tx); + + assert!(matches!(result, Err(SupervisorError::Shutdown))); + } + + // -- Initialization failure tests ------------------------------------------------------ + + #[tokio::test] + async fn init_failure_propagates_with_child_name() { + let mut sup = Supervisor::new("test-sup").unwrap(); + sup.add_worker(MockWorker::long_running("good-worker")); + sup.add_worker(MockWorker::init_failure("bad-worker")); + + let (_tx, rx) = oneshot::channel::<()>(); + let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx)) + .await + .unwrap(); + + match result { + Err(SupervisorError::FailedToInitialize { child_name, .. }) => { + assert_eq!(child_name, "bad-worker"); + } + other => panic!("expected FailedToInitialize, got: {:?}", other), + } + } + + #[tokio::test] + async fn init_failure_does_not_trigger_restart() { + let init_fail = MockWorker::init_failure("bad-worker"); + let start_count = init_fail.start_count(); + + let mut sup = Supervisor::new("test-sup").unwrap().with_restart_strategy( + RestartStrategy::one_to_one().with_intensity_and_period(10, Duration::from_secs(10)), + ); + sup.add_worker(init_fail); + + let (_tx, rx) = oneshot::channel::<()>(); + let result = timeout(Duration::from_secs(2), sup.run_with_shutdown(rx)) + .await + .unwrap(); + + assert!(matches!(result, Err(SupervisorError::FailedToInitialize { .. }))); + // The worker never got past init, so start_count should be 0. + assert_eq!(start_count.load(Ordering::SeqCst), 0); + } + + // -- Shutdown responsiveness tests ----------------------------------------------------- + + #[tokio::test] + async fn shutdown_completes_promptly_in_steady_state() { + let mut sup = Supervisor::new("test-sup").unwrap(); + sup.add_worker(MockWorker::long_running("worker1")); + sup.add_worker(MockWorker::long_running("worker2")); + + let (tx, handle) = run_supervisor_with_trigger(sup).await; + tx.send(()).unwrap(); + + // Shutdown should complete well within 1 second (workers respond to shutdown signal immediately). + let result = timeout(Duration::from_secs(1), handle).await; + assert!(result.is_ok(), "shutdown should complete promptly"); + } + + #[tokio::test] + async fn shutdown_during_slow_init_completes_promptly() { + let mut sup = Supervisor::new("test-sup").unwrap(); + // This worker takes 30 seconds to initialize — but we'll trigger shutdown immediately. + sup.add_worker(MockWorker::slow_init("slow-worker", Duration::from_secs(30))); + + let (tx, rx) = oneshot::channel(); + let handle = tokio::spawn(async move { sup.run_with_shutdown(rx).await }); + + // Give the supervisor just enough time to spawn the task, then trigger shutdown. + sleep(Duration::from_millis(20)).await; + tx.send(()).unwrap(); + + // Shutdown should complete quickly even though the worker hasn't finished initializing. + // The supervisor loop sees the shutdown signal and aborts the still-initializing task. + let result = timeout(Duration::from_secs(2), handle).await; + assert!(result.is_ok(), "shutdown during slow init should complete promptly"); + } +}