diff --git a/Cargo.lock b/Cargo.lock index 5cc36789..9d81f1bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2947,6 +2947,15 @@ version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" +[[package]] +name = "slotmap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdd58c3c93c3d278ca835519292445cb4b0d4dc59ccfdf7ceadaab3f8aeb4038" +dependencies = [ + "version_check", +] + [[package]] name = "slug" version = "0.1.6" @@ -4267,6 +4276,7 @@ dependencies = [ "serde_json", "sqlx", "tokio", + "tokio-util", "uuid", "waymark-backend-memory", "waymark-backend-postgres", @@ -4282,7 +4292,12 @@ dependencies = [ "waymark-runner-state", "waymark-secret-string", "waymark-support-integration", - "waymark-worker-remote", + "waymark-worker-core", + "waymark-worker-process-pool", + "waymark-worker-process-spec", + "waymark-worker-python", + "waymark-worker-remote-bringup", + "waymark-worker-remote-pool", "waymark-workflow-registry-backend", ] @@ -4386,6 +4401,17 @@ dependencies = [ "tonic-build", ] +[[package]] +name = "waymark-reserved-process" +version = "0.1.0" +dependencies = [ + "thiserror", + "tokio", + "tracing", + "waymark-managed-process", + "waymark-worker-reservation", +] + [[package]] name = "waymark-runloop" version = "0.1.0" @@ -4556,7 +4582,10 @@ dependencies = [ "waymark-runloop", "waymark-runner-state", "waymark-smoke-sources", - "waymark-worker-remote", + "waymark-worker-process-spec", + "waymark-worker-python", + "waymark-worker-remote-bringup", + "waymark-worker-remote-pool", "waymark-workflow-registry-backend", ] @@ -4626,7 +4655,9 @@ dependencies = [ "waymark-scheduler-loop", "waymark-scheduler-loop-core", "waymark-webapp-bringup", - "waymark-worker-remote", + "waymark-worker-python", + "waymark-worker-remote-bringup", + "waymark-worker-remote-pool", "waymark-worker-status-reporter", ] @@ -4797,6 +4828,19 @@ dependencies = [ "waymark-worker-core", ] +[[package]] +name = "waymark-worker-message-protocol" +version = "0.1.0" +dependencies = [ + "prost 0.12.6", + "thiserror", + "tokio", + "tracing", + "uuid", + "waymark-proto", + "waymark-worker-metrics", +] + [[package]] name = "waymark-worker-metrics" version = "0.1.0" @@ -4806,28 +4850,126 @@ dependencies = [ ] [[package]] -name = "waymark-worker-remote" +name = "waymark-worker-process" +version = "0.1.0" +dependencies = [ + "tokio", + "tracing", + "waymark-managed-process", + "waymark-reserved-process", + "waymark-worker-message-protocol", + "waymark-worker-reservation", +] + +[[package]] +name = "waymark-worker-process-pool" version = "0.1.0" dependencies = [ - "anyhow", - "futures-core", "nonempty-collections", - "prost 0.12.6", - "serde_json", "thiserror", + "tokio", + "tracing", + "waymark-managed-process", + "waymark-reserved-process", + "waymark-worker-message-protocol", + "waymark-worker-metrics", + "waymark-worker-process", + "waymark-worker-process-spec", + "waymark-worker-reservation", + "waymark-worker-status-core", +] + +[[package]] +name = "waymark-worker-process-spec" +version = "0.1.0" +dependencies = [ + "waymark-reserved-process", + "waymark-worker-reservation", +] + +[[package]] +name = "waymark-worker-python" +version = "0.1.0" +dependencies = [ + "tokio", + "tracing", + "waymark-reserved-process", + "waymark-worker-process-spec", + "waymark-worker-reservation", +] + +[[package]] +name = "waymark-worker-remote-bridge-bringup" +version = "0.1.0" +dependencies = [ "tokio", "tokio-stream", + "tokio-util", "tonic 0.11.0", "tracing", - "uuid", - "waymark-ids", + "waymark-proto", + "waymark-worker-message-protocol", + "waymark-worker-remote-bridge-service", + "waymark-worker-reservation", +] + +[[package]] +name = "waymark-worker-remote-bridge-service" +version = "0.1.0" +dependencies = [ + "futures-core", + "prost 0.12.6", + "tokio", + "tokio-stream", + "tonic 0.11.0", + "tracing", + "waymark-proto", + "waymark-worker-message-protocol", + "waymark-worker-reservation", +] + +[[package]] +name = "waymark-worker-remote-bringup" +version = "0.1.0" +dependencies = [ + "thiserror", + "tokio", + "tokio-util", + "tracing", + "waymark-worker-process-pool", + "waymark-worker-process-spec", + "waymark-worker-remote-bridge-bringup", +] + +[[package]] +name = "waymark-worker-remote-pool" +version = "0.1.0" +dependencies = [ + "nonempty-collections", + "prost 0.12.6", + "serde_json", + "tokio", + "tracing", + "waymark-managed-process", "waymark-message-conversions", "waymark-proto", "waymark-worker-core", + "waymark-worker-message-protocol", "waymark-worker-metrics", + "waymark-worker-process-pool", + "waymark-worker-process-spec", "waymark-worker-status-core", ] +[[package]] +name = "waymark-worker-reservation" +version = "0.1.0" +dependencies = [ + "slotmap", + "thiserror", + "tokio", +] + [[package]] name = "waymark-worker-status-backend" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index 22092b1f..57070c59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ waymark-ids = { path = "crates/lib/ids" } waymark-ir-conversions = { path = "crates/lib/ir-conversions" } waymark-ir-format = { path = "crates/lib/ir-format" } waymark-ir-parser = { path = "crates/lib/ir-parser" } +waymark-managed-process = { path = "crates/lib/managed-process" } waymark-message-conversions = { path = "crates/lib/message-conversions" } waymark-nonzero-duration = { path = "crates/lib/nonzero-duration" } waymark-observability = { path = "crates/lib/observability" } @@ -34,6 +35,7 @@ waymark-observability-setup = { path = "crates/lib/observability-setup" } waymark-pool-status = { path = "crates/lib/pool-status" } waymark-prometheus-exporter-bringup = { path = "crates/lib/prometheus-exporter-bringup" } waymark-proto = { path = "crates/lib/proto" } +waymark-reserved-process = { path = "crates/lib/reserved-process" } waymark-runloop = { path = "crates/lib/runloop" } waymark-runner = { path = "crates/lib/runner" } waymark-runner-retry-policy = { path = "crates/lib/runner-retry-policy" } @@ -61,8 +63,17 @@ waymark-webapp-core = { path = "crates/lib/webapp-core" } waymark-webapp-routes = { path = "crates/lib/webapp-routes" } waymark-worker-core = { path = "crates/lib/worker-core" } waymark-worker-inline = { path = "crates/lib/worker-inline" } +waymark-worker-message-protocol = { path = "crates/lib/worker-message-protocol" } waymark-worker-metrics = { path = "crates/lib/worker-metrics" } -waymark-worker-remote = { path = "crates/lib/worker-remote" } +waymark-worker-process = { path = "crates/lib/worker-process" } +waymark-worker-process-pool = { path = "crates/lib/worker-process-pool" } +waymark-worker-process-spec = { path = "crates/lib/worker-process-spec" } +waymark-worker-python = { path = "crates/lib/worker-python" } +waymark-worker-remote-bridge-bringup = { path = "crates/lib/worker-remote-bridge-bringup" } +waymark-worker-remote-bridge-service = { path = "crates/lib/worker-remote-bridge-service" } +waymark-worker-remote-bringup = { path = "crates/lib/worker-remote-bringup" } +waymark-worker-remote-pool = { path = "crates/lib/worker-remote-pool" } +waymark-worker-reservation = { path = "crates/lib/worker-reservation" } waymark-worker-status-backend = { path = "crates/lib/worker-status-backend" } waymark-worker-status-core = { path = "crates/lib/worker-status-core" } waymark-worker-status-reporter = { path = "crates/lib/worker-status-reporter" } @@ -100,6 +111,7 @@ serde = "1" serde_json = "1" serial_test = "2" sha2 = "0.10" +slotmap = "1" sqlx = { version = "0.8", default-features = false } strum = "0.28" syn = "2" diff --git a/crates/bin/integration-test/Cargo.toml b/crates/bin/integration-test/Cargo.toml index e8f28a9b..9f79b5e7 100644 --- a/crates/bin/integration-test/Cargo.toml +++ b/crates/bin/integration-test/Cargo.toml @@ -19,7 +19,12 @@ waymark-runloop = { workspace = true } waymark-runner-state = { workspace = true } waymark-secret-string = { workspace = true } waymark-support-integration = { workspace = true } -waymark-worker-remote = { workspace = true } +waymark-worker-core = { workspace = true } +waymark-worker-process-pool = { workspace = true } +waymark-worker-process-spec = { workspace = true } +waymark-worker-python = { workspace = true } +waymark-worker-remote-bringup = { workspace = true } +waymark-worker-remote-pool = { workspace = true } waymark-workflow-registry-backend = { workspace = true } anyhow = { workspace = true } @@ -30,4 +35,5 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } sqlx = { workspace = true, features = ["runtime-tokio"] } tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio-util = { workspace = true } uuid = { workspace = true } diff --git a/crates/bin/integration-test/src/main.rs b/crates/bin/integration-test/src/main.rs index 59eab5fc..c6848fdc 100644 --- a/crates/bin/integration-test/src/main.rs +++ b/crates/bin/integration-test/src/main.rs @@ -6,6 +6,7 @@ //! 3. Assert backend output matches inline Python output. use std::collections::{BTreeMap, HashMap, VecDeque}; +use std::num::NonZeroUsize; use std::path::{Path, PathBuf}; use std::process::Command; use std::sync::{Arc, Mutex}; @@ -30,7 +31,6 @@ use waymark_proto::ast as ir; use waymark_runloop::{RunLoop, RunLoopConfig}; use waymark_runner_state::RunnerState; use waymark_support_integration::{LOCAL_POSTGRES_DSN, connect_pool, ensure_local_postgres}; -use waymark_worker_remote::{PythonWorkerConfig, RemoteWorkerPool}; use waymark_workflow_registry_backend::{WorkflowRegistration, WorkflowRegistryBackend}; #[derive(Parser, Debug)] @@ -45,8 +45,8 @@ struct Args { cases: Vec, /// Number of Python workers for backend execution. - #[arg(long, default_value_t = 2)] - worker_count: usize, + #[arg(long, default_value_t = 2.try_into().unwrap())] + worker_count: NonZeroUsize, /// Timeout per backend execution, in seconds. #[arg(long, default_value_t = 120)] @@ -202,6 +202,8 @@ async fn main() -> Result<()> { bail!("no fixture cases selected"); } + let shutdown_token = tokio_util::sync::CancellationToken::new(); + let mut prepared_cases = Vec::new(); for case in selected_cases { prepared_cases.push(prepare_case(&repo_root, case.clone()).with_context(|| { @@ -218,9 +220,19 @@ async fn main() -> Result<()> { None }; - let worker_pool = setup_worker_pool(&repo_root, &prepared_cases, args.worker_count) - .await - .context("start integration worker pool")?; + let (worker_process_pool, bridge_server_task) = setup_worker_pool( + shutdown_token.clone(), + &repo_root, + &prepared_cases, + args.worker_count, + ) + .await + .context("start integration worker pool")?; + + let worker_process_pool = Arc::new(worker_process_pool); + let worker_pool = + waymark_worker_remote_pool::RemoteWorkerPool::new(Arc::clone(&worker_process_pool)); + let worker_pool = Arc::new(worker_pool); let mut failures = Vec::new(); let mut comparisons = 0usize; @@ -262,8 +274,10 @@ async fn main() -> Result<()> { } } } + bridge_server_task.abort(); + let _ = bridge_server_task.await; - if let Err(err) = worker_pool.shutdown().await { + if let Err(err) = worker_pool.shutdown_arc().await { eprintln!("failed to shutdown worker pool: {err}"); } @@ -287,6 +301,7 @@ async fn main() -> Result<()> { prepared_cases.len(), comparisons ); + Ok(()) } @@ -416,10 +431,14 @@ fn run_python_helper(repo_root: &Path, case: &FixtureCase) -> Result Result { + worker_count: NonZeroUsize, +) -> Result<( + waymark_worker_process_pool::Pool, + tokio::task::JoinHandle<()>, +)> { let mut modules = cases .iter() .map(|prepared| prepared.case.module_name.to_string()) @@ -427,7 +446,7 @@ async fn setup_worker_pool( modules.sort(); modules.dedup(); - let config = PythonWorkerConfig::new() + let config = waymark_worker_python::Config::new() .with_user_modules(modules) .with_python_paths(vec![ repo_root.join("python"), @@ -435,9 +454,21 @@ async fn setup_worker_pool( repo_root.join("tests/integration_tests"), ]); - RemoteWorkerPool::new_with_config(config, worker_count.max(1), None, None, 10) - .await - .context("create remote worker pool") + let (pool, task) = waymark_worker_remote_bringup::start( + shutdown_token, + None, + |bridge_server_addr| waymark_worker_python::Spec { + bridge_server_addr, + config, + }, + worker_count, + None, + 10.try_into().unwrap(), + ) + .await + .context("create remote worker pool")?; + + Ok((pool, task)) } async fn connect_postgres_backend() -> Result { @@ -460,11 +491,15 @@ async fn connect_postgres_backend() -> Result { Ok(PostgresBackend::new(pool)) } -async fn run_case_in_memory( +async fn run_case_in_memory( prepared: &PreparedCase, - worker_pool: RemoteWorkerPool, + worker_pool: Arc>, timeout: Duration, -) -> Result { +) -> Result +where + Spec: waymark_worker_process_spec::Spec + Send + Sync + 'static, + waymark_worker_remote_pool::RemoteWorkerPool: waymark_worker_core::BaseWorkerPool, +{ let queue = Arc::new(Mutex::new(VecDeque::new())); let backend = MemoryBackend::with_queue(queue); @@ -507,12 +542,16 @@ async fn run_case_in_memory( ))) } -async fn run_case_postgres( +async fn run_case_postgres( prepared: &PreparedCase, backend: &PostgresBackend, - worker_pool: RemoteWorkerPool, + worker_pool: Arc>, timeout: Duration, -) -> Result { +) -> Result +where + Spec: waymark_worker_process_spec::Spec + Send + Sync + 'static, + waymark_worker_remote_pool::RemoteWorkerPool: waymark_worker_core::BaseWorkerPool, +{ backend .clear_all() .await @@ -573,13 +612,19 @@ async fn run_case_postgres( ))) } -async fn run_runloop(worker_pool: RemoteWorkerPool, backend: B, timeout: Duration) -> Result<()> +async fn run_runloop( + worker_pool: Arc>, + backend: Backend, + timeout: Duration, +) -> Result<()> where - B: CoreBackend + WorkflowRegistryBackend + Clone + Send + Sync + 'static, - ::PollQueuedInstancesError: Send + Sync + 'static, - ::PollQueuedInstancesError: core::error::Error, + Backend: CoreBackend + WorkflowRegistryBackend + Clone + Send + Sync + 'static, + ::PollQueuedInstancesError: Send + Sync + 'static, + ::PollQueuedInstancesError: core::error::Error, + Spec: waymark_worker_process_spec::Spec + Send + Sync + 'static, + waymark_worker_remote_pool::RemoteWorkerPool: waymark_worker_core::BaseWorkerPool, { - let runloop = RunLoop::new( + let runloop = RunLoop::, _, _>::new( worker_pool, backend, RunLoopConfig { diff --git a/crates/bin/smoke/Cargo.toml b/crates/bin/smoke/Cargo.toml index 18999aae..0b4aa5c2 100644 --- a/crates/bin/smoke/Cargo.toml +++ b/crates/bin/smoke/Cargo.toml @@ -17,7 +17,10 @@ waymark-proto = { workspace = true, features = ["serde", "client", "server"] } waymark-runloop = { workspace = true } waymark-runner-state = { workspace = true } waymark-smoke-sources = { workspace = true } -waymark-worker-remote = { workspace = true } +waymark-worker-process-spec = { workspace = true } +waymark-worker-python = { workspace = true } +waymark-worker-remote-bringup = { workspace = true } +waymark-worker-remote-pool = { workspace = true } waymark-workflow-registry-backend = { workspace = true } anyhow = { workspace = true } diff --git a/crates/bin/smoke/src/main.rs b/crates/bin/smoke/src/main.rs index a8700036..de82dcec 100644 --- a/crates/bin/smoke/src/main.rs +++ b/crates/bin/smoke/src/main.rs @@ -26,7 +26,6 @@ use waymark_smoke_sources::{ build_control_flow_program, build_parallel_spread_program, build_program, build_try_except_program, build_while_loop_program, }; -use waymark_worker_remote::{PythonWorkerConfig, RemoteWorkerPool}; #[derive(Parser, Debug)] #[command( @@ -68,7 +67,13 @@ fn slugify(name: &str) -> String { .collect() } -async fn run_program_smoke(case: &SmokeCase, worker_pool: RemoteWorkerPool) -> Result<()> { +async fn run_program_smoke( + case: &SmokeCase, + worker_pool: Arc>, +) -> Result<()> +where + Spec: waymark_worker_process_spec::Spec + Send + Sync + 'static, +{ println!("\nIR program ({})", case.name); println!("{}", format_program(&case.program)); println!("IR inputs ({}): {:?}", case.name, case.inputs); @@ -114,7 +119,7 @@ async fn run_program_smoke(case: &SmokeCase, worker_pool: RemoteWorkerPool) -> R .queue_template_node(&entry_node, None) .map_err(|err| anyhow!(err.0))?; - let runloop = RunLoop::new( + let runloop = RunLoop::, _, _>::new( worker_pool, backend.clone(), RunLoopConfig { @@ -159,14 +164,31 @@ async fn run_program_smoke(case: &SmokeCase, worker_pool: RemoteWorkerPool) -> R } async fn run_smoke(base: i64) -> i32 { - let config = PythonWorkerConfig::new().with_user_module("tests.fixtures.test_actions"); - let worker_pool = match RemoteWorkerPool::new_with_config(config, 2, None, None, 10).await { - Ok(pool) => pool, + let config = + waymark_worker_python::Config::new().with_user_module("tests.fixtures.test_actions"); + + let result = waymark_worker_remote_bringup::start( + Default::default(), + None, + |bridge_server_addr| waymark_worker_python::Spec { + config, + bridge_server_addr, + }, + 2.try_into().unwrap(), + None, + 10.try_into().unwrap(), + ) + .await; + + let (process_pool, bridge_server_task) = match result { + Ok(val) => val, Err(err) => { println!("Failed to start python worker pool: {err}"); return 1; } }; + let worker_pool = waymark_worker_remote_pool::RemoteWorkerPool::new(process_pool); + let worker_pool = Arc::new(worker_pool); let mut cases = Vec::new(); cases.push(SmokeCase { @@ -213,7 +235,10 @@ async fn run_smoke(base: i64) -> i32 { } } - if let Err(err) = worker_pool.shutdown().await { + bridge_server_task.abort(); + let _ = bridge_server_task.await; + + if let Err(err) = worker_pool.shutdown_arc().await { println!("Failed to shut down worker pool: {err}"); } diff --git a/crates/bin/start-workers/Cargo.toml b/crates/bin/start-workers/Cargo.toml index 7f131536..90907cc2 100644 --- a/crates/bin/start-workers/Cargo.toml +++ b/crates/bin/start-workers/Cargo.toml @@ -19,7 +19,9 @@ waymark-runloop = { workspace = true } waymark-scheduler-loop = { workspace = true } waymark-scheduler-loop-core = { workspace = true } waymark-webapp-bringup = { workspace = true } -waymark-worker-remote = { workspace = true } +waymark-worker-python = { workspace = true } +waymark-worker-remote-bringup = { workspace = true } +waymark-worker-remote-pool = { workspace = true } waymark-worker-status-reporter = { workspace = true } anyhow = { workspace = true } diff --git a/crates/bin/start-workers/src/main.rs b/crates/bin/start-workers/src/main.rs index 5de72654..11e4e7d2 100644 --- a/crates/bin/start-workers/src/main.rs +++ b/crates/bin/start-workers/src/main.rs @@ -51,7 +51,6 @@ use waymark_nonzero_duration::NonZeroDuration; use waymark_proto::ast as ir; use waymark_runloop::RunLoopConfig; use waymark_scheduler_loop_core::WorkflowDag; -use waymark_worker_remote::{PythonWorkerConfig, RemoteWorkerPool}; #[tokio::main] async fn main() -> Result<()> { @@ -91,24 +90,27 @@ async fn main() -> Result<()> { let backend = PostgresBackend::new(pool); // Start the worker pool (bridge + python workers). - let mut worker_config = PythonWorkerConfig::new(); + let mut worker_config = waymark_worker_python::Config::new(); if !config.user_modules.is_empty() { worker_config = worker_config.with_user_modules(config.user_modules.clone()); } - let remote_pool = RemoteWorkerPool::new_with_config( - worker_config, - config.worker_count.get(), + let worker_process_spec_builder = |bridge_server_addr| waymark_worker_python::Spec { + bridge_server_addr, + config: worker_config, + }; + + let (process_pool, bridge_task) = waymark_worker_remote_bringup::start( + shutdown_token.clone(), Some(config.worker_grpc_addr), - config.max_action_lifecycle.map(|val| val.get()), - config.concurrent_per_worker.get(), + worker_process_spec_builder, + config.worker_count, + config.max_action_lifecycle, + config.concurrent_per_worker, ) .await?; - info!( - count = config.worker_count, - bridge_addr = %remote_pool.bridge_addr(), - "python worker pool started" - ); + + let process_pool = Arc::new(process_pool); // Start the webapp server. let webapp_backend = Arc::new(backend.clone()); @@ -160,7 +162,7 @@ async fn main() -> Result<()> { let status_reporter_handle = tokio::spawn(waymark_worker_status_reporter::run( pool_id, backend.clone(), - remote_pool.clone(), + process_pool.clone(), active_instance_gauge.clone(), config.profile_interval, shutdown_token.clone().cancelled_owned(), @@ -185,9 +187,10 @@ async fn main() -> Result<()> { }); // Run the runloop. + let remote_pool = waymark_worker_remote_pool::RemoteWorkerPool::new(process_pool.clone()); let lock_uuid = LockId::new_uuid_v4(); let runloop = waymark_runloop::RunLoop::new_with_shutdown( - remote_pool.clone(), + remote_pool, backend.clone(), RunLoopConfig { max_concurrent_instances: config.max_concurrent_instances, @@ -218,12 +221,13 @@ async fn main() -> Result<()> { } let _ = shutdown_handle.await; + let _ = tokio::time::timeout(Duration::from_secs(5), bridge_task).await; let _ = tokio::time::timeout(Duration::from_secs(5), scheduler_handle).await; let _ = tokio::time::timeout(Duration::from_secs(5), garbage_collector_handle).await; let _ = tokio::time::timeout(Duration::from_secs(2), status_reporter_handle).await; let _ = tokio::time::timeout(Duration::from_secs(2), expired_lock_reclaimer_handle).await; - if let Err(err) = remote_pool.shutdown().await { + if let Err(err) = process_pool.shutdown_arc().await { warn!(error = %err, "worker pool shutdown failed"); } diff --git a/crates/lib/reserved-process/Cargo.toml b/crates/lib/reserved-process/Cargo.toml new file mode 100644 index 00000000..1dbbf3f0 --- /dev/null +++ b/crates/lib/reserved-process/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "waymark-reserved-process" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-managed-process = { workspace = true } +waymark-worker-reservation = { workspace = true } + +thiserror = { workspace = true } +tokio = { workspace = true, features = ["process"] } +tracing = { workspace = true } diff --git a/crates/lib/reserved-process/src/lib.rs b/crates/lib/reserved-process/src/lib.rs new file mode 100644 index 00000000..a5d0c722 --- /dev/null +++ b/crates/lib/reserved-process/src/lib.rs @@ -0,0 +1,65 @@ +use std::time::Duration; + +pub struct SpawnParams { + pub command: tokio::process::Command, + pub wait_for_playload_timeout: Duration, + pub graceful_shutdown_timeout: Duration, + pub kill_timeout: Duration, +} + +#[derive(Debug, thiserror::Error)] +pub enum SpawnError { + #[error("spawn: {0}")] + Spawn(std::io::Error), + + #[error("timed out after waiting for worker to attach")] + ReservationWaitTimeout { timeout: Duration }, + + #[error("reservation wait: {0}")] + ReservationWaitError(waymark_worker_reservation::ReservationCancelledError), +} + +pub async fn spawn( + reservation: waymark_worker_reservation::Reservation, + params: SpawnParams, +) -> Result<(waymark_managed_process::Child, Payload), SpawnError> { + let SpawnParams { + wait_for_playload_timeout, + command, + graceful_shutdown_timeout, + kill_timeout, + } = params; + + // Spawn the process. + let child = waymark_managed_process::spawn(command).map_err(SpawnError::Spawn)?; + + // Wait for the worker to connect (with timeout). + let result = tokio::time::timeout(wait_for_playload_timeout, reservation.wait()).await; + let result = match result { + Ok(Ok(channels)) => Ok(channels), + Ok(Err(err)) => Err(SpawnError::ReservationWaitError(err)), + Err(tokio::time::error::Elapsed { .. }) => Err(SpawnError::ReservationWaitTimeout { + timeout: wait_for_playload_timeout, + }), + }; + + // Pass the payload or terminate the process gracefully. + let payload = match result { + Ok(channels) => channels, + Err(error) => { + let shutdown_result = child + .shutdown(graceful_shutdown_timeout, kill_timeout) + .await; + tracing::debug!( + ?error, + ?shutdown_result, + "reserved process shut down after failed payload wait" + ); + return Err(error); + } + }; + + tracing::info!("reserved process connected"); + + Ok((child, payload)) +} diff --git a/crates/lib/worker-message-protocol/Cargo.toml b/crates/lib/worker-message-protocol/Cargo.toml new file mode 100644 index 00000000..706c0476 --- /dev/null +++ b/crates/lib/worker-message-protocol/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "waymark-worker-message-protocol" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-proto = { workspace = true } +waymark-worker-metrics = { workspace = true } + +prost = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["rt"] } +tracing = { workspace = true } +uuid = { workspace = true } diff --git a/crates/lib/worker-message-protocol/src/lib.rs b/crates/lib/worker-message-protocol/src/lib.rs new file mode 100644 index 00000000..192766b0 --- /dev/null +++ b/crates/lib/worker-message-protocol/src/lib.rs @@ -0,0 +1,291 @@ +use std::{ + collections::HashMap, + sync::{Arc, atomic::AtomicU64}, +}; + +use prost::Message as _; +use uuid::Uuid; +use waymark_proto::messages as proto; +use waymark_worker_metrics::RoundTripMetrics; + +/// Channels for communicating with a connected worker. +pub struct Channels { + /// Send actions to the worker + pub to_worker: tokio::sync::mpsc::Sender, + + /// Receive results from the worker + pub from_worker: tokio::sync::mpsc::Receiver, +} + +/// Internal state shared between worker sender and reader tasks. +#[derive(Debug, Default)] +struct SharedState { + /// Pending ACK receivers, keyed by delivery_id + pub pending_acks: HashMap>, + + /// Pending result receivers, keyed by delivery_id + pub pending_responses: + HashMap>, +} + +pub struct Sender { + to_worker: tokio::sync::mpsc::Sender, + next_delivery: AtomicU64, + shared: Arc>, +} + +pub fn setup(channels: Channels) -> (Sender, impl Future) { + let Channels { + to_worker, + from_worker, + } = channels; + + // Set up shared state and spawn reader task + let shared = Arc::new(tokio::sync::Mutex::new(SharedState::default())); + let loop_fut = { + let shared = Arc::clone(&shared); + async move { + if let Err(err) = r#loop(from_worker, shared).await { + tracing::error!(?err, "worker message protocol loop exited"); + } + } + }; + + let sender = Sender { + to_worker, + shared, + next_delivery: AtomicU64::new(1), + }; + + (sender, loop_fut) +} + +/// Errors that can occur when sending or receiving a message. +#[derive(Debug, thiserror::Error)] +pub enum MessageError { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("failed to decode message: {0}")] + Decode(#[from] prost::DecodeError), + + #[error("failed to encode message: {0}")] + Encode(#[from] prost::EncodeError), + + #[error("channel closed")] + ChannelClosed, +} + +/// Errors that can occur when sending or receiving a message. +#[derive(Debug, thiserror::Error)] +pub enum SendActionError { + #[error("channel closed")] + ChannelClosed, +} + +/// Background task that reads messages from the worker. +async fn r#loop( + mut from_worker: tokio::sync::mpsc::Receiver, + shared: Arc>, +) -> Result<(), MessageError> { + while let Some(envelope) = from_worker.recv().await { + let kind = + proto::MessageKind::try_from(envelope.kind).unwrap_or(proto::MessageKind::Unspecified); + + match kind { + proto::MessageKind::Ack => { + let ack = proto::Ack::decode(envelope.payload.as_slice())?; + let mut guard = shared.lock().await; + if let Some(sender) = guard.pending_acks.remove(&ack.acked_delivery_id) { + let _ = sender.send(std::time::Instant::now()); + } else { + tracing::warn!(delivery = ack.acked_delivery_id, "unexpected ACK"); + } + } + proto::MessageKind::ActionResult => { + let response = proto::ActionResult::decode(envelope.payload.as_slice())?; + let mut guard = shared.lock().await; + if let Some(sender) = guard.pending_responses.remove(&envelope.delivery_id) { + let _ = sender.send((response, std::time::Instant::now())); + } else { + tracing::warn!(delivery = envelope.delivery_id, "orphan response"); + } + } + proto::MessageKind::Heartbeat => { + tracing::trace!(delivery = envelope.delivery_id, "heartbeat"); + } + other => { + tracing::warn!(?other, "unhandled message kind"); + } + } + } + + Ok(()) +} + +/// Payload for dispatching an action to a worker. +#[derive(Debug, Clone)] +pub struct ActionDispatchPayload { + /// Unique action identifier + pub action_id: String, + + /// Workflow instance this action belongs to + pub instance_id: String, + + /// Sequence number within the instance + pub sequence: u32, + + /// Name of the action function to call + pub action_name: String, + + /// Python module containing the action + pub module_name: String, + + /// Keyword arguments for the action + pub kwargs: proto::WorkflowArguments, + + /// Timeout in seconds (0 = no timeout) + pub timeout_seconds: u32, + + /// Maximum retry attempts + pub max_retries: u32, + + /// Current attempt number + pub attempt_number: u32, + + /// Dispatch token for correlation + pub dispatch_token: Uuid, +} + +impl Sender { + /// Send an action to the worker and wait for the result. + /// + /// This method: + /// 1. Allocates a delivery ID + /// 2. Creates channels for ACK and response + /// 3. Sends the action dispatch + /// 4. Waits for ACK (immediate) + /// 5. Waits for result (after execution) + /// 6. Returns metrics including latencies + /// + /// # Errors + /// + /// Returns an error if: + /// - The worker channel is closed (worker crashed) + /// - Response decoding fails + pub async fn send_action( + &self, + dispatch: ActionDispatchPayload, + ) -> Result { + let delivery_id = self + .next_delivery + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let send_instant = std::time::Instant::now(); + + tracing::trace!( + action_id = %dispatch.action_id, + instance_id = %dispatch.instance_id, + sequence = dispatch.sequence, + module = %dispatch.module_name, + function = %dispatch.action_name, + delivery_id, + "sending action to worker" + ); + + // Create channels for receiving ACK and response + let (ack_tx, ack_rx) = tokio::sync::oneshot::channel(); + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + + // Register pending requests + { + let mut shared = self.shared.lock().await; + shared.pending_acks.insert(delivery_id, ack_tx); + shared.pending_responses.insert(delivery_id, response_tx); + } + + // Build and send the dispatch envelope + let command = proto::ActionDispatch { + action_id: dispatch.action_id.clone(), + instance_id: dispatch.instance_id.clone(), + sequence: dispatch.sequence, + action_name: dispatch.action_name.clone(), + module_name: dispatch.module_name.clone(), + kwargs: Some(dispatch.kwargs.clone()), + timeout_seconds: Some(dispatch.timeout_seconds), + max_retries: Some(dispatch.max_retries), + attempt_number: Some(dispatch.attempt_number), + dispatch_token: Some(dispatch.dispatch_token.to_string()), + }; + + let envelope = proto::Envelope { + delivery_id, + partition_id: 0, + kind: proto::MessageKind::ActionDispatch as i32, + payload: command.encode_to_vec(), + }; + + self.send_envelope(envelope) + .await + .map_err(|_| SendActionError::ChannelClosed)?; + + // Wait for ACK (should be immediate) + let ack_instant = ack_rx.await.map_err(|_| SendActionError::ChannelClosed)?; + + // Wait for the actual response (after execution) + let (response, response_instant) = response_rx + .await + .map_err(|_| SendActionError::ChannelClosed)?; + + // Calculate metrics + let ack_latency = ack_instant + .checked_duration_since(send_instant) + .unwrap_or_default(); + let round_trip = response_instant + .checked_duration_since(send_instant) + .unwrap_or_default(); + let worker_duration = std::time::Duration::from_nanos( + response + .worker_end_ns + .saturating_sub(response.worker_start_ns), + ); + + tracing::trace!( + action_id = %dispatch.action_id, + ack_latency_us = ack_latency.as_micros(), + round_trip_ms = round_trip.as_millis(), + worker_duration_ms = worker_duration.as_millis(), + success = response.success, + "action completed" + ); + + Ok(RoundTripMetrics { + action_id: dispatch.action_id, + instance_id: dispatch.instance_id, + delivery_id, + sequence: dispatch.sequence, + ack_latency, + round_trip, + worker_duration, + response_payload: response + .payload + .as_ref() + .map(|payload| payload.encode_to_vec()) + .unwrap_or_default(), + success: response.success, + dispatch_token: response + .dispatch_token + .as_ref() + .and_then(|token| Uuid::parse_str(token).ok()), + error_type: response.error_type, + error_message: response.error_message, + }) + } + + /// Send an envelope to the worker. + async fn send_envelope( + &self, + envelope: proto::Envelope, + ) -> Result<(), tokio::sync::mpsc::error::SendError> { + self.to_worker.send(envelope).await + } +} diff --git a/crates/lib/worker-process-pool/Cargo.toml b/crates/lib/worker-process-pool/Cargo.toml new file mode 100644 index 00000000..eadbf05b --- /dev/null +++ b/crates/lib/worker-process-pool/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "waymark-worker-process-pool" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-managed-process = { workspace = true } +waymark-reserved-process = { workspace = true } +waymark-worker-message-protocol = { workspace = true } +waymark-worker-metrics = { workspace = true } +waymark-worker-process = { workspace = true } +waymark-worker-process-spec = { workspace = true } +waymark-worker-reservation = { workspace = true } +waymark-worker-status-core = { workspace = true } + +nonempty-collections = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["process"] } +tracing = { workspace = true } diff --git a/crates/lib/worker-process-pool/src/lib.rs b/crates/lib/worker-process-pool/src/lib.rs new file mode 100644 index 00000000..ca0b5bdf --- /dev/null +++ b/crates/lib/worker-process-pool/src/lib.rs @@ -0,0 +1,521 @@ +use std::{ + num::{NonZeroU64, NonZeroUsize}, + sync::{ + Arc, Mutex as StdMutex, + atomic::{AtomicU64, AtomicUsize, Ordering}, + }, + time::{Duration, Instant}, +}; + +use nonempty_collections::NEVec; +use tokio::sync::RwLock; + +use tracing::{error, info, trace, warn}; + +use waymark_worker_metrics::{WorkerPoolMetrics, WorkerThroughputSnapshot}; + +const LATENCY_SAMPLE_SIZE: usize = 256; +const THROUGHPUT_WINDOW_SECS: u64 = 1; + +type Registry = waymark_worker_reservation::Registry; + +type WorkerId = u64; + +pub struct WorkerState { + pub handle: waymark_worker_process::Handle, + pub sender: waymark_worker_message_protocol::Sender, + pub id: WorkerId, +} + +pub struct Pool { + /// The spec for the worker processes. + worker_process_spec: Spec, + + /// The registry of the connecting workers. + workers_registry: Arc, + + // Worker ID sequence. + worker_id_sequence: AtomicU64, + + /// The workers in the pool (RwLock for recycling support) + worker_processes: RwLock>>, + + /// Cursor for round-robin selection + cursor: AtomicUsize, + + /// Shared metrics tracker for throughput + latency. + metrics: StdMutex, + + /// Action counts per worker slot (for lifecycle tracking) + action_counts: NEVec, + + /// In-flight action counts per worker slot (for concurrency control) + in_flight_counts: NEVec, + + /// Maximum concurrent actions per worker + max_concurrent_per_worker: NonZeroUsize, + + /// Maximum actions per worker before recycling (None = no limit) + max_action_lifecycle: Option, +} + +#[derive(Debug, thiserror::Error)] +pub enum InitError { + #[error("unable to spawn worker with index {worker_index}: {error}")] + WorkerSpawn { + error: waymark_reserved_process::SpawnError, + worker_index: usize, + }, +} + +impl Pool +where + Spec: waymark_worker_process_spec::Spec, +{ + /// Create a new worker pool with explicit concurrency limit. + pub async fn new_with_concurrency( + workers_registry: Arc, + worker_process_spec: Spec, + worker_count: NonZeroUsize, + max_action_lifecycle: Option, + max_concurrent_per_worker: NonZeroUsize, + ) -> Result { + info!( + count = worker_count, + max_action_lifecycle = ?max_action_lifecycle, + "spawning python worker pool" + ); + + // Spawn all workers in parallel to reduce boot time. + let spawn_results: Vec<_> = { + let workers_registry = &workers_registry; + (0..worker_count.get()) + .map(|_| { + let reservation = workers_registry.reserve(); + let params = worker_process_spec.prepare_spawn_params(reservation.id()); + tokio::spawn(waymark_worker_process::spawn(reservation, params)) + }) + .collect() + }; + + let mut workers = Vec::with_capacity(worker_count.get()); + let mut worker_id_sequence = 0; + for (worker_index, handle) in spawn_results.into_iter().enumerate() { + let result = handle.await.unwrap(); // propagate panics + match result { + Ok((handle, sender)) => { + workers.push(Arc::new(WorkerState { + handle, + sender, + id: worker_id_sequence, + })); + worker_id_sequence += 1; + } + Err(error) => { + warn!( + worker_index, + ?error, + "failed to spawn worker, cleaning up {} already spawned", + workers.len() + ); + for worker in workers { + if let Ok(worker) = Arc::try_unwrap(worker) { + let _ = worker.handle.shutdown().await; + } + } + return Err(InitError::WorkerSpawn { + error, + worker_index, + }); + } + } + } + + info!(count = workers.len(), "worker pool ready"); + + let worker_id_sequence = AtomicU64::new(worker_id_sequence); + let worker_ids = workers.iter().map(|worker| worker.id).collect(); + let action_counts = nevec_fn(worker_count, |_| AtomicU64::new(0)); + let in_flight_counts = nevec_fn(worker_count, |_| AtomicUsize::new(0)); + Ok(Self { + worker_process_spec, + workers_registry, + worker_id_sequence, + worker_processes: RwLock::new(workers), + cursor: AtomicUsize::new(0), + metrics: StdMutex::new(WorkerPoolMetrics::new( + worker_ids, + Duration::from_secs(THROUGHPUT_WINDOW_SECS), + LATENCY_SAMPLE_SIZE, + )), + action_counts, + in_flight_counts, + max_concurrent_per_worker, + max_action_lifecycle, + }) + } +} + +fn nevec_fn(items: NonZeroUsize, mut f: impl FnMut(usize) -> T) -> NEVec { + let mut vec = NEVec::with_capacity(items, f(0)); + for index in 1..items.get() { + vec.push(f(index)); + } + vec +} + +impl Pool { + /// Get a worker by index. + /// + /// Returns a clone of the Arc for the worker at the given index. + pub async fn get_worker(&self, idx: usize) -> Arc { + let worker_processes = self.worker_processes.read().await; + Arc::clone(&worker_processes[idx % worker_processes.len()]) + } + + /// Get the next worker index using round-robin selection. + /// + /// This is lock-free and O(1). Returns the index that can be used + /// with `get_worker` to fetch the actual worker. + pub fn next_worker_idx(&self) -> usize { + self.cursor.fetch_add(1, Ordering::Relaxed) + } + + /// Get the number of workers in the pool. + pub fn len(&self) -> NonZeroUsize { + self.action_counts.len() + } + + /// Get the maximum concurrent actions per worker. + pub fn max_concurrent_per_worker(&self) -> NonZeroUsize { + self.max_concurrent_per_worker + } + + /// Get total capacity (worker_count * max_concurrent_per_worker). + pub fn total_capacity(&self) -> NonZeroUsize { + self.len().saturating_mul(self.max_concurrent_per_worker) + } + + /// Get total in-flight actions across all workers. + pub fn total_in_flight(&self) -> usize { + self.in_flight_counts + .iter() + .map(|c| c.load(Ordering::Relaxed)) + .sum() + } + + /// Get available capacity (total_capacity - total_in_flight). + pub fn available_capacity(&self) -> usize { + self.total_capacity() + .get() + .saturating_sub(self.total_in_flight()) + } + + /// Try to acquire a slot for the next available worker. + /// + /// Returns `Some(worker_idx)` if a slot was acquired, `None` if all workers + /// are at capacity. Uses round-robin selection among workers with capacity. + pub fn try_acquire_slot(&self) -> Option { + let worker_count = self.len(); + + // Try each worker starting from the current cursor position + let start = self.cursor.fetch_add(1, Ordering::Relaxed); + for i in 0..worker_count.get() { + let idx = (start + i) % worker_count; + if self.try_acquire_slot_for_worker(idx) { + return Some(idx); + } + } + None + } + + /// Try to acquire a slot for a specific worker. + /// + /// Returns `true` if the slot was acquired, `false` if the worker is at capacity. + pub fn try_acquire_slot_for_worker(&self, worker_idx: usize) -> bool { + let Some(counter) = self.in_flight_counts.get(worker_idx % self.len()) else { + return false; + }; + + // CAS loop to atomically increment if below limit + loop { + let current = counter.load(Ordering::Acquire); + if current >= self.max_concurrent_per_worker.get() { + return false; + } + match counter.compare_exchange_weak( + current, + current + 1, + Ordering::AcqRel, + Ordering::Relaxed, + ) { + Ok(_) => return true, + Err(_) => continue, // Retry + } + } + } + + /// Release a slot for a worker. + /// + /// Should be called when an action completes (via `record_completion`). + pub fn release_slot(&self, worker_idx: usize) { + if let Some(counter) = self.in_flight_counts.get(worker_idx % self.len()) { + // Saturating sub to avoid underflow in case of bugs + let prev = counter.fetch_sub(1, Ordering::Release); + if prev == 0 { + warn!(worker_idx, "release_slot called with zero in-flight count"); + counter.store(0, Ordering::Release); + } + } + } + + /// Get in-flight count for a specific worker. + pub fn in_flight_for_worker(&self, worker_idx: usize) -> usize { + self.in_flight_counts + .get(worker_idx % self.len()) + .map(|c| c.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + /// Get a snapshot of all workers in the pool. + pub async fn workers_snapshot(&self) -> Vec> { + self.worker_processes.read().await.clone() + } + + /// Get throughput snapshots for all workers. + /// + /// Returns worker throughput metrics including completion counts and rates. + pub fn throughput_snapshots(&self) -> Vec { + if let Ok(mut metrics) = self.metrics.lock() { + metrics.throughput_snapshots(Instant::now()) + } else { + Vec::new() + } + } + + /// Record the latest latency measurements for median reporting. + pub fn record_latency(&self, ack_latency: Duration, worker_duration: Duration) { + if let Ok(mut metrics) = self.metrics.lock() { + metrics.record_latency(ack_latency, worker_duration); + } + } + + /// Return the current median dequeue/handling latencies in milliseconds. + pub fn median_latencies_ms(&self) -> (Option, Option) { + if let Ok(metrics) = self.metrics.lock() { + metrics.median_latencies_ms() + } else { + (None, None) + } + } + + /// Get queue statistics: (dispatch_queue_size, total_in_flight). + pub fn queue_stats(&self) -> (usize, usize) { + let total_in_flight: usize = self + .in_flight_counts + .iter() + .map(|c| c.load(Ordering::Relaxed)) + .sum(); + // dispatch_queue_size would require access to the bridge's queue + // For now, return 0 as placeholder + (0, total_in_flight) + } +} + +impl Pool +where + Spec: waymark_worker_process_spec::Spec, +{ + /// Record an action completion for a worker and trigger recycling if needed. + /// + /// This decrements the in-flight count and increments the action count for + /// the worker at the given index. If `max_action_lifecycle` is set and the + /// count reaches or exceeds the threshold, a background task is spawned to + /// recycle the worker. + pub fn record_completion(&self, worker_idx: usize, pool: Arc) + where + Spec: Send + Sync + 'static, + { + // Release the in-flight slot + self.release_slot(worker_idx); + + // Update throughput tracking + if let Ok(mut metrics) = self.metrics.lock() { + metrics.record_completion(worker_idx); + if tracing::enabled!(tracing::Level::TRACE) { + let snapshots = metrics.throughput_snapshots(Instant::now()); + if let Some(snapshot) = snapshots.get(worker_idx) { + trace!( + worker_id = snapshot.worker_id, + throughput_per_min = snapshot.throughput_per_min, + total_completed = snapshot.total_completed, + last_action_at = ?snapshot.last_action_at, + "worker throughput snapshot" + ); + } + } + } + + // Increment action count + if let Some(counter) = self.action_counts.get(worker_idx) { + let new_count = counter.fetch_add(1, Ordering::SeqCst) + 1; + + // Check if recycling is needed + if let Some(max_lifecycle) = self.max_action_lifecycle + && new_count >= max_lifecycle.get() + { + info!( + worker_idx, + action_count = new_count, + max_lifecycle, + "worker reached action lifecycle limit, scheduling recycle" + ); + // Spawn a background task to recycle this worker + tokio::spawn(async move { + if let Err(err) = pool.recycle_worker(worker_idx).await { + error!(worker_idx, ?err, "failed to recycle worker"); + } + }); + } + } + } + + /// Recycle a worker at the given index. + /// + /// Spawns a new worker and replaces the old one. The old worker + /// will be shut down once all in-flight actions complete (when + /// its Arc reference count drops to zero). + async fn recycle_worker( + &self, + worker_idx: usize, + ) -> Result<(), waymark_reserved_process::SpawnError> { + // Spawn the replacement worker first + let reservation = self.workers_registry.reserve(); + let params = self + .worker_process_spec + .prepare_spawn_params(reservation.id()); + let (handle, sender) = waymark_worker_process::spawn(reservation, params).await?; + let new_worker_id = self.worker_id_sequence.fetch_add(1, Ordering::Relaxed); + let new_worker = WorkerState { + handle, + sender, + id: new_worker_id, + }; + + // Replace the worker in the pool + let old_worker = { + let mut worker_processes = self.worker_processes.write().await; + let idx = worker_idx % worker_processes.len(); + std::mem::replace(&mut worker_processes[idx], Arc::new(new_worker)) + }; + + // Reset the action count for this slot + if let Some(counter) = self + .action_counts + .get(worker_idx % self.action_counts.len()) + { + counter.store(0, Ordering::SeqCst); + } + + // Update throughput tracker with new worker ID + if let Ok(mut metrics) = self.metrics.lock() { + metrics.reset_worker(worker_idx, new_worker_id); + } + + info!( + worker_idx, + old_worker_id = old_worker.id, + new_worker_id, + "recycled worker" + ); + + // The old worker will be cleaned up when its Arc drops + // (once all in-flight actions complete) + + Ok(()) + } +} + +impl Pool { + /// Get the current action count for a worker slot. + /// + /// Returns the number of actions that have been completed by the worker + /// at the given index since it was last spawned/recycled. + #[cfg(test)] + #[allow(dead_code)] + pub(crate) fn get_action_count(&self, worker_idx: usize) -> u64 { + self.action_counts + .get(worker_idx) + .map(|c| c.load(Ordering::SeqCst)) + .unwrap_or(0) + } + + /// Get the maximum action lifecycle setting. + #[cfg(test)] + #[allow(dead_code)] + pub(crate) fn max_lifecycle(&self) -> Option { + self.max_action_lifecycle + } + + /// Gracefully shut down all workers in the pool. + /// + /// Workers are shut down in order. Any workers still in use + /// (shared references exist) are skipped with a warning. + pub async fn shutdown(self) -> Result<(), waymark_managed_process::ShutdownError> { + let workers = self.worker_processes.into_inner(); + info!(count = workers.len(), "shutting down worker pool"); + + for worker in workers { + match Arc::try_unwrap(worker) { + Ok(worker) => { + worker.handle.shutdown().await?; + } + Err(arc) => { + warn!( + worker_id = arc.id, + "worker still in use during shutdown; skipping" + ); + } + } + } + + info!("worker pool shutdown complete"); + Ok(()) + } + + /// Unwrap an [`Arc`] and gracefully shut down all workers in the pool. + /// + /// See [`Pool::shutdown`]. + pub async fn shutdown_arc( + self: Arc, + ) -> Result<(), waymark_managed_process::ShutdownError> { + let Some(pool) = Arc::into_inner(self) else { + warn!("worker pool still referenced during shutdown; skipping shutdown"); + return Ok(()); + }; + pool.shutdown().await + } +} + +impl waymark_worker_status_core::WorkerPoolStats for Pool { + fn stats_snapshot(&self) -> waymark_worker_status_core::WorkerPoolStatsSnapshot { + let snapshots = self.throughput_snapshots(); + let active_workers = snapshots.len() as u16; + let throughput_per_min: f64 = snapshots.iter().map(|s| s.throughput_per_min).sum(); + let total_completed: i64 = snapshots.iter().map(|s| s.total_completed as i64).sum(); + let last_action_at = snapshots.iter().filter_map(|s| s.last_action_at).max(); + let (dispatch_queue_size, total_in_flight) = self.queue_stats(); + let (median_dequeue_ms, median_handling_ms) = self.median_latencies_ms(); + + waymark_worker_status_core::WorkerPoolStatsSnapshot { + active_workers, + throughput_per_min, + total_completed, + last_action_at, + dispatch_queue_size, + total_in_flight, + median_dequeue_ms, + median_handling_ms, + } + } +} diff --git a/crates/lib/worker-process-spec/Cargo.toml b/crates/lib/worker-process-spec/Cargo.toml new file mode 100644 index 00000000..9426209c --- /dev/null +++ b/crates/lib/worker-process-spec/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "waymark-worker-process-spec" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-reserved-process = { workspace = true } +waymark-worker-reservation = { workspace = true } diff --git a/crates/lib/worker-process-spec/src/lib.rs b/crates/lib/worker-process-spec/src/lib.rs new file mode 100644 index 00000000..8451102f --- /dev/null +++ b/crates/lib/worker-process-spec/src/lib.rs @@ -0,0 +1,6 @@ +pub trait Spec { + fn prepare_spawn_params( + &self, + reservation_id: waymark_worker_reservation::Id, + ) -> waymark_reserved_process::SpawnParams; +} diff --git a/crates/lib/worker-process/Cargo.toml b/crates/lib/worker-process/Cargo.toml new file mode 100644 index 00000000..289a8d90 --- /dev/null +++ b/crates/lib/worker-process/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "waymark-worker-process" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-managed-process = { workspace = true } +waymark-reserved-process = { workspace = true } +waymark-worker-message-protocol = { workspace = true } +waymark-worker-reservation = { workspace = true } + +tokio = { workspace = true, features = ["rt"] } +tracing = { workspace = true } diff --git a/crates/lib/worker-process/src/lib.rs b/crates/lib/worker-process/src/lib.rs new file mode 100644 index 00000000..d7eb2fb7 --- /dev/null +++ b/crates/lib/worker-process/src/lib.rs @@ -0,0 +1,69 @@ +pub type Reservation = + waymark_worker_reservation::Reservation; + +pub async fn spawn( + reservation: Reservation, + params: waymark_reserved_process::SpawnParams, +) -> Result<(Handle, waymark_worker_message_protocol::Sender), waymark_reserved_process::SpawnError> +{ + let shutdown_params = ShutdownParams { + graceful_shutdown_timeout: params.graceful_shutdown_timeout, + kill_timeout: params.kill_timeout, + }; + + let (child, channels) = waymark_reserved_process::spawn(reservation, params).await?; + + let mut tasks = tokio::task::JoinSet::new(); + + let (sender, fut) = waymark_worker_message_protocol::setup(channels); + + tasks.spawn(fut); + + let handle = Handle { + child: Some(child), + tasks, + shutdown_params, + }; + + Ok((handle, sender)) +} + +struct ShutdownParams { + pub graceful_shutdown_timeout: std::time::Duration, + pub kill_timeout: std::time::Duration, +} + +pub struct Handle { + /// The managed child process. + child: Option, + + /// All of the async tasks managed by this process. + tasks: tokio::task::JoinSet<()>, + + /// The params used for shutdown. + shutdown_params: ShutdownParams, +} + +impl Handle { + /// Gracefully shut down the worker process. + pub async fn shutdown(mut self) -> Result<(), waymark_managed_process::ShutdownError> { + tracing::info!("shutting down worker"); + + // Abort the tasks. + self.tasks.shutdown().await; + + // Shutdown the managed process gracefully. + if let Some(child) = self.child.take() { + let exit_status = child + .shutdown( + self.shutdown_params.graceful_shutdown_timeout, + self.shutdown_params.kill_timeout, + ) + .await?; + tracing::debug!(?exit_status, "worker child process exited"); + } + + tracing::info!("worker shutdown complete"); + Ok(()) + } +} diff --git a/crates/lib/worker-python/Cargo.toml b/crates/lib/worker-python/Cargo.toml new file mode 100644 index 00000000..449bc118 --- /dev/null +++ b/crates/lib/worker-python/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "waymark-worker-python" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-reserved-process = { workspace = true } +waymark-worker-process-spec = { workspace = true } +waymark-worker-reservation = { workspace = true } + +tokio = { workspace = true, features = ["process"] } +tracing = { workspace = true } diff --git a/crates/lib/worker-python/src/config.rs b/crates/lib/worker-python/src/config.rs new file mode 100644 index 00000000..7b82ff5b --- /dev/null +++ b/crates/lib/worker-python/src/config.rs @@ -0,0 +1,133 @@ +use std::path::{Path, PathBuf}; + +/// Configuration for spawning Python workers. +#[derive(Clone, Debug)] +pub struct Config { + /// Path to the script/executable to run (e.g., "uv" or "waymark-worker") + pub script_path: PathBuf, + + /// Arguments to pass before the worker-specific args + pub script_args: Vec, + + /// Python module(s) to preload (contains @action definitions) + pub user_modules: Vec, + + /// Additional paths to add to PYTHONPATH + pub extra_python_paths: Vec, +} + +impl Default for Config { + fn default() -> Self { + let (script_path, script_args) = default_runner(); + Self { + script_path, + script_args, + user_modules: vec![], + extra_python_paths: vec![], + } + } +} + +impl Config { + /// Create a new config with default runner detection. + pub fn new() -> Self { + Self::default() + } + + /// Set the user module to preload. + pub fn with_user_module(mut self, module: &str) -> Self { + self.user_modules = vec![module.to_string()]; + self + } + + /// Set multiple user modules to preload. + pub fn with_user_modules(mut self, modules: Vec) -> Self { + self.user_modules = modules; + self + } + + /// Add extra paths to PYTHONPATH. + pub fn with_python_paths(mut self, paths: Vec) -> Self { + self.extra_python_paths = paths; + self + } +} + +/// Find the default Python runner. +/// Prefers `waymark-worker` if in PATH, otherwise uses `uv run`. +fn default_runner() -> (PathBuf, Vec) { + if let Some(path) = find_executable("waymark-worker") { + return (path, Vec::new()); + } + ( + PathBuf::from("uv"), + vec![ + "run".to_string(), + "python".to_string(), + "-m".to_string(), + "waymark.worker".to_string(), + ], + ) +} + +/// Search PATH for an executable. +fn find_executable(bin: impl AsRef) -> Option { + let path_var = std::env::var_os("PATH")?; + for dir in std::env::split_paths(&path_var) { + let candidate = dir.join(bin.as_ref()); + // BUG: this is blocking disk io, we shouldn't block the runtime + // executor on this. + // BUG: this code doesn't allow symlinks/junctions. + // TODO: rewrite this to do it correctly + if candidate.is_file() { + return Some(candidate); + } + #[cfg(windows)] + { + let exe_candidate = dir.join(bin.as_ref().with_added_extension("exe")); + if exe_candidate.is_file() { + return Some(exe_candidate); + } + } + } + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_builder() { + let config = Config::new() + .with_user_module("my_module") + .with_python_paths(vec![PathBuf::from("/extra/path")]); + + assert_eq!(config.user_modules, vec!["my_module".to_string()]); + assert_eq!( + config.extra_python_paths, + vec![PathBuf::from("/extra/path")] + ); + } + + #[test] + fn test_config_with_multiple_modules() { + let config = + Config::new().with_user_modules(vec!["module1".to_string(), "module2".to_string()]); + + assert_eq!(config.user_modules, vec!["module1", "module2"]); + } + + #[test] + fn test_default_runner_detection() { + // Should return uv as fallback if waymark-worker not in PATH + let (path, args) = default_runner(); + // Either waymark-worker was found, or we get uv with args + if args.is_empty() { + assert!(path.to_string_lossy().contains("waymark-worker")); + } else { + assert_eq!(path, PathBuf::from("uv")); + assert_eq!(args, vec!["run", "python", "-m", "waymark.worker"]); + } + } +} diff --git a/crates/lib/worker-python/src/lib.rs b/crates/lib/worker-python/src/lib.rs new file mode 100644 index 00000000..eab3b370 --- /dev/null +++ b/crates/lib/worker-python/src/lib.rs @@ -0,0 +1,91 @@ +use std::{path::PathBuf, time::Duration}; + +mod config; + +pub use config::Config; + +/// TODO: rewrite to fully cache effective values, like workdir, as constructor. +pub struct Spec { + pub bridge_server_addr: std::net::SocketAddr, + pub config: Config, +} + +impl waymark_worker_process_spec::Spec for Spec { + fn prepare_spawn_params( + &self, + reservation_id: waymark_worker_reservation::Id, + ) -> waymark_reserved_process::SpawnParams { + // Determine working directory and module paths + let package_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("python"); + let working_dir = if package_root.is_dir() { + Some(package_root.clone()) + } else { + None + }; + + // Build PYTHONPATH with all necessary directories + let mut module_paths = Vec::new(); + if let Some(root) = working_dir.as_ref() { + module_paths.push(root.clone()); + let src_dir = root.join("src"); + if src_dir.exists() { + module_paths.push(src_dir); + } + let proto_dir = root.join("proto"); + if proto_dir.exists() { + module_paths.push(proto_dir); + } + } + module_paths.extend(self.config.extra_python_paths.clone()); + + let joined_python_path = module_paths + .iter() + .map(|path| path.display().to_string()) + .collect::>() + .join(":"); + + let python_path = match std::env::var("PYTHONPATH") { + Ok(existing) if !existing.is_empty() => format!("{existing}:{joined_python_path}"), + _ => joined_python_path, + }; + + tracing::info!(python_path = %python_path, ?reservation_id, "configured python path for worker"); + + // Build the command + let mut command = tokio::process::Command::new(&self.config.script_path); + command.args(&self.config.script_args); + command + .arg("--bridge") + .arg(self.bridge_server_addr.to_string()) + .arg("--worker-id") + .arg(reservation_id.to_string()); + + // Add user modules + for module in &self.config.user_modules { + command.arg("--user-module").arg(module); + } + + command.env("PYTHONPATH", python_path); + + if let Some(dir) = working_dir { + tracing::info!(?dir, "using package root for worker process"); + command.current_dir(dir); + } else { + // TODO: move this fallible initialization outside of this impl. + let cwd = std::env::current_dir().expect("failed to resolve current directory"); + tracing::info!( + ?cwd, + "package root missing, using current directory for worker process" + ); + command.current_dir(cwd); + } + + waymark_reserved_process::SpawnParams { + command, + // TODO: move to config + wait_for_playload_timeout: Duration::from_secs(15), + graceful_shutdown_timeout: Duration::from_secs(5), + kill_timeout: Duration::from_secs(10), + } + } +} diff --git a/crates/lib/worker-remote-bridge-bringup/Cargo.toml b/crates/lib/worker-remote-bridge-bringup/Cargo.toml new file mode 100644 index 00000000..2a04bf49 --- /dev/null +++ b/crates/lib/worker-remote-bridge-bringup/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "waymark-worker-remote-bridge-bringup" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-proto = { workspace = true } +waymark-worker-message-protocol = { workspace = true } +waymark-worker-remote-bridge-service = { workspace = true } +waymark-worker-reservation = { workspace = true } + +tokio = { workspace = true, features = ["rt", "net"] } +tokio-stream = { workspace = true, features = ["net"] } +tokio-util = { workspace = true } +tonic = { workspace = true } +tracing = { workspace = true } diff --git a/crates/lib/worker-remote-bridge-bringup/src/lib.rs b/crates/lib/worker-remote-bridge-bringup/src/lib.rs new file mode 100644 index 00000000..cec87bcc --- /dev/null +++ b/crates/lib/worker-remote-bridge-bringup/src/lib.rs @@ -0,0 +1,69 @@ +use std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, +}; + +use tokio::{net::TcpListener, task::JoinHandle}; +use tonic::transport::Server; +use tracing::{error, info}; + +use waymark_proto::messages as proto; + +type Registry = waymark_worker_reservation::Registry; + +/// Start the worker bridge server. +/// +/// If `bind_addr` is None, binds to localhost on an ephemeral port. +/// The actual bound address can be retrieved with [`Self::addr`]. +pub async fn start( + shutdown_token: tokio_util::sync::CancellationToken, + workers_registry: Arc, + bind_addr: Option, +) -> Result<(SocketAddr, JoinHandle<()>), std::io::Error> { + let bind_addr = + bind_addr.unwrap_or_else(|| SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)); + + // TODO: annotoate errors via custom error type. + let listener = TcpListener::bind(bind_addr).await?; + + let addr = listener.local_addr()?; + + info!(%addr, "worker bridge server starting"); + + let service = waymark_worker_remote_bridge_service::WorkerBridgeService { workers_registry }; + + let task = tokio::spawn(async move { + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + + let result = Server::builder() + .add_service(proto::worker_bridge_server::WorkerBridgeServer::new( + service, + )) + .serve_with_incoming_shutdown(incoming, shutdown_token.cancelled()) + .await; + if let Err(err) = result { + error!(?err, "worker bridge server exited with error"); + } + }); + + Ok((addr, task)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_server_starts_and_binds() { + let shutdown_token = tokio_util::sync::CancellationToken::new(); + + let registry = Default::default(); + + let (addr, task) = start(shutdown_token.clone(), registry, None).await.unwrap(); + assert!(addr.port() > 0); + + shutdown_token.cancel(); + + task.await.unwrap(); + } +} diff --git a/crates/lib/worker-remote-bridge-service/Cargo.toml b/crates/lib/worker-remote-bridge-service/Cargo.toml new file mode 100644 index 00000000..ff56fe18 --- /dev/null +++ b/crates/lib/worker-remote-bridge-service/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "waymark-worker-remote-bridge-service" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-proto = { workspace = true } +waymark-worker-message-protocol = { workspace = true } +waymark-worker-reservation = { workspace = true } + +futures-core = { workspace = true } +prost = { workspace = true } +tokio = { workspace = true, features = ["sync", "rt"] } +tokio-stream = { workspace = true, features = ["net"] } +tonic = { workspace = true } +tracing = { workspace = true } diff --git a/crates/lib/worker-remote-bridge-service/src/lib.rs b/crates/lib/worker-remote-bridge-service/src/lib.rs new file mode 100644 index 00000000..43353dc8 --- /dev/null +++ b/crates/lib/worker-remote-bridge-service/src/lib.rs @@ -0,0 +1,99 @@ +use std::{pin::Pin, sync::Arc}; + +use futures_core::Stream; +use prost::Message; +use tokio::sync::mpsc; +use tokio_stream::{StreamExt, wrappers::ReceiverStream}; +use tonic::{Request, Response, Status, Streaming, async_trait}; + +use waymark_proto::messages as proto; + +type Registry = waymark_worker_reservation::Registry; + +/// gRPC service implementation for the WorkerBridge. +#[derive(Clone)] +pub struct WorkerBridgeService { + pub workers_registry: Arc, +} + +#[async_trait] +impl proto::worker_bridge_server::WorkerBridge for WorkerBridgeService { + type AttachStream = + Pin> + Send + 'static>>; + + async fn attach( + &self, + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); + + // Read and validate the handshake message + let handshake = stream + .message() + .await + .map_err(|err| Status::internal(format!("failed to read handshake: {err}")))? + .ok_or_else(|| Status::invalid_argument("missing worker handshake"))?; + + let kind = proto::MessageKind::try_from(handshake.kind) + .map_err(|_| Status::invalid_argument("invalid message kind"))?; + + if kind != proto::MessageKind::WorkerHello { + return Err(Status::failed_precondition( + "expected WorkerHello as first message", + )); + } + + let hello = proto::WorkerHello::decode(&*handshake.payload).map_err(|err| { + Status::invalid_argument(format!("invalid WorkerHello payload: {err}")) + })?; + + let worker_id = hello.worker_id; + tracing::info!(worker_id, "worker connected and sent hello"); + + // Create channels for bidirectional communication + // Buffer size of 64 provides reasonable backpressure while allowing + // some pipelining of requests + let (to_worker_tx, to_worker_rx) = mpsc::channel(64); + let (from_worker_tx, from_worker_rx) = mpsc::channel(64); + + let reservation_id = waymark_worker_reservation::Id::from(worker_id); + let channels = waymark_worker_message_protocol::Channels { + to_worker: to_worker_tx, + from_worker: from_worker_rx, + }; + + // Complete the registration - this unblocks the spawn code + self.workers_registry + .register(reservation_id, channels) + .map_err(|err| tonic::Status::not_found(err.to_string()))?; + + // Spawn a task to read from the worker stream and forward to the channel + // TODO: move this into the `waymark_worker_message_protocol` and + // drop `rt` feature from `tokio`. + tokio::spawn(async move { + loop { + let envelope = match stream.message().await { + Ok(Some(envelope)) => envelope, + Ok(None) => { + // Stream closed cleanly + tracing::info!(worker_id, "worker stream closed"); + break; + } + Err(err) => { + tracing::warn!(?err, worker_id, "worker stream receive error"); + break; + } + }; + + if from_worker_tx.send(envelope).await.is_err() { + // Receiver dropped, worker shutting down + break; + } + } + }); + + // Return a stream that sends from to_worker_rx to the Python client + let outbound = ReceiverStream::new(to_worker_rx).map(Ok::); + Ok(Response::new(Box::pin(outbound) as Self::AttachStream)) + } +} diff --git a/crates/lib/worker-remote-bringup/Cargo.toml b/crates/lib/worker-remote-bringup/Cargo.toml new file mode 100644 index 00000000..43130984 --- /dev/null +++ b/crates/lib/worker-remote-bringup/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "waymark-worker-remote-bringup" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +waymark-worker-process-pool = { workspace = true } +waymark-worker-process-spec = { workspace = true } +waymark-worker-remote-bridge-bringup = { workspace = true } + +thiserror = { workspace = true } +tokio = { workspace = true, features = ["rt", "net"] } +tokio-util = { workspace = true } +tracing = { workspace = true } diff --git a/crates/lib/worker-remote-bringup/src/lib.rs b/crates/lib/worker-remote-bringup/src/lib.rs new file mode 100644 index 00000000..fc31ea0b --- /dev/null +++ b/crates/lib/worker-remote-bringup/src/lib.rs @@ -0,0 +1,59 @@ +use std::{ + net::SocketAddr, + num::{NonZeroU64, NonZeroUsize}, + sync::Arc, +}; + +use tokio::task::JoinHandle; + +#[derive(Debug, thiserror::Error)] +pub enum StartError { + #[error("bridge: {0}")] + Bridge(#[source] std::io::Error), + + #[error("pool: {0}")] + Pool(#[source] waymark_worker_process_pool::InitError), +} + +pub async fn start( + shutdown_token: tokio_util::sync::CancellationToken, + bind_addr: Option, + worker_process_spec_builder: impl FnOnce(SocketAddr) -> Spec, + worker_pool_size: NonZeroUsize, + max_action_lifecycle: Option, + max_concurrent_per_worker: NonZeroUsize, +) -> Result<(waymark_worker_process_pool::Pool, JoinHandle<()>), StartError> +where + Spec: waymark_worker_process_spec::Spec, +{ + let workers_registry = Default::default(); + + // Bringup server first. + let (bridge_addr, bridge_task) = waymark_worker_remote_bridge_bringup::start( + shutdown_token, + Arc::clone(&workers_registry), + bind_addr, + ) + .await + .map_err(StartError::Bridge)?; + + let worker_process_spec = (worker_process_spec_builder)(bridge_addr); + + let pool = waymark_worker_process_pool::Pool::new_with_concurrency( + workers_registry, + worker_process_spec, + worker_pool_size, + max_action_lifecycle, + max_concurrent_per_worker, + ) + .await + .map_err(StartError::Pool)?; + + tracing::info!( + %worker_pool_size, + %bridge_addr, + "worker pool started" + ); + + Ok((pool, bridge_task)) +} diff --git a/crates/lib/worker-remote/Cargo.toml b/crates/lib/worker-remote-pool/Cargo.toml similarity index 61% rename from crates/lib/worker-remote/Cargo.toml rename to crates/lib/worker-remote-pool/Cargo.toml index 91cf6c67..e20b8a4c 100644 --- a/crates/lib/worker-remote/Cargo.toml +++ b/crates/lib/worker-remote-pool/Cargo.toml @@ -1,25 +1,22 @@ [package] -name = "waymark-worker-remote" +name = "waymark-worker-remote-pool" edition = "2024" version.workspace = true publish.workspace = true [dependencies] -waymark-ids = { workspace = true } +waymark-managed-process = { workspace = true } waymark-message-conversions = { workspace = true } waymark-proto = { workspace = true } waymark-worker-core = { workspace = true } +waymark-worker-message-protocol = { workspace = true } waymark-worker-metrics = { workspace = true } +waymark-worker-process-pool = { workspace = true } +waymark-worker-process-spec = { workspace = true } waymark-worker-status-core = { workspace = true } -anyhow = { workspace = true } # TODO: drop -futures-core = { workspace = true } nonempty-collections = { workspace = true } prost = { workspace = true } serde_json = { workspace = true } -thiserror = { workspace = true } tokio = { workspace = true, features = ["process"] } -tokio-stream = { workspace = true, features = ["net"] } -tonic = { workspace = true } tracing = { workspace = true } -uuid = { workspace = true } diff --git a/crates/lib/worker-remote-pool/src/lib.rs b/crates/lib/worker-remote-pool/src/lib.rs new file mode 100644 index 00000000..f107ebac --- /dev/null +++ b/crates/lib/worker-remote-pool/src/lib.rs @@ -0,0 +1,205 @@ +mod request; +mod response; + +use std::{ + sync::{ + Arc, Mutex as StdMutex, + atomic::{AtomicBool, Ordering}, + }, + time::Duration, +}; + +use nonempty_collections::NEVec; + +use tokio::sync::{Mutex, mpsc}; + +use waymark_worker_core::{ActionCompletion, ActionRequest, WorkerPoolError, error_to_value}; + +async fn execute_remote_request( + pool: &Arc>, + request: ActionRequest, +) -> ActionCompletion +where + Spec: waymark_worker_process_spec::Spec, + Spec: Send + Sync + 'static, +{ + let executor_id = request.executor_id; + let execution_id = request.execution_id; + let attempt_number = request.attempt_number; + let dispatch_token = request.dispatch_token; + + let dispatch = match request::to_dispatch_payload(request) { + Ok(dispatch) => dispatch, + Err(short_circuit) => return short_circuit, + }; + + let worker_idx = loop { + if let Some(idx) = pool.try_acquire_slot() { + break idx; + } + tokio::time::sleep(Duration::from_millis(5)).await; + }; + + let worker = pool.get_worker(worker_idx).await; + + match worker.sender.send_action(dispatch).await { + Ok(metrics) => { + pool.record_latency(metrics.ack_latency, metrics.worker_duration); + pool.record_completion(worker_idx, Arc::clone(pool)); + ActionCompletion { + executor_id, + execution_id, + attempt_number, + dispatch_token, + result: response::decode_action_result(&metrics), + } + } + Err(err) => { + pool.release_slot(worker_idx); + ActionCompletion { + executor_id, + execution_id, + attempt_number, + dispatch_token, + result: error_to_value(&WorkerPoolError::new( + "RemoteWorkerPoolError", + err.to_string(), + )), + } + } + } +} + +// This type's only purpose is to provide transport layer to the underlying +// pool, however that poll should be itself capable of providing the said +// transport. +// TODO: move this into to `waymark-worker-message-protocol`; not done yet +// since it requires substantial changes to the code layout of the integration +// surfaces, and we want to keep things in place for review purposes. +// Another downside is the process pool wrapping requires an `Arc`, which may +// prevent proper shutdown - but without a real need for it (we only need +// to give out a tiny communication handle under an `Arc` - but that's also for +// later). +pub struct RemoteWorkerPool { + pool: Arc>, + request_tx: mpsc::Sender, + request_rx: StdMutex>>, + completion_tx: mpsc::Sender, + completion_rx: Mutex>, + launched: AtomicBool, +} + +impl RemoteWorkerPool { + const DEFAULT_QUEUE_CAPACITY: usize = 1024; + + pub fn new(pool: impl Into>>) -> Self { + Self::with_capacity( + pool, + Self::DEFAULT_QUEUE_CAPACITY, + Self::DEFAULT_QUEUE_CAPACITY, + ) + } + + pub fn with_capacity( + pool: impl Into>>, + request_capacity: usize, + completion_capacity: usize, + ) -> Self { + let (request_tx, request_rx) = mpsc::channel(request_capacity.max(1)); + let (completion_tx, completion_rx) = mpsc::channel(completion_capacity.max(1)); + Self { + pool: pool.into(), + request_tx, + request_rx: StdMutex::new(Some(request_rx)), + completion_tx, + completion_rx: Mutex::new(completion_rx), + launched: AtomicBool::new(false), + } + } + + pub async fn shutdown_arc( + self: Arc, + ) -> Result<(), waymark_managed_process::ShutdownError> { + let Some(inner) = Arc::into_inner(self) else { + tracing::warn!( + "remote worker pool still referenced during shutdown; skipping shutdown" + ); + return Ok(()); + }; + inner.shutdown().await + } + + pub async fn shutdown(self) -> Result<(), waymark_managed_process::ShutdownError> { + self.pool.shutdown_arc().await + } +} + +impl waymark_worker_core::BaseWorkerPool for RemoteWorkerPool +where + Spec: waymark_worker_process_spec::Spec, + Spec: Send + Sync + 'static, +{ + async fn launch(&self) -> std::result::Result<(), waymark_worker_core::WorkerPoolError> { + if self.launched.swap(true, Ordering::SeqCst) { + return Ok(()); + } + + let request_rx = { + let mut guard = self.request_rx.lock().map_err(|_| { + WorkerPoolError::new("RemoteWorkerPoolError", "failed to lock request receiver") + })?; + guard.take() + }; + + let Some(mut request_rx) = request_rx else { + return Ok(()); + }; + + let pool = Arc::clone(&self.pool); + let completion_tx = self.completion_tx.clone(); + + tokio::spawn(async move { + while let Some(request) = request_rx.recv().await { + tokio::spawn({ + let completion_tx = completion_tx.clone(); + let pool = Arc::clone(&pool); + async move { + let completion = execute_remote_request(&pool, request).await; + let _ = completion_tx.send(completion).await; + } + }); + } + }); + + Ok(()) + } + + fn queue(&self, request: ActionRequest) -> Result<(), WorkerPoolError> { + self.request_tx.try_send(request).map_err(|err| { + WorkerPoolError::new( + "RemoteWorkerPoolError", + format!("failed to enqueue action request: {err}"), + ) + }) + } + + async fn poll_complete(&self) -> Option> { + let mut receiver = self.completion_rx.lock().await; + + let first = receiver.recv().await?; + + let mut completions = NEVec::new(first); + + while let Ok(item) = receiver.try_recv() { + completions.push(item); + } + + Some(completions) + } +} + +impl waymark_worker_status_core::WorkerPoolStats for RemoteWorkerPool { + fn stats_snapshot(&self) -> waymark_worker_status_core::WorkerPoolStatsSnapshot { + self.pool.stats_snapshot() + } +} diff --git a/crates/lib/worker-remote-pool/src/request.rs b/crates/lib/worker-remote-pool/src/request.rs new file mode 100644 index 00000000..12040c5e --- /dev/null +++ b/crates/lib/worker-remote-pool/src/request.rs @@ -0,0 +1,62 @@ +use std::collections::HashMap; + +use waymark_proto::messages as proto; +use waymark_worker_core::{ActionCompletion, ActionRequest, WorkerPoolError, error_to_value}; +use waymark_worker_message_protocol::ActionDispatchPayload; + +fn kwargs_to_workflow_arguments( + kwargs: &HashMap, +) -> proto::WorkflowArguments { + let mut arguments = Vec::with_capacity(kwargs.len()); + for (key, value) in kwargs { + let arg_value = waymark_message_conversions::json_to_workflow_argument_value(value); + arguments.push(proto::WorkflowArgument { + key: key.clone(), + value: Some(arg_value), + }); + } + proto::WorkflowArguments { arguments } +} + +pub fn to_dispatch_payload( + request: ActionRequest, +) -> Result { + let ActionRequest { + executor_id, + execution_id, + action_name, + module_name, + kwargs, + timeout_seconds, + attempt_number, + dispatch_token, + } = request; + + let Some(module_name) = module_name else { + return Err(ActionCompletion { + executor_id, + execution_id, + attempt_number, + dispatch_token, + result: error_to_value(&WorkerPoolError::new( + "RemoteWorkerPoolError", + "missing module name for action request", + )), + }); + }; + + let dispatch = ActionDispatchPayload { + action_id: execution_id.to_string(), + instance_id: executor_id.to_string(), + sequence: 0, + action_name, + module_name, + kwargs: kwargs_to_workflow_arguments(&kwargs), + timeout_seconds, + max_retries: 0, + attempt_number, + dispatch_token, + }; + + Ok(dispatch) +} diff --git a/crates/lib/worker-remote-pool/src/response.rs b/crates/lib/worker-remote-pool/src/response.rs new file mode 100644 index 00000000..0b210c25 --- /dev/null +++ b/crates/lib/worker-remote-pool/src/response.rs @@ -0,0 +1,70 @@ +use prost::Message as _; +use waymark_proto::messages as proto; +use waymark_worker_core::{WorkerPoolError, error_to_value}; + +fn ensure_error_fields(mut map: serde_json::Map) -> serde_json::Value { + let error_type = map + .get("type") + .and_then(|value| value.as_str()) + .unwrap_or("RemoteWorkerError") + .to_string(); + let error_message = map + .get("message") + .and_then(|value| value.as_str()) + .unwrap_or("remote worker error") + .to_string(); + if !map.contains_key("type") { + map.insert("type".to_string(), serde_json::Value::String(error_type)); + } + if !map.contains_key("message") { + map.insert( + "message".to_string(), + serde_json::Value::String(error_message), + ); + } + serde_json::Value::Object(map) +} + +fn normalize_error_value(error: serde_json::Value) -> serde_json::Value { + let serde_json::Value::Object(mut map) = error else { + return error; + }; + + if let Some(serde_json::Value::Object(exception)) = map.remove("__exception__") { + return ensure_error_fields(exception); + } + + ensure_error_fields(map) +} + +pub fn decode_action_result( + metrics: &waymark_worker_metrics::RoundTripMetrics, +) -> serde_json::Value { + let payload = proto::WorkflowArguments::decode(metrics.response_payload.as_slice()) + .map(waymark_message_conversions::workflow_arguments_to_json) + .unwrap_or(serde_json::Value::Null); + + if metrics.success { + if let serde_json::Value::Object(mut map) = payload { + if let Some(result) = map.remove("result") { + return result; + } + return serde_json::Value::Object(map); + } + return payload; + } + + if let serde_json::Value::Object(mut map) = payload { + if let Some(error) = map.remove("error") { + return normalize_error_value(error); + } + return serde_json::Value::Object(map); + } + + let error_type = metrics.error_type.as_deref().unwrap_or("RemoteWorkerError"); + let error_message = metrics + .error_message + .as_deref() + .unwrap_or("remote worker error"); + error_to_value(&WorkerPoolError::new(error_type, error_message)) +} diff --git a/crates/lib/worker-remote/src/lib.rs b/crates/lib/worker-remote/src/lib.rs deleted file mode 100644 index a7baf44a..00000000 --- a/crates/lib/worker-remote/src/lib.rs +++ /dev/null @@ -1,1846 +0,0 @@ -//! Remote worker process management. -//! -//! This module provides the core infrastructure for spawning and managing -//! Python worker processes that execute workflow actions. -//! -//! ## Architecture -//! -//! ```text -//! ┌─────────────────────────────────────────────────────────────────────────┐ -//! │ PythonWorkerPool │ -//! │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ -//! │ │PythonWorker │ │PythonWorker │ │PythonWorker │ ... (N workers) │ -//! │ │ (process) │ │ (process) │ │ (process) │ │ -//! │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ -//! │ │ │ │ │ -//! │ └───────────────┼───────────────┘ │ -//! │ │ gRPC streaming │ -//! │ ▼ │ -//! │ ┌─────────────────────┐ │ -//! │ │ WorkerBridgeServer │ │ -//! │ └─────────────────────┘ │ -//! └─────────────────────────────────────────────────────────────────────────┘ -//! ``` -//! -//! ## Worker Lifecycle -//! -//! 1. Pool spawns N worker processes, each connecting to the WorkerBridge -//! 2. Workers send `WorkerHello` to complete the handshake -//! 3. Pool sends `ActionDispatch` messages, workers respond with `ActionResult` -//! 4. Workers send `Ack` immediately upon receiving a dispatch (for latency tracking) -//! 5. On shutdown, workers are terminated gracefully -//! -//! ## Error Handling -//! -//! - Worker spawn failures are propagated immediately -//! - Connection timeouts (15s) trigger worker process termination -//! - Dropped channels indicate worker crash and are propagated as errors -//! - Round-robin selection ensures load distribution even with slow workers - -pub mod server_worker; - -use std::{ - collections::HashMap, - env, - net::SocketAddr, - path::PathBuf, - process::Stdio, - sync::{ - Arc, Mutex as StdMutex, - atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, - }, - time::{Duration, Instant}, -}; - -use nonempty_collections::NEVec; -use serde_json::Value; -use tokio::sync::RwLock; - -use anyhow::{Context, Result as AnyResult, anyhow}; -use prost::Message; -use tokio::{ - process::{Child, Command}, - sync::{Mutex, mpsc, oneshot}, - task::JoinHandle, - time::timeout, -}; -use tracing::{error, info, trace, warn}; -use uuid::Uuid; - -use waymark_proto::messages as proto; -use waymark_worker_core::{ - ActionCompletion, ActionRequest, BaseWorkerPool, WorkerPoolError, error_to_value, -}; -use waymark_worker_metrics::{RoundTripMetrics, WorkerPoolMetrics, WorkerThroughputSnapshot}; -use waymark_worker_status_core::{WorkerPoolStats, WorkerPoolStatsSnapshot}; - -use self::server_worker::{WorkerBridgeChannels, WorkerBridgeServer}; - -const LATENCY_SAMPLE_SIZE: usize = 256; -const THROUGHPUT_WINDOW_SECS: u64 = 1; - -/// Errors that can occur during message encoding/decoding -#[derive(Debug, thiserror::Error)] -pub enum MessageError { - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - #[error("Failed to decode message: {0}")] - Decode(#[from] prost::DecodeError), - #[error("Failed to encode message: {0}")] - Encode(#[from] prost::EncodeError), - #[error("Channel closed")] - ChannelClosed, -} - -/// Configuration for spawning Python workers. -#[derive(Clone, Debug)] -pub struct PythonWorkerConfig { - /// Path to the script/executable to run (e.g., "uv" or "waymark-worker") - pub script_path: PathBuf, - /// Arguments to pass before the worker-specific args - pub script_args: Vec, - /// Python module(s) to preload (contains @action definitions) - pub user_modules: Vec, - /// Additional paths to add to PYTHONPATH - pub extra_python_paths: Vec, -} - -impl Default for PythonWorkerConfig { - fn default() -> Self { - let (script_path, script_args) = default_runner(); - Self { - script_path, - script_args, - user_modules: vec![], - extra_python_paths: vec![], - } - } -} - -impl PythonWorkerConfig { - /// Create a new config with default runner detection. - pub fn new() -> Self { - Self::default() - } - - /// Set the user module to preload. - pub fn with_user_module(mut self, module: &str) -> Self { - self.user_modules = vec![module.to_string()]; - self - } - - /// Set multiple user modules to preload. - pub fn with_user_modules(mut self, modules: Vec) -> Self { - self.user_modules = modules; - self - } - - /// Add extra paths to PYTHONPATH. - pub fn with_python_paths(mut self, paths: Vec) -> Self { - self.extra_python_paths = paths; - self - } -} - -/// Find the default Python runner. -/// Prefers `waymark-worker` if in PATH, otherwise uses `uv run`. -fn default_runner() -> (PathBuf, Vec) { - if let Some(path) = find_executable("waymark-worker") { - return (path, Vec::new()); - } - ( - PathBuf::from("uv"), - vec![ - "run".to_string(), - "python".to_string(), - "-m".to_string(), - "waymark.worker".to_string(), - ], - ) -} - -/// Search PATH for an executable. -fn find_executable(bin: &str) -> Option { - let path_var = env::var_os("PATH")?; - for dir in env::split_paths(&path_var) { - let candidate = dir.join(bin); - if candidate.is_file() { - return Some(candidate); - } - #[cfg(windows)] - { - let exe_candidate = dir.join(format!("{bin}.exe")); - if exe_candidate.is_file() { - return Some(exe_candidate); - } - } - } - None -} - -/// Payload for dispatching an action to a worker. -#[derive(Debug, Clone)] -pub struct ActionDispatchPayload { - /// Unique action identifier - pub action_id: String, - /// Workflow instance this action belongs to - pub instance_id: String, - /// Sequence number within the instance - pub sequence: u32, - /// Name of the action function to call - pub action_name: String, - /// Python module containing the action - pub module_name: String, - /// Keyword arguments for the action - pub kwargs: proto::WorkflowArguments, - /// Timeout in seconds (0 = no timeout) - pub timeout_seconds: u32, - /// Maximum retry attempts - pub max_retries: u32, - /// Current attempt number - pub attempt_number: u32, - /// Dispatch token for correlation - pub dispatch_token: Uuid, -} - -/// Internal state shared between worker sender and reader tasks. -struct SharedState { - /// Pending ACK receivers, keyed by delivery_id - pending_acks: HashMap>, - /// Pending result receivers, keyed by delivery_id - pending_responses: HashMap>, -} - -impl SharedState { - fn new() -> Self { - Self { - pending_acks: HashMap::new(), - pending_responses: HashMap::new(), - } - } -} - -/// A single Python worker process. -/// -/// Manages the lifecycle of a Python subprocess that executes actions. -/// Communication happens via gRPC streaming through the WorkerBridge. -/// -/// Workers are not meant to be used directly - use [`PythonWorkerPool`] instead. -pub struct PythonWorker { - /// The child process - child: Child, - /// Channel to send envelopes to the worker - sender: mpsc::Sender, - /// Shared state for pending requests - shared: Arc>, - /// Counter for delivery IDs (monotonically increasing) - next_delivery: AtomicU64, - /// Handle to the reader task - reader_handle: Option>, - /// Worker ID for logging - worker_id: u64, -} - -impl PythonWorker { - /// Spawn a new Python worker process. - /// - /// This will: - /// 1. Reserve a worker ID on the bridge - /// 2. Spawn the Python process with appropriate arguments - /// 3. Wait for the worker to connect (15s timeout) - /// 4. Set up bidirectional communication channels - /// - /// # Errors - /// - /// Returns an error if: - /// - The process fails to spawn - /// - The worker doesn't connect within 15 seconds - /// - The bridge connection is dropped - pub async fn spawn( - config: PythonWorkerConfig, - bridge: Arc, - ) -> AnyResult { - let (worker_id, connection_rx) = bridge.reserve_worker().await; - - // Determine working directory and module paths - let package_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("python"); - let working_dir = if package_root.is_dir() { - Some(package_root.clone()) - } else { - None - }; - - // Build PYTHONPATH with all necessary directories - let mut module_paths = Vec::new(); - if let Some(root) = working_dir.as_ref() { - module_paths.push(root.clone()); - let src_dir = root.join("src"); - if src_dir.exists() { - module_paths.push(src_dir); - } - let proto_dir = root.join("proto"); - if proto_dir.exists() { - module_paths.push(proto_dir); - } - } - module_paths.extend(config.extra_python_paths.clone()); - - let joined_python_path = module_paths - .iter() - .map(|path| path.display().to_string()) - .collect::>() - .join(":"); - - let python_path = match env::var("PYTHONPATH") { - Ok(existing) if !existing.is_empty() => format!("{existing}:{joined_python_path}"), - _ => joined_python_path, - }; - - info!(python_path = %python_path, worker_id, "configured python path for worker"); - - // Build the command - let mut command = Command::new(&config.script_path); - command.args(&config.script_args); - command - .arg("--bridge") - .arg(bridge.addr().to_string()) - .arg("--worker-id") - .arg(worker_id.to_string()); - - // Add user modules - for module in &config.user_modules { - command.arg("--user-module").arg(module); - } - - command - .stderr(Stdio::inherit()) - .env("PYTHONPATH", python_path); - - if let Some(dir) = working_dir { - info!(?dir, worker_id, "using package root for worker process"); - command.current_dir(dir); - } else { - let cwd = env::current_dir().context("failed to resolve current directory")?; - info!( - ?cwd, - worker_id, "package root missing, using current directory for worker process" - ); - command.current_dir(cwd); - } - - // Spawn the process - let mut child = match command.spawn().context("failed to launch python worker") { - Ok(child) => child, - Err(err) => { - bridge.cancel_worker(worker_id).await; - return Err(err); - } - }; - - info!( - pid = child.id(), - script = %config.script_path.display(), - worker_id, - "spawned python worker" - ); - - // Wait for the worker to connect (with timeout) - let connection = match timeout(Duration::from_secs(15), connection_rx).await { - Ok(Ok(channels)) => channels, - Ok(Err(_)) => { - bridge.cancel_worker(worker_id).await; - let _ = child.start_kill(); - let _ = child.wait().await; - return Err(anyhow!("worker bridge channel closed before attach")); - } - Err(_) => { - bridge.cancel_worker(worker_id).await; - let _ = child.start_kill(); - let _ = child.wait().await; - return Err(anyhow!( - "timed out waiting for worker {} to connect (15s)", - worker_id - )); - } - }; - - let WorkerBridgeChannels { - to_worker, - mut from_worker, - } = connection; - - // Set up shared state and spawn reader task - let shared = Arc::new(Mutex::new(SharedState::new())); - let reader_shared = Arc::clone(&shared); - let reader_worker_id = worker_id; - let reader_handle = tokio::spawn(async move { - if let Err(err) = Self::reader_loop(&mut from_worker, reader_shared).await { - error!( - ?err, - worker_id = reader_worker_id, - "python worker stream exited" - ); - } - }); - - info!(worker_id, "worker connected and ready"); - - Ok(Self { - child, - sender: to_worker, - shared, - next_delivery: AtomicU64::new(1), - reader_handle: Some(reader_handle), - worker_id, - }) - } - - /// Send an action to the worker and wait for the result. - /// - /// This method: - /// 1. Allocates a delivery ID - /// 2. Creates channels for ACK and response - /// 3. Sends the action dispatch - /// 4. Waits for ACK (immediate) - /// 5. Waits for result (after execution) - /// 6. Returns metrics including latencies - /// - /// # Errors - /// - /// Returns an error if: - /// - The worker channel is closed (worker crashed) - /// - Response decoding fails - pub async fn send_action( - &self, - dispatch: ActionDispatchPayload, - ) -> Result { - let delivery_id = self.next_delivery.fetch_add(1, Ordering::SeqCst); - let send_instant = Instant::now(); - - trace!( - action_id = %dispatch.action_id, - instance_id = %dispatch.instance_id, - sequence = dispatch.sequence, - module = %dispatch.module_name, - function = %dispatch.action_name, - worker_id = self.worker_id, - delivery_id, - "sending action to worker" - ); - - // Create channels for receiving ACK and response - let (ack_tx, ack_rx) = oneshot::channel(); - let (response_tx, response_rx) = oneshot::channel(); - - // Register pending requests - { - let mut shared = self.shared.lock().await; - shared.pending_acks.insert(delivery_id, ack_tx); - shared.pending_responses.insert(delivery_id, response_tx); - } - - // Build and send the dispatch envelope - let command = proto::ActionDispatch { - action_id: dispatch.action_id.clone(), - instance_id: dispatch.instance_id.clone(), - sequence: dispatch.sequence, - action_name: dispatch.action_name.clone(), - module_name: dispatch.module_name.clone(), - kwargs: Some(dispatch.kwargs.clone()), - timeout_seconds: Some(dispatch.timeout_seconds), - max_retries: Some(dispatch.max_retries), - attempt_number: Some(dispatch.attempt_number), - dispatch_token: Some(dispatch.dispatch_token.to_string()), - }; - - let envelope = proto::Envelope { - delivery_id, - partition_id: 0, - kind: proto::MessageKind::ActionDispatch as i32, - payload: command.encode_to_vec(), - }; - - self.send_envelope(envelope).await?; - - // Wait for ACK (should be immediate) - let ack_instant = ack_rx.await.map_err(|_| MessageError::ChannelClosed)?; - - // Wait for the actual response (after execution) - let (response, response_instant) = - response_rx.await.map_err(|_| MessageError::ChannelClosed)?; - - // Calculate metrics - let ack_latency = ack_instant - .checked_duration_since(send_instant) - .unwrap_or_default(); - let round_trip = response_instant - .checked_duration_since(send_instant) - .unwrap_or_default(); - let worker_duration = Duration::from_nanos( - response - .worker_end_ns - .saturating_sub(response.worker_start_ns), - ); - - trace!( - action_id = %dispatch.action_id, - worker_id = self.worker_id, - ack_latency_us = ack_latency.as_micros(), - round_trip_ms = round_trip.as_millis(), - worker_duration_ms = worker_duration.as_millis(), - success = response.success, - "action completed" - ); - - Ok(RoundTripMetrics { - action_id: dispatch.action_id, - instance_id: dispatch.instance_id, - delivery_id, - sequence: dispatch.sequence, - ack_latency, - round_trip, - worker_duration, - response_payload: response - .payload - .as_ref() - .map(|payload| payload.encode_to_vec()) - .unwrap_or_default(), - success: response.success, - dispatch_token: response - .dispatch_token - .as_ref() - .and_then(|token| Uuid::parse_str(token).ok()), - error_type: response.error_type, - error_message: response.error_message, - }) - } - - /// Send an envelope to the worker. - async fn send_envelope(&self, envelope: proto::Envelope) -> Result<(), MessageError> { - self.sender - .send(envelope) - .await - .map_err(|_| MessageError::ChannelClosed) - } - - /// Background task that reads messages from the worker. - async fn reader_loop( - incoming: &mut mpsc::Receiver, - shared: Arc>, - ) -> Result<(), MessageError> { - while let Some(envelope) = incoming.recv().await { - let kind = proto::MessageKind::try_from(envelope.kind) - .unwrap_or(proto::MessageKind::Unspecified); - - match kind { - proto::MessageKind::Ack => { - let ack = proto::Ack::decode(envelope.payload.as_slice())?; - let mut guard = shared.lock().await; - if let Some(sender) = guard.pending_acks.remove(&ack.acked_delivery_id) { - let _ = sender.send(Instant::now()); - } else { - warn!(delivery = ack.acked_delivery_id, "unexpected ACK"); - } - } - proto::MessageKind::ActionResult => { - let response = proto::ActionResult::decode(envelope.payload.as_slice())?; - let mut guard = shared.lock().await; - if let Some(sender) = guard.pending_responses.remove(&envelope.delivery_id) { - let _ = sender.send((response, Instant::now())); - } else { - warn!(delivery = envelope.delivery_id, "orphan response"); - } - } - proto::MessageKind::Heartbeat => { - trace!(delivery = envelope.delivery_id, "heartbeat"); - } - other => { - warn!(?other, "unhandled message kind"); - } - } - } - - Ok(()) - } - - /// Gracefully shut down the worker. - pub async fn shutdown(mut self) -> AnyResult<()> { - info!(worker_id = self.worker_id, "shutting down worker"); - - // Abort the reader task - if let Some(handle) = self.reader_handle.take() { - handle.abort(); - let _ = handle.await; - } - - // Kill the child process - self.child.start_kill()?; - let _ = self.child.wait().await?; - - info!(worker_id = self.worker_id, "worker shutdown complete"); - Ok(()) - } - - /// Get the worker ID. - pub fn worker_id(&self) -> u64 { - self.worker_id - } -} - -impl Drop for PythonWorker { - fn drop(&mut self) { - if let Some(handle) = self.reader_handle.take() { - handle.abort(); - } - if let Err(err) = self.child.start_kill() { - warn!( - ?err, - worker_id = self.worker_id, - "failed to kill python worker during drop" - ); - } - } -} - -/// Pool of Python workers for action execution. -/// -/// Provides round-robin load balancing across multiple worker processes. -/// Workers are spawned eagerly on pool creation. -/// -/// # Example -/// -/// ```ignore -/// let config = PythonWorkerConfig::new() -/// .with_user_module("my_app.actions"); -/// let pool = PythonWorkerPool::new_with_bridge_addr(config, 4, None, None, 10).await?; -/// -/// let metrics = pool.get_worker(0).await.send_action(dispatch).await?; -/// ``` -pub struct PythonWorkerPool { - /// The workers in the pool (RwLock for recycling support) - workers: RwLock>>, - /// Cursor for round-robin selection - cursor: AtomicUsize, - /// Shared metrics tracker for throughput + latency. - metrics: StdMutex, - /// Action counts per worker slot (for lifecycle tracking) - action_counts: Vec, - /// In-flight action counts per worker slot (for concurrency control) - in_flight_counts: Vec, - /// Maximum concurrent actions per worker - max_concurrent_per_worker: usize, - /// Maximum actions per worker before recycling (None = no limit) - max_action_lifecycle: Option, - /// Bridge server for spawning replacement workers - bridge: Arc, - /// Worker configuration for spawning replacements - config: PythonWorkerConfig, -} - -impl PythonWorkerPool { - /// Create a new worker pool with the given configuration. - /// - /// Spawns `count` worker processes. All workers must successfully - /// spawn and connect before this returns. - /// - /// # Arguments - /// - /// * `config` - Configuration for worker processes - /// * `count` - Number of workers to spawn (minimum 1) - /// * `bridge` - The WorkerBridge server workers will connect to - /// * `max_action_lifecycle` - Maximum actions per worker before recycling (None = no limit) - /// * `max_concurrent_per_worker` - Maximum concurrent actions per worker (default 10) - /// - /// # Errors - /// - /// Returns an error if any worker fails to spawn or connect. - pub async fn new( - config: PythonWorkerConfig, - count: usize, - bridge: Arc, - max_action_lifecycle: Option, - ) -> AnyResult { - Self::new_with_concurrency(config, count, bridge, max_action_lifecycle, 10).await - } - - /// Create a new worker pool with explicit concurrency limit. - pub async fn new_with_concurrency( - config: PythonWorkerConfig, - count: usize, - bridge: Arc, - max_action_lifecycle: Option, - max_concurrent_per_worker: usize, - ) -> AnyResult { - let worker_count = count.max(1); - info!( - count = worker_count, - max_action_lifecycle = ?max_action_lifecycle, - "spawning python worker pool" - ); - - // Spawn all workers in parallel to reduce boot time. - let spawn_handles: Vec<_> = (0..worker_count) - .map(|_| { - let cfg = config.clone(); - let br = Arc::clone(&bridge); - tokio::spawn(async move { PythonWorker::spawn(cfg, br).await }) - }) - .collect(); - - let mut workers = Vec::with_capacity(worker_count); - for (i, handle) in spawn_handles.into_iter().enumerate() { - match handle.await { - Ok(Ok(worker)) => { - workers.push(Arc::new(worker)); - } - result @ (Ok(Err(_)) | Err(_)) => { - let err = match result { - Ok(Err(e)) => e, - Err(e) => anyhow::Error::from(e), - _ => unreachable!(), - }; - warn!( - worker_index = i, - "failed to spawn worker, cleaning up {} already spawned", - workers.len() - ); - for worker in workers { - if let Ok(worker) = Arc::try_unwrap(worker) { - let _ = worker.shutdown().await; - } - } - return Err(err.context(format!("failed to spawn worker {}", i))); - } - } - } - - info!(count = workers.len(), "worker pool ready"); - - let worker_ids = workers.iter().map(|worker| worker.worker_id()).collect(); - let action_counts = (0..worker_count).map(|_| AtomicU64::new(0)).collect(); - let in_flight_counts = (0..worker_count).map(|_| AtomicUsize::new(0)).collect(); - Ok(Self { - workers: RwLock::new(workers), - cursor: AtomicUsize::new(0), - metrics: StdMutex::new(WorkerPoolMetrics::new( - worker_ids, - Duration::from_secs(THROUGHPUT_WINDOW_SECS), - LATENCY_SAMPLE_SIZE, - )), - action_counts, - in_flight_counts, - max_concurrent_per_worker: max_concurrent_per_worker.max(1), - max_action_lifecycle, - bridge, - config, - }) - } - - /// Create a new worker pool and spawn its own bridge server. - pub async fn new_with_bridge_addr( - config: PythonWorkerConfig, - count: usize, - bind_addr: Option, - max_action_lifecycle: Option, - max_concurrent_per_worker: usize, - ) -> AnyResult { - let bridge = WorkerBridgeServer::start(bind_addr).await?; - match Self::new_with_concurrency( - config, - count, - Arc::clone(&bridge), - max_action_lifecycle, - max_concurrent_per_worker, - ) - .await - { - Ok(pool) => Ok(pool), - Err(err) => { - bridge.shutdown().await; - Err(err) - } - } - } - - /// Return the bridge address for worker connections. - pub fn bridge_addr(&self) -> SocketAddr { - self.bridge.addr() - } - - /// Get a worker by index. - /// - /// Returns a clone of the Arc for the worker at the given index. - pub async fn get_worker(&self, idx: usize) -> Arc { - let workers = self.workers.read().await; - Arc::clone(&workers[idx % workers.len()]) - } - - /// Get the next worker index using round-robin selection. - /// - /// This is lock-free and O(1). Returns the index that can be used - /// with `get_worker` to fetch the actual worker. - pub fn next_worker_idx(&self) -> usize { - self.cursor.fetch_add(1, Ordering::Relaxed) - } - - /// Get the number of workers in the pool. - pub fn len(&self) -> usize { - self.action_counts.len() - } - - /// Check if the pool is empty. - pub fn is_empty(&self) -> bool { - self.action_counts.is_empty() - } - - /// Get the maximum concurrent actions per worker. - pub fn max_concurrent_per_worker(&self) -> usize { - self.max_concurrent_per_worker - } - - /// Get total capacity (worker_count * max_concurrent_per_worker). - pub fn total_capacity(&self) -> usize { - self.len() * self.max_concurrent_per_worker - } - - /// Get total in-flight actions across all workers. - pub fn total_in_flight(&self) -> usize { - self.in_flight_counts - .iter() - .map(|c| c.load(Ordering::Relaxed)) - .sum() - } - - /// Get available capacity (total_capacity - total_in_flight). - pub fn available_capacity(&self) -> usize { - self.total_capacity().saturating_sub(self.total_in_flight()) - } - - /// Try to acquire a slot for the next available worker. - /// - /// Returns `Some(worker_idx)` if a slot was acquired, `None` if all workers - /// are at capacity. Uses round-robin selection among workers with capacity. - pub fn try_acquire_slot(&self) -> Option { - let worker_count = self.len(); - if worker_count == 0 { - return None; - } - - // Try each worker starting from the current cursor position - let start = self.cursor.fetch_add(1, Ordering::Relaxed); - for i in 0..worker_count { - let idx = (start + i) % worker_count; - if self.try_acquire_slot_for_worker(idx) { - return Some(idx); - } - } - None - } - - /// Try to acquire a slot for a specific worker. - /// - /// Returns `true` if the slot was acquired, `false` if the worker is at capacity. - pub fn try_acquire_slot_for_worker(&self, worker_idx: usize) -> bool { - let Some(counter) = self.in_flight_counts.get(worker_idx % self.len()) else { - return false; - }; - - // CAS loop to atomically increment if below limit - loop { - let current = counter.load(Ordering::Acquire); - if current >= self.max_concurrent_per_worker { - return false; - } - match counter.compare_exchange_weak( - current, - current + 1, - Ordering::AcqRel, - Ordering::Relaxed, - ) { - Ok(_) => return true, - Err(_) => continue, // Retry - } - } - } - - /// Release a slot for a worker. - /// - /// Should be called when an action completes (via `record_completion`). - pub fn release_slot(&self, worker_idx: usize) { - if let Some(counter) = self.in_flight_counts.get(worker_idx % self.len()) { - // Saturating sub to avoid underflow in case of bugs - let prev = counter.fetch_sub(1, Ordering::Release); - if prev == 0 { - warn!(worker_idx, "release_slot called with zero in-flight count"); - counter.store(0, Ordering::Release); - } - } - } - - /// Get in-flight count for a specific worker. - pub fn in_flight_for_worker(&self, worker_idx: usize) -> usize { - self.in_flight_counts - .get(worker_idx % self.len()) - .map(|c| c.load(Ordering::Relaxed)) - .unwrap_or(0) - } - - /// Get a snapshot of all workers in the pool. - pub async fn workers_snapshot(&self) -> Vec> { - self.workers.read().await.clone() - } - - /// Get throughput snapshots for all workers. - /// - /// Returns worker throughput metrics including completion counts and rates. - pub fn throughput_snapshots(&self) -> Vec { - if let Ok(mut metrics) = self.metrics.lock() { - metrics.throughput_snapshots(Instant::now()) - } else { - Vec::new() - } - } - - /// Record the latest latency measurements for median reporting. - pub fn record_latency(&self, ack_latency: Duration, worker_duration: Duration) { - if let Ok(mut metrics) = self.metrics.lock() { - metrics.record_latency(ack_latency, worker_duration); - } - } - - /// Return the current median dequeue/handling latencies in milliseconds. - pub fn median_latencies_ms(&self) -> (Option, Option) { - if let Ok(metrics) = self.metrics.lock() { - metrics.median_latencies_ms() - } else { - (None, None) - } - } - - /// Get queue statistics: (dispatch_queue_size, total_in_flight). - pub fn queue_stats(&self) -> (usize, usize) { - let total_in_flight: usize = self - .in_flight_counts - .iter() - .map(|c| c.load(Ordering::Relaxed)) - .sum(); - // dispatch_queue_size would require access to the bridge's queue - // For now, return 0 as placeholder - (0, total_in_flight) - } - - /// Record an action completion for a worker and trigger recycling if needed. - /// - /// This decrements the in-flight count and increments the action count for - /// the worker at the given index. If `max_action_lifecycle` is set and the - /// count reaches or exceeds the threshold, a background task is spawned to - /// recycle the worker. - pub fn record_completion(&self, worker_idx: usize, pool: Arc) { - // Release the in-flight slot - self.release_slot(worker_idx); - - // Update throughput tracking - if let Ok(mut metrics) = self.metrics.lock() { - metrics.record_completion(worker_idx); - if tracing::enabled!(tracing::Level::TRACE) { - let snapshots = metrics.throughput_snapshots(Instant::now()); - if let Some(snapshot) = snapshots.get(worker_idx) { - trace!( - worker_id = snapshot.worker_id, - throughput_per_min = snapshot.throughput_per_min, - total_completed = snapshot.total_completed, - last_action_at = ?snapshot.last_action_at, - "worker throughput snapshot" - ); - } - } - } - - // Increment action count - if let Some(counter) = self.action_counts.get(worker_idx) { - let new_count = counter.fetch_add(1, Ordering::SeqCst) + 1; - - // Check if recycling is needed - if let Some(max_lifecycle) = self.max_action_lifecycle - && new_count >= max_lifecycle - { - info!( - worker_idx, - action_count = new_count, - max_lifecycle, - "worker reached action lifecycle limit, scheduling recycle" - ); - // Spawn a background task to recycle this worker - tokio::spawn(async move { - if let Err(err) = pool.recycle_worker(worker_idx).await { - error!(worker_idx, ?err, "failed to recycle worker"); - } - }); - } - } - } - - /// Recycle a worker at the given index. - /// - /// Spawns a new worker and replaces the old one. The old worker - /// will be shut down once all in-flight actions complete (when - /// its Arc reference count drops to zero). - async fn recycle_worker(&self, worker_idx: usize) -> AnyResult<()> { - // Spawn the replacement worker first - let new_worker = PythonWorker::spawn(self.config.clone(), Arc::clone(&self.bridge)).await?; - let new_worker_id = new_worker.worker_id(); - - // Replace the worker in the pool - let old_worker = { - let mut workers = self.workers.write().await; - let idx = worker_idx % workers.len(); - std::mem::replace(&mut workers[idx], Arc::new(new_worker)) - }; - - // Reset the action count for this slot - if let Some(counter) = self - .action_counts - .get(worker_idx % self.action_counts.len()) - { - counter.store(0, Ordering::SeqCst); - } - - // Update throughput tracker with new worker ID - if let Ok(mut metrics) = self.metrics.lock() { - metrics.reset_worker(worker_idx, new_worker_id); - } - - info!( - worker_idx, - old_worker_id = old_worker.worker_id(), - new_worker_id, - "recycled worker" - ); - - // The old worker will be cleaned up when its Arc drops - // (once all in-flight actions complete) - - Ok(()) - } - - /// Get the current action count for a worker slot. - /// - /// Returns the number of actions that have been completed by the worker - /// at the given index since it was last spawned/recycled. - #[cfg(test)] - #[allow(dead_code)] - pub(crate) fn get_action_count(&self, worker_idx: usize) -> u64 { - self.action_counts - .get(worker_idx) - .map(|c| c.load(Ordering::SeqCst)) - .unwrap_or(0) - } - - /// Get the maximum action lifecycle setting. - #[cfg(test)] - #[allow(dead_code)] - pub(crate) fn max_lifecycle(&self) -> Option { - self.max_action_lifecycle - } - - /// Gracefully shut down all workers in the pool. - /// - /// Workers are shut down in order. Any workers still in use - /// (shared references exist) are skipped with a warning. - pub async fn shutdown(self) -> AnyResult<()> { - let workers = self.workers.into_inner(); - info!(count = workers.len(), "shutting down worker pool"); - - for worker in workers { - match Arc::try_unwrap(worker) { - Ok(worker) => { - worker.shutdown().await?; - } - Err(arc) => { - warn!( - worker_id = arc.worker_id(), - "worker still in use during shutdown; skipping" - ); - } - } - } - - self.bridge.shutdown().await; - info!("worker pool shutdown complete"); - Ok(()) - } -} - -fn kwargs_to_workflow_arguments(kwargs: &HashMap) -> proto::WorkflowArguments { - let mut arguments = Vec::with_capacity(kwargs.len()); - for (key, value) in kwargs { - let arg_value = waymark_message_conversions::json_to_workflow_argument_value(value); - arguments.push(proto::WorkflowArgument { - key: key.clone(), - value: Some(arg_value), - }); - } - proto::WorkflowArguments { arguments } -} - -fn normalize_error_value(error: Value) -> Value { - let Value::Object(mut map) = error else { - return error; - }; - - if let Some(Value::Object(exception)) = map.remove("__exception__") { - return ensure_error_fields(exception); - } - - ensure_error_fields(map) -} - -fn ensure_error_fields(mut map: serde_json::Map) -> Value { - let error_type = map - .get("type") - .and_then(|value| value.as_str()) - .unwrap_or("RemoteWorkerError") - .to_string(); - let error_message = map - .get("message") - .and_then(|value| value.as_str()) - .unwrap_or("remote worker error") - .to_string(); - if !map.contains_key("type") { - map.insert("type".to_string(), Value::String(error_type)); - } - if !map.contains_key("message") { - map.insert("message".to_string(), Value::String(error_message)); - } - Value::Object(map) -} - -fn decode_action_result(metrics: &RoundTripMetrics) -> Value { - let payload = proto::WorkflowArguments::decode(metrics.response_payload.as_slice()) - .map(waymark_message_conversions::workflow_arguments_to_json) - .unwrap_or(Value::Null); - - if metrics.success { - if let Value::Object(mut map) = payload { - if let Some(result) = map.remove("result") { - return result; - } - return Value::Object(map); - } - return payload; - } - - if let Value::Object(mut map) = payload { - if let Some(error) = map.remove("error") { - return normalize_error_value(error); - } - return Value::Object(map); - } - - let error_type = metrics.error_type.as_deref().unwrap_or("RemoteWorkerError"); - let error_message = metrics - .error_message - .as_deref() - .unwrap_or("remote worker error"); - error_to_value(&WorkerPoolError::new(error_type, error_message)) -} - -async fn execute_remote_request( - pool: Arc, - request: ActionRequest, -) -> ActionCompletion { - let executor_id = request.executor_id; - let execution_id = request.execution_id; - let attempt_number = request.attempt_number; - let dispatch_token = request.dispatch_token; - let timeout_seconds = request.timeout_seconds; - let Some(module_name) = request.module_name.clone() else { - return ActionCompletion { - executor_id, - execution_id, - attempt_number, - dispatch_token, - result: error_to_value(&WorkerPoolError::new( - "RemoteWorkerPoolError", - "missing module name for action request", - )), - }; - }; - - let worker_idx = loop { - if let Some(idx) = pool.try_acquire_slot() { - break idx; - } - tokio::time::sleep(Duration::from_millis(5)).await; - }; - - let worker = pool.get_worker(worker_idx).await; - let dispatch = ActionDispatchPayload { - action_id: execution_id.to_string(), - instance_id: executor_id.to_string(), - sequence: 0, - action_name: request.action_name, - module_name, - kwargs: kwargs_to_workflow_arguments(&request.kwargs), - timeout_seconds, - max_retries: 0, - attempt_number, - dispatch_token, - }; - - match worker.send_action(dispatch).await { - Ok(metrics) => { - pool.record_latency(metrics.ack_latency, metrics.worker_duration); - pool.record_completion(worker_idx, Arc::clone(&pool)); - ActionCompletion { - executor_id, - execution_id, - attempt_number, - dispatch_token, - result: decode_action_result(&metrics), - } - } - Err(err) => { - pool.release_slot(worker_idx); - ActionCompletion { - executor_id, - execution_id, - attempt_number, - dispatch_token, - result: error_to_value(&WorkerPoolError::new( - "RemoteWorkerPoolError", - err.to_string(), - )), - } - } - } -} - -struct RemoteWorkerPoolInner { - pool: Arc, - request_tx: mpsc::Sender, - request_rx: StdMutex>>, - completion_tx: mpsc::Sender, - completion_rx: Mutex>, - launched: AtomicBool, -} - -/// BaseWorkerPool implementation backed by a Python worker cluster. -#[derive(Clone)] -pub struct RemoteWorkerPool { - inner: Arc, -} - -impl RemoteWorkerPool { - const DEFAULT_QUEUE_CAPACITY: usize = 1024; - - pub fn new(pool: Arc) -> Self { - Self::with_capacity( - pool, - Self::DEFAULT_QUEUE_CAPACITY, - Self::DEFAULT_QUEUE_CAPACITY, - ) - } - - pub fn with_capacity( - pool: Arc, - request_capacity: usize, - completion_capacity: usize, - ) -> Self { - let (request_tx, request_rx) = mpsc::channel(request_capacity.max(1)); - let (completion_tx, completion_rx) = mpsc::channel(completion_capacity.max(1)); - Self { - inner: Arc::new(RemoteWorkerPoolInner { - pool, - request_tx, - request_rx: StdMutex::new(Some(request_rx)), - completion_tx, - completion_rx: Mutex::new(completion_rx), - launched: AtomicBool::new(false), - }), - } - } - - pub async fn new_with_config( - config: PythonWorkerConfig, - count: usize, - bind_addr: Option, - max_action_lifecycle: Option, - max_concurrent_per_worker: usize, - ) -> AnyResult { - let worker_count = count.max(1); - let per_worker = max_concurrent_per_worker.max(1); - let queue_capacity = worker_count - .saturating_mul(per_worker) - .saturating_mul(2) - .max(Self::DEFAULT_QUEUE_CAPACITY); - let pool = PythonWorkerPool::new_with_bridge_addr( - config, - count, - bind_addr, - max_action_lifecycle, - max_concurrent_per_worker, - ) - .await?; - Ok(Self::with_capacity( - Arc::new(pool), - queue_capacity, - queue_capacity, - )) - } - - pub fn bridge_addr(&self) -> SocketAddr { - self.inner.pool.bridge_addr() - } - - pub async fn shutdown(self) -> AnyResult<()> { - match Arc::try_unwrap(self.inner) { - Ok(inner) => match Arc::try_unwrap(inner.pool) { - Ok(pool) => pool.shutdown().await, - Err(_) => { - warn!("worker pool still referenced during shutdown; skipping shutdown"); - Ok(()) - } - }, - Err(_) => { - warn!("remote worker pool still referenced during shutdown; skipping shutdown"); - Ok(()) - } - } - } -} - -impl BaseWorkerPool for RemoteWorkerPool { - async fn launch(&self) -> std::result::Result<(), waymark_worker_core::WorkerPoolError> { - if self.inner.launched.swap(true, Ordering::SeqCst) { - return Ok(()); - } - - let request_rx = { - let mut guard = self.inner.request_rx.lock().map_err(|_| { - WorkerPoolError::new("RemoteWorkerPoolError", "failed to lock request receiver") - })?; - guard.take() - }; - - let Some(mut request_rx) = request_rx else { - return Ok(()); - }; - - let pool = Arc::clone(&self.inner.pool); - let completion_tx = self.inner.completion_tx.clone(); - - tokio::spawn(async move { - while let Some(request) = request_rx.recv().await { - let completion_tx = completion_tx.clone(); - let pool = Arc::clone(&pool); - tokio::spawn(async move { - let completion = execute_remote_request(pool, request).await; - let _ = completion_tx.send(completion).await; - }); - } - }); - - Ok(()) - } - - fn queue(&self, request: ActionRequest) -> Result<(), WorkerPoolError> { - self.inner.request_tx.try_send(request).map_err(|err| { - WorkerPoolError::new( - "RemoteWorkerPoolError", - format!("failed to enqueue action request: {err}"), - ) - }) - } - - async fn poll_complete(&self) -> Option> { - let mut receiver = self.inner.completion_rx.lock().await; - - let first = receiver.recv().await?; - - let mut completions = NEVec::new(first); - - while let Ok(item) = receiver.try_recv() { - completions.push(item); - } - - Some(completions) - } -} - -impl WorkerPoolStats for PythonWorkerPool { - fn stats_snapshot(&self) -> WorkerPoolStatsSnapshot { - let snapshots = self.throughput_snapshots(); - let active_workers = snapshots.len() as u16; - let throughput_per_min: f64 = snapshots.iter().map(|s| s.throughput_per_min).sum(); - let total_completed: i64 = snapshots.iter().map(|s| s.total_completed as i64).sum(); - let last_action_at = snapshots.iter().filter_map(|s| s.last_action_at).max(); - let (dispatch_queue_size, total_in_flight) = self.queue_stats(); - let (median_dequeue_ms, median_handling_ms) = self.median_latencies_ms(); - - WorkerPoolStatsSnapshot { - active_workers, - throughput_per_min, - total_completed, - last_action_at, - dispatch_queue_size, - total_in_flight, - median_dequeue_ms, - median_handling_ms, - } - } -} - -impl WorkerPoolStats for RemoteWorkerPool { - fn stats_snapshot(&self) -> WorkerPoolStatsSnapshot { - self.inner.pool.stats_snapshot() - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - use std::process::Stdio; - use std::sync::{ - Arc, - atomic::{AtomicU64, AtomicUsize, Ordering}, - }; - - use serde_json::json; - use tokio::process::Child; - use waymark_ids::{ExecutionId, InstanceId}; - - use super::*; - use waymark_worker_core::BaseWorkerPool; - - #[test] - fn test_config_builder() { - let config = PythonWorkerConfig::new() - .with_user_module("my_module") - .with_python_paths(vec![PathBuf::from("/extra/path")]); - - assert_eq!(config.user_modules, vec!["my_module".to_string()]); - assert_eq!( - config.extra_python_paths, - vec![PathBuf::from("/extra/path")] - ); - } - - #[test] - fn test_config_with_multiple_modules() { - let config = PythonWorkerConfig::new() - .with_user_modules(vec!["module1".to_string(), "module2".to_string()]); - - assert_eq!(config.user_modules, vec!["module1", "module2"]); - } - - #[test] - fn test_default_runner_detection() { - // Should return uv as fallback if waymark-worker not in PATH - let (path, args) = default_runner(); - // Either waymark-worker was found, or we get uv with args - if args.is_empty() { - assert!(path.to_string_lossy().contains("waymark-worker")); - } else { - assert_eq!(path, PathBuf::from("uv")); - assert_eq!(args, vec!["run", "python", "-m", "waymark.worker"]); - } - } - - fn make_string_kwarg(key: &str, value: &str) -> proto::WorkflowArgument { - proto::WorkflowArgument { - key: key.to_string(), - value: Some(proto::WorkflowArgumentValue { - kind: Some(proto::workflow_argument_value::Kind::Primitive( - proto::PrimitiveWorkflowArgument { - kind: Some(proto::primitive_workflow_argument::Kind::StringValue( - value.to_string(), - )), - }, - )), - }), - } - } - - fn spawn_stub_child() -> Child { - #[cfg(windows)] - { - Command::new("cmd") - .args(["/C", "timeout", "/T", "60", "/NOBREAK"]) - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("spawn windows stub child") - } - #[cfg(not(windows))] - { - Command::new("sleep") - .arg("60") - .stdin(Stdio::null()) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .spawn() - .expect("spawn unix stub child") - } - } - - async fn test_bridge() -> Option> { - match WorkerBridgeServer::start(None).await { - Ok(server) => Some(server), - Err(err) => { - let message = format!("{err:?}"); - if message.contains("Operation not permitted") - || message.contains("Permission denied") - { - None - } else { - panic!("start worker bridge: {err}"); - } - } - } - } - - fn make_result_payload(value: Value) -> proto::WorkflowArguments { - proto::WorkflowArguments { - arguments: vec![proto::WorkflowArgument { - key: "result".to_string(), - value: Some(waymark_message_conversions::json_to_workflow_argument_value(&value)), - }], - } - } - - fn make_test_worker( - worker_id: u64, - ) -> ( - PythonWorker, - mpsc::Receiver, - mpsc::Sender, - ) { - let (to_worker, from_runner) = mpsc::channel(16); - let (to_runner, from_worker) = mpsc::channel(16); - let shared = Arc::new(Mutex::new(SharedState::new())); - let reader_shared = Arc::clone(&shared); - let reader_handle = tokio::spawn(async move { - let mut incoming = from_worker; - let _ = PythonWorker::reader_loop(&mut incoming, reader_shared).await; - }); - - let worker = PythonWorker { - child: spawn_stub_child(), - sender: to_worker, - shared, - next_delivery: AtomicU64::new(1), - reader_handle: Some(reader_handle), - worker_id, - }; - (worker, from_runner, to_runner) - } - - async fn make_single_worker_pool() -> Option<( - Arc, - mpsc::Receiver, - mpsc::Sender, - )> { - let bridge = test_bridge().await?; - let (worker, outgoing, incoming) = make_test_worker(0); - let pool = PythonWorkerPool { - workers: RwLock::new(vec![Arc::new(worker)]), - cursor: AtomicUsize::new(0), - metrics: StdMutex::new(WorkerPoolMetrics::new( - vec![0], - Duration::from_secs(THROUGHPUT_WINDOW_SECS), - LATENCY_SAMPLE_SIZE, - )), - action_counts: vec![AtomicU64::new(0)], - in_flight_counts: vec![AtomicUsize::new(0)], - max_concurrent_per_worker: 2, - max_action_lifecycle: None, - bridge, - config: PythonWorkerConfig::new(), - }; - Some((Arc::new(pool), outgoing, incoming)) - } - - #[tokio::test] - async fn test_send_action_roundtrip_happy_path() { - let (worker, mut outgoing, incoming) = make_test_worker(7); - let dispatch_token = Uuid::new_v4(); - - let responder = tokio::spawn(async move { - let envelope = outgoing.recv().await.expect("dispatch envelope"); - assert_eq!( - proto::MessageKind::try_from(envelope.kind).ok(), - Some(proto::MessageKind::ActionDispatch) - ); - let dispatch = proto::ActionDispatch::decode(envelope.payload.as_slice()) - .expect("decode dispatch"); - assert_eq!(dispatch.action_name, "greet"); - - incoming - .send(proto::Envelope { - delivery_id: envelope.delivery_id + 100, - partition_id: 0, - kind: proto::MessageKind::Ack as i32, - payload: proto::Ack { - acked_delivery_id: envelope.delivery_id, - } - .encode_to_vec(), - }) - .await - .expect("send ack"); - incoming - .send(proto::Envelope { - delivery_id: envelope.delivery_id, - partition_id: 0, - kind: proto::MessageKind::ActionResult as i32, - payload: proto::ActionResult { - action_id: dispatch.action_id, - success: true, - payload: Some(make_result_payload(json!("hello"))), - worker_start_ns: 10, - worker_end_ns: 42, - dispatch_token: Some(dispatch_token.to_string()), - error_type: None, - error_message: None, - } - .encode_to_vec(), - }) - .await - .expect("send action result"); - }); - - let metrics = worker - .send_action(ActionDispatchPayload { - action_id: "action-1".to_string(), - instance_id: "instance-1".to_string(), - sequence: 1, - action_name: "greet".to_string(), - module_name: "tests.actions".to_string(), - kwargs: proto::WorkflowArguments { - arguments: vec![make_string_kwarg("name", "World")], - }, - timeout_seconds: 30, - max_retries: 0, - attempt_number: 0, - dispatch_token, - }) - .await - .expect("send action"); - - responder.await.expect("responder task"); - assert!(metrics.success); - assert_eq!(metrics.action_id, "action-1"); - assert_eq!(metrics.instance_id, "instance-1"); - assert_eq!(metrics.dispatch_token, Some(dispatch_token)); - assert_eq!(metrics.worker_duration, Duration::from_nanos(32)); - - worker.shutdown().await.expect("shutdown worker"); - } - - #[test] - fn test_concurrency_slot_logic() { - // Test the atomic slot logic without spawning workers - let in_flight = AtomicUsize::new(0); - let max_concurrent = 3; - - // Helper to try acquire - let try_acquire = || { - loop { - let current = in_flight.load(Ordering::Acquire); - if current >= max_concurrent { - return false; - } - match in_flight.compare_exchange_weak( - current, - current + 1, - Ordering::AcqRel, - Ordering::Relaxed, - ) { - Ok(_) => return true, - Err(_) => continue, - } - } - }; - - // Acquire up to max - assert!(try_acquire()); - assert!(try_acquire()); - assert!(try_acquire()); - // At capacity - assert!(!try_acquire()); - assert_eq!(in_flight.load(Ordering::Relaxed), 3); - - // Release one - in_flight.fetch_sub(1, Ordering::Release); - assert_eq!(in_flight.load(Ordering::Relaxed), 2); - - // Can acquire again - assert!(try_acquire()); - assert!(!try_acquire()); - } - - #[tokio::test] - async fn test_execute_remote_request_happy_path() { - let Some((pool, mut outgoing, incoming)) = make_single_worker_pool().await else { - return; - }; - let request = ActionRequest { - executor_id: InstanceId::new_uuid_v4(), - execution_id: ExecutionId::new_uuid_v4(), - action_name: "double".to_string(), - module_name: Some("tests.actions".to_string()), - kwargs: HashMap::from([("value".to_string(), Value::Number(9.into()))]), - timeout_seconds: 0, - attempt_number: 1, - dispatch_token: Uuid::new_v4(), - }; - - let responder = tokio::spawn(async move { - let envelope = outgoing.recv().await.expect("dispatch envelope"); - incoming - .send(proto::Envelope { - delivery_id: envelope.delivery_id + 1, - partition_id: 0, - kind: proto::MessageKind::Ack as i32, - payload: proto::Ack { - acked_delivery_id: envelope.delivery_id, - } - .encode_to_vec(), - }) - .await - .expect("send ack"); - incoming - .send(proto::Envelope { - delivery_id: envelope.delivery_id, - partition_id: 0, - kind: proto::MessageKind::ActionResult as i32, - payload: proto::ActionResult { - action_id: "ignored".to_string(), - success: true, - payload: Some(make_result_payload(Value::Number(18.into()))), - worker_start_ns: 100, - worker_end_ns: 125, - dispatch_token: None, - error_type: None, - error_message: None, - } - .encode_to_vec(), - }) - .await - .expect("send result"); - }); - - let completion = execute_remote_request(Arc::clone(&pool), request.clone()).await; - responder.await.expect("responder task"); - assert_eq!(completion.executor_id, request.executor_id); - assert_eq!(completion.execution_id, request.execution_id); - assert_eq!(completion.result, Value::Number(18.into())); - assert_eq!(pool.total_in_flight(), 0); - - if let Ok(pool) = Arc::try_unwrap(pool) { - pool.shutdown().await.expect("shutdown pool"); - } - } - - #[tokio::test] - async fn test_remote_worker_pool_launch_queue_get_complete_happy_path() { - let Some((pool, mut outgoing, incoming)) = make_single_worker_pool().await else { - return; - }; - let remote = RemoteWorkerPool::new(Arc::clone(&pool)); - BaseWorkerPool::launch(&remote) - .await - .expect("launch remote pool"); - let request = ActionRequest { - executor_id: InstanceId::new_uuid_v4(), - execution_id: ExecutionId::new_uuid_v4(), - action_name: "square".to_string(), - module_name: Some("tests.actions".to_string()), - kwargs: HashMap::from([("value".to_string(), Value::Number(5.into()))]), - timeout_seconds: 0, - attempt_number: 1, - dispatch_token: Uuid::new_v4(), - }; - let execution_id = request.execution_id; - - let responder = tokio::spawn(async move { - let envelope = outgoing.recv().await.expect("dispatch envelope"); - incoming - .send(proto::Envelope { - delivery_id: envelope.delivery_id + 5, - partition_id: 0, - kind: proto::MessageKind::Ack as i32, - payload: proto::Ack { - acked_delivery_id: envelope.delivery_id, - } - .encode_to_vec(), - }) - .await - .expect("send ack"); - incoming - .send(proto::Envelope { - delivery_id: envelope.delivery_id, - partition_id: 0, - kind: proto::MessageKind::ActionResult as i32, - payload: proto::ActionResult { - action_id: "ignored".to_string(), - success: true, - payload: Some(make_result_payload(Value::Number(25.into()))), - worker_start_ns: 300, - worker_end_ns: 360, - dispatch_token: None, - error_type: None, - error_message: None, - } - .encode_to_vec(), - }) - .await - .expect("send result"); - }); - - BaseWorkerPool::queue(&remote, request).expect("queue request"); - let maybe_completions = BaseWorkerPool::poll_complete(&remote).await; - responder.await.expect("responder task"); - let completions = maybe_completions.unwrap(); - assert_eq!(completions.len().get(), 1); - assert_eq!(completions[0].execution_id, execution_id); - assert_eq!(completions[0].result, Value::Number(25.into())); - - drop(remote); - tokio::time::sleep(Duration::from_millis(10)).await; - if let Ok(pool) = Arc::try_unwrap(pool) { - pool.shutdown().await.expect("shutdown pool"); - } - } - - #[test] - fn test_action_result_success_false_deserialize() { - use prost::Message; - - // These are the bytes from Python when success=False is set - // The success field is NOT included because it's the default value in proto3 - let success_false_bytes: &[u8] = &[0x0a, 0x04, 0x74, 0x65, 0x73, 0x74]; - - // These are the bytes from Python when success=True is set - let success_true_bytes: &[u8] = &[0x0a, 0x04, 0x74, 0x65, 0x73, 0x74, 0x10, 0x01]; - - // Deserialize success=False case - let result_false = - proto::ActionResult::decode(success_false_bytes).expect("decode success=false"); - assert_eq!(result_false.action_id, "test"); - assert!( - !result_false.success, - "success should be false when field is omitted (proto3 default)" - ); - - // Deserialize success=True case - let result_true = - proto::ActionResult::decode(success_true_bytes).expect("decode success=true"); - assert_eq!(result_true.action_id, "test"); - assert!( - result_true.success, - "success should be true when field is 1" - ); - } -} diff --git a/crates/lib/worker-remote/src/server_worker.rs b/crates/lib/worker-remote/src/server_worker.rs deleted file mode 100644 index 94e29782..00000000 --- a/crates/lib/worker-remote/src/server_worker.rs +++ /dev/null @@ -1,347 +0,0 @@ -//! gRPC server for Python worker connections. -//! -//! The [`WorkerBridgeServer`] provides a bidirectional streaming gRPC service -//! that Python workers connect to. The protocol works as follows: -//! -//! 1. Rust spawns a worker process and reserves a worker ID -//! 2. Python worker connects to the bridge and sends a `WorkerHello` with its ID -//! 3. The bridge matches the connection to the reservation and establishes channels -//! 4. Bidirectional streaming begins: Rust sends actions, Python returns results - -use std::{ - collections::HashMap, - net::{IpAddr, Ipv4Addr, SocketAddr}, - pin::Pin, - sync::{ - Arc, - atomic::{AtomicU64, Ordering}, - }, -}; - -use anyhow::{Context, Result as AnyResult}; -use futures_core::Stream; -use prost::Message; -use tokio::{ - net::TcpListener, - sync::{Mutex, mpsc, oneshot}, - task::JoinHandle, -}; -use tokio_stream::{ - StreamExt, - wrappers::{ReceiverStream, TcpListenerStream}, -}; -use tonic::{Request, Response, Status, Streaming, async_trait, transport::Server}; -use tracing::{error, info, warn}; - -use waymark_proto::messages as proto; - -/// Channels for communicating with a connected worker. -/// Created when a worker successfully completes the handshake. -pub struct WorkerBridgeChannels { - /// Send actions to the worker - pub to_worker: mpsc::Sender, - /// Receive results from the worker - pub from_worker: mpsc::Receiver, -} - -/// Internal state for pending worker connections. -/// Workers are reserved before they connect, so we can correlate -/// the connection with the spawned process. -struct WorkerBridgeState { - /// Map of worker_id -> channel sender for completing the handshake - pending: Mutex>>, -} - -impl WorkerBridgeState { - fn new() -> Self { - Self { - pending: Mutex::new(HashMap::new()), - } - } - - /// Reserve a slot for an incoming worker connection. - /// Returns a receiver that will be signaled when the worker connects. - async fn reserve_worker(&self, worker_id: u64) -> oneshot::Receiver { - let (tx, rx) = oneshot::channel(); - let mut guard = self.pending.lock().await; - guard.insert(worker_id, tx); - rx - } - - /// Cancel a pending worker reservation (e.g., if spawn fails). - async fn cancel_worker(&self, worker_id: u64) { - let mut guard = self.pending.lock().await; - guard.remove(&worker_id); - } - - /// Complete the worker registration after receiving the handshake. - async fn register_worker( - &self, - worker_id: u64, - channels: WorkerBridgeChannels, - ) -> Result<(), Status> { - let sender = { - let mut guard = self.pending.lock().await; - guard.remove(&worker_id) - }; - match sender { - Some(waiter) => waiter - .send(channels) - .map_err(|_| Status::unavailable("worker reservation dropped")), - None => Err(Status::failed_precondition(format!( - "unknown worker id: {}", - worker_id - ))), - } - } -} - -/// gRPC service implementation for the WorkerBridge. -#[derive(Clone)] -struct WorkerBridgeService { - state: Arc, -} - -impl WorkerBridgeService { - fn new(state: Arc) -> Self { - Self { state } - } -} - -#[async_trait] -impl proto::worker_bridge_server::WorkerBridge for WorkerBridgeService { - type AttachStream = - Pin> + Send + 'static>>; - - async fn attach( - &self, - request: Request>, - ) -> Result, Status> { - let mut stream = request.into_inner(); - - // Read and validate the handshake message - let handshake = stream - .message() - .await - .map_err(|err| Status::internal(format!("failed to read handshake: {err}")))? - .ok_or_else(|| Status::invalid_argument("missing worker handshake"))?; - - let kind = proto::MessageKind::try_from(handshake.kind) - .map_err(|_| Status::invalid_argument("invalid message kind"))?; - - if kind != proto::MessageKind::WorkerHello { - return Err(Status::failed_precondition( - "expected WorkerHello as first message", - )); - } - - let hello = proto::WorkerHello::decode(&*handshake.payload).map_err(|err| { - Status::invalid_argument(format!("invalid WorkerHello payload: {err}")) - })?; - - let worker_id = hello.worker_id; - info!(worker_id, "worker connected and sent hello"); - - // Create channels for bidirectional communication - // Buffer size of 64 provides reasonable backpressure while allowing - // some pipelining of requests - let (to_worker_tx, to_worker_rx) = mpsc::channel(64); - let (from_worker_tx, from_worker_rx) = mpsc::channel(64); - - // Complete the registration - this unblocks the spawn code - self.state - .register_worker( - worker_id, - WorkerBridgeChannels { - to_worker: to_worker_tx, - from_worker: from_worker_rx, - }, - ) - .await?; - - // Spawn a task to read from the worker stream and forward to the channel - let reader_state = Arc::clone(&self.state); - tokio::spawn(async move { - loop { - match stream.message().await { - Ok(Some(envelope)) => { - if from_worker_tx.send(envelope).await.is_err() { - // Receiver dropped, worker shutting down - break; - } - } - Ok(None) => { - // Stream closed cleanly - info!(worker_id, "worker stream closed"); - break; - } - Err(err) => { - warn!(?err, worker_id, "worker stream receive error"); - break; - } - } - } - // Clean up the pending map in case of reconnection attempts - reader_state.cancel_worker(worker_id).await; - }); - - // Return a stream that sends from to_worker_rx to the Python client - let outbound = ReceiverStream::new(to_worker_rx).map(Ok::); - Ok(Response::new(Box::pin(outbound) as Self::AttachStream)) - } -} - -/// gRPC server for worker connections. -/// -/// Workers connect via bidirectional streaming and exchange action dispatch -/// and result messages. The server handles: -/// -/// - Worker handshake and registration -/// - Channel creation for message passing -/// - Graceful shutdown -/// -/// # Example -/// -/// ```ignore -/// let bridge = WorkerBridgeServer::start(None).await?; -/// let (worker_id, connection_rx) = bridge.reserve_worker().await; -/// // ... spawn worker process with worker_id ... -/// let channels = connection_rx.await?; // Wait for worker to connect -/// ``` -pub struct WorkerBridgeServer { - addr: SocketAddr, - state: Arc, - next_worker_id: AtomicU64, - shutdown_tx: Mutex>>, - server_handle: Mutex>>, -} - -impl WorkerBridgeServer { - /// Start the worker bridge server. - /// - /// If `bind_addr` is None, binds to localhost on an ephemeral port. - /// The actual bound address can be retrieved with [`Self::addr`]. - pub async fn start(bind_addr: Option) -> AnyResult> { - let bind_addr = - bind_addr.unwrap_or_else(|| SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0)); - - let listener = TcpListener::bind(bind_addr) - .await - .context("failed to bind worker bridge listener")?; - - let addr = listener - .local_addr() - .context("failed to resolve bridge addr")?; - - info!(%addr, "worker bridge server starting"); - - let state = Arc::new(WorkerBridgeState::new()); - let (shutdown_tx, shutdown_rx) = oneshot::channel(); - - let service = WorkerBridgeService::new(Arc::clone(&state)); - let server = tokio::spawn(async move { - let incoming = TcpListenerStream::new(listener); - let shutdown = async move { - let _ = shutdown_rx.await; - }; - let result = Server::builder() - .add_service(proto::worker_bridge_server::WorkerBridgeServer::new( - service, - )) - .serve_with_incoming_shutdown(incoming, shutdown) - .await; - if let Err(err) = result { - error!(?err, "worker bridge server exited with error"); - } - }); - - Ok(Arc::new(Self { - addr, - state, - next_worker_id: AtomicU64::new(0), - shutdown_tx: Mutex::new(Some(shutdown_tx)), - server_handle: Mutex::new(Some(server)), - })) - } - - /// Get the address the server is bound to. - pub fn addr(&self) -> SocketAddr { - self.addr - } - - /// Reserve a worker ID and get a receiver for when the worker connects. - /// - /// Call this before spawning a worker process. Pass the worker_id to the - /// process, then await the receiver to get the communication channels - /// once the worker connects and completes the handshake. - pub async fn reserve_worker(&self) -> (u64, oneshot::Receiver) { - let worker_id = self.next_worker_id.fetch_add(1, Ordering::SeqCst); - let rx = self.state.reserve_worker(worker_id).await; - (worker_id, rx) - } - - /// Cancel a pending worker reservation. - /// - /// Call this if the worker process failed to spawn. - pub async fn cancel_worker(&self, worker_id: u64) { - self.state.cancel_worker(worker_id).await; - } - - /// Gracefully shut down the server. - pub async fn shutdown(&self) { - if let Some(tx) = self.shutdown_tx.lock().await.take() { - let _ = tx.send(()); - } - if let Some(handle) = self.server_handle.lock().await.take() - && let Err(err) = handle.await - { - warn!(?err, "worker bridge task join failed"); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_server_starts_and_binds() { - let server = match WorkerBridgeServer::start(None).await { - Ok(server) => server, - Err(err) => { - let message = format!("{err:?}"); - if message.contains("Operation not permitted") - || message.contains("Permission denied") - { - return; - } - panic!("start server: {err}"); - } - }; - assert!(server.addr().port() > 0); - server.shutdown().await; - } - - #[tokio::test] - async fn test_reserve_worker_ids_increment() { - let server = match WorkerBridgeServer::start(None).await { - Ok(server) => server, - Err(err) => { - let message = format!("{err:?}"); - if message.contains("Operation not permitted") - || message.contains("Permission denied") - { - return; - } - panic!("start server: {err}"); - } - }; - let (id1, _) = server.reserve_worker().await; - let (id2, _) = server.reserve_worker().await; - let (id3, _) = server.reserve_worker().await; - assert_eq!(id1, 0); - assert_eq!(id2, 1); - assert_eq!(id3, 2); - server.shutdown().await; - } -} diff --git a/crates/lib/worker-reservation/Cargo.toml b/crates/lib/worker-reservation/Cargo.toml new file mode 100644 index 00000000..4b490e16 --- /dev/null +++ b/crates/lib/worker-reservation/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "waymark-worker-reservation" +edition = "2024" +version.workspace = true +publish.workspace = true + +[dependencies] +slotmap = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true, features = ["sync"] } diff --git a/crates/lib/worker-reservation/src/id.rs b/crates/lib/worker-reservation/src/id.rs new file mode 100644 index 00000000..ca4ad9c8 --- /dev/null +++ b/crates/lib/worker-reservation/src/id.rs @@ -0,0 +1,32 @@ +slotmap::new_key_type! { + pub struct Id; +} + +impl core::fmt::Display for Id { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.0.as_ffi()) + } +} + +impl core::str::FromStr for Id { + type Err = std::num::ParseIntError; + + fn from_str(s: &str) -> Result { + let ffi = s.parse()?; + let key_data = slotmap::KeyData::from_ffi(ffi); + Ok(Self::from(key_data)) + } +} + +impl From for Id { + fn from(value: u64) -> Self { + let key_data = slotmap::KeyData::from_ffi(value); + Self::from(key_data) + } +} + +impl From for u64 { + fn from(value: Id) -> Self { + value.0.as_ffi() + } +} diff --git a/crates/lib/worker-reservation/src/lib.rs b/crates/lib/worker-reservation/src/lib.rs new file mode 100644 index 00000000..9d3b13ab --- /dev/null +++ b/crates/lib/worker-reservation/src/lib.rs @@ -0,0 +1,9 @@ +//! Tooling for tracking the state of connecting worker reservations. + +mod id; +mod registry; +mod reservation; + +pub use self::id::*; +pub use self::registry::*; +pub use self::reservation::*; diff --git a/crates/lib/worker-reservation/src/registry.rs b/crates/lib/worker-reservation/src/registry.rs new file mode 100644 index 00000000..8340837b --- /dev/null +++ b/crates/lib/worker-reservation/src/registry.rs @@ -0,0 +1,76 @@ +use std::sync::Arc; + +use slotmap::SlotMap; + +use crate::Reservation; + +pub(crate) type ActiveReservations = + SlotMap>; + +pub struct Registry { + active_reservations: std::sync::Mutex>, +} + +impl Default for Registry { + fn default() -> Self { + Self { + active_reservations: Default::default(), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RegisterError { + #[error("reservation {reservation_id} not found")] + ReservationNotFound { + reservation_id: crate::Id, + payload: Payload, + }, + + #[error("unable to send payload for reservation {reservation_id}")] + PayloadSend { + reservation_id: crate::Id, + payload: Payload, + }, +} + +impl Registry { + pub fn reserve(self: &Arc) -> Reservation { + let mut active_reservations = self.active_reservations.lock().unwrap(); + + let (tx, rx) = tokio::sync::oneshot::channel(); + + let reservation_id = active_reservations.insert(tx); + + let registry = Arc::clone(self); + + Reservation::issue_from_registry(registry, reservation_id, rx) + } + + /// Complete the worker registration after receiving the handshake. + pub fn register( + &self, + reservation_id: crate::Id, + payload: Payload, + ) -> Result<(), RegisterError> { + let mut active_reservations = self.active_reservations.lock().unwrap(); + + let Some(tx) = active_reservations.remove(reservation_id) else { + return Err(RegisterError::ReservationNotFound { + reservation_id, + payload, + }); + }; + + tx.send(payload) + .map_err(|payload| RegisterError::PayloadSend { + reservation_id, + payload, + }) + } + + pub(crate) fn reservation_drop_cleanup(&self, reservation_id: crate::Id) { + let mut active_reservations = self.active_reservations.lock().unwrap(); + let _ = active_reservations.remove(reservation_id); + } +} diff --git a/crates/lib/worker-reservation/src/reservation.rs b/crates/lib/worker-reservation/src/reservation.rs new file mode 100644 index 00000000..a6a69282 --- /dev/null +++ b/crates/lib/worker-reservation/src/reservation.rs @@ -0,0 +1,54 @@ +use std::sync::Arc; + +use crate::Registry; + +/// A reservation handle that allows to wait for payload. +pub struct Reservation { + registry: Arc>, + id: crate::Id, + rx: Option>, +} + +#[derive(Debug, thiserror::Error)] +#[error("reservation {reservation_id} was cancelled")] +pub struct ReservationCancelledError { + pub reservation_id: crate::Id, +} + +impl Reservation { + pub(crate) fn issue_from_registry( + registry: Arc>, + id: crate::Id, + rx: tokio::sync::oneshot::Receiver, + ) -> Self { + Self { + registry, + id, + rx: Some(rx), + } + } + + pub fn id(&self) -> crate::Id { + self.id + } + + /// Serve the reservation, by waiting for the payload to be sent. + pub async fn wait(mut self) -> Result { + let channel_rx = self.rx.take().unwrap(); // only ever consumed here + match channel_rx.await { + Ok(val) => Ok(val), + Err(_) => Err(ReservationCancelledError { + reservation_id: self.id, + }), + } + } +} + +impl Drop for Reservation { + fn drop(&mut self) { + if let Some(rx) = &mut self.rx { + rx.close() + } + self.registry.reservation_drop_cleanup(self.id); + } +} diff --git a/crates/lib/worker-status-reporter/src/lib.rs b/crates/lib/worker-status-reporter/src/lib.rs index ec1dbb6f..4706a088 100644 --- a/crates/lib/worker-status-reporter/src/lib.rs +++ b/crates/lib/worker-status-reporter/src/lib.rs @@ -12,7 +12,7 @@ use waymark_nonzero_duration::NonZeroDuration; pub async fn run( pool_id: Uuid, backend: B, - worker_pool: P, + worker_pool: impl AsRef

, active_instances: Arc, interval: NonZeroDuration, shutdown: tokio_util::sync::WaitForCancellationFutureOwned, @@ -35,7 +35,7 @@ pub async fn run( loop { tokio::select! { _ = ticker.tick() => { - let stats = worker_pool.stats_snapshot(); + let stats = worker_pool.as_ref().stats_snapshot(); let actions_per_sec = stats.throughput_per_min / 60.0; let active_instances_count = active_instances.load(Ordering::SeqCst); let active_instances_u32 =