From 44fbeb0eb515f6d80700c90b7f9617d430afe1f0 Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 12:24:37 -0400 Subject: [PATCH 01/15] heavy revision work based around runtimes module tree --- .gitignore | 1 + docs/agents/domain.md | 40 + docs/agents/issue-tracker.md | 22 + docs/agents/triage-labels.md | 15 + migrations/0001_init.sql | 177 +--- migrations/postgres/0001_init.sql | 135 +-- src/runtimes/checkpointer.rs | 338 ++---- src/runtimes/checkpointer_postgres.rs | 988 +++++++----------- src/runtimes/checkpointer_postgres_helpers.rs | 72 -- src/runtimes/checkpointer_sqlite.rs | 980 +++++++---------- src/runtimes/checkpointer_sqlite_helpers.rs | 71 -- src/runtimes/execution.rs | 90 +- src/runtimes/metrics_observer.rs | 63 +- src/runtimes/mod.rs | 47 +- src/runtimes/observer.rs | 100 +- src/runtimes/persistence.rs | 186 ++-- src/runtimes/replay.rs | 160 ++- src/runtimes/runner.rs | 947 +++++++---------- src/runtimes/runtime_config.rs | 160 ++- src/runtimes/session.rs | 64 +- src/runtimes/streaming.rs | 39 +- src/runtimes/types.rs | 133 +-- 22 files changed, 1671 insertions(+), 3157 deletions(-) create mode 100644 docs/agents/domain.md create mode 100644 docs/agents/issue-tracker.md create mode 100644 docs/agents/triage-labels.md delete mode 100644 src/runtimes/checkpointer_postgres_helpers.rs delete mode 100644 src/runtimes/checkpointer_sqlite_helpers.rs diff --git a/.gitignore b/.gitignore index 0e5dcb1..3fc8b17 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ data/*.sqlite data/*.tmp .vscode/ tmp/ +AGENTS.md # Prevent committing test binaries test_deprecation diff --git a/docs/agents/domain.md b/docs/agents/domain.md new file mode 100644 index 0000000..da9818b --- /dev/null +++ b/docs/agents/domain.md @@ -0,0 +1,40 @@ +# Domain Docs + +How the engineering skills should consume this repo's domain documentation when exploring the codebase. + +## Layout + +This is a single-context repo. Domain documentation is expected at the repo root: + +- **`CONTEXT.md`** — project vocabulary and domain language +- **`docs/adr/`** — architectural decision records + +If any of these files don't exist, **proceed silently**. Don't flag their absence; don't suggest creating them upfront. The producer skill (`/grill-with-docs`) creates them lazily when terms or decisions actually get resolved. + +## Before exploring, read these + +- **`CONTEXT.md`** at the repo root +- **`docs/adr/`** — read ADRs that touch the area you're about to work in + +## File structure + +```text +/ +├── CONTEXT.md +├── docs/adr/ +│ ├── 0001-event-sourced-orders.md +│ └── 0002-postgres-for-write-model.md +└── src/ +``` + +## Use the glossary's vocabulary + +When your output names a domain concept (in an issue title, a refactor proposal, a hypothesis, a test name), use the term as defined in `CONTEXT.md`. Don't drift to synonyms the glossary explicitly avoids. + +If the concept you need isn't in the glossary yet, that's a signal: either you're inventing language the project doesn't use (reconsider) or there's a real gap (note it for `/grill-with-docs`). + +## Flag ADR conflicts + +If your output contradicts an existing ADR, surface it explicitly rather than silently overriding: + +> _Contradicts ADR-0007 (event-sourced orders), but worth reopening because..._ diff --git a/docs/agents/issue-tracker.md b/docs/agents/issue-tracker.md new file mode 100644 index 0000000..6bdd84d --- /dev/null +++ b/docs/agents/issue-tracker.md @@ -0,0 +1,22 @@ +# Issue tracker: GitHub + +Issues and PRDs for this repo live as GitHub issues. Use the `gh` CLI for all operations. + +## Conventions + +- **Create an issue**: `gh issue create --title "..." --body "..."`. Use a heredoc for multi-line bodies. +- **Read an issue**: `gh issue view --comments`, filtering comments by `jq` and also fetching labels. +- **List issues**: `gh issue list --state open --json number,title,body,labels,comments --jq '[.[] | {number, title, body, labels: [.labels[].name], comments: [.comments[].body]}]'` with appropriate `--label` and `--state` filters. +- **Comment on an issue**: `gh issue comment --body "..."` +- **Apply / remove labels**: `gh issue edit --add-label "..."` / `--remove-label "..."` +- **Close**: `gh issue close --comment "..."` + +Infer the repo from `git remote -v`; `gh` does this automatically when run inside a clone. + +## When a skill says "publish to the issue tracker" + +Create a GitHub issue. + +## When a skill says "fetch the relevant ticket" + +Run `gh issue view --comments`. diff --git a/docs/agents/triage-labels.md b/docs/agents/triage-labels.md new file mode 100644 index 0000000..cd2a8c0 --- /dev/null +++ b/docs/agents/triage-labels.md @@ -0,0 +1,15 @@ +# Triage Labels + +The skills speak in terms of five canonical triage roles. This file maps those roles to the actual label strings used in this repo's issue tracker. + +| Label in skills | Label in our tracker | Meaning | +| -------------------------- | -------------------- | ---------------------------------------- | +| `needs-triage` | `needs-triage` | Maintainer needs to evaluate this issue | +| `needs-info` | `needs-info` | Waiting on reporter for more information | +| `ready-for-agent` | `ready-for-agent` | Fully specified, ready for an AFK agent | +| `ready-for-human` | `ready-for-human` | Requires human implementation | +| `wontfix` | `wontfix` | Will not be actioned | + +When a skill mentions a role (e.g. "apply the AFK-ready triage label"), use the corresponding label string from this table. + +Community/discovery labels such as `bug`, `enhancement`, `good first issue`, `help wanted`, `documentation`, `api-design`, and `breaking-change` are additive metadata. Do not use them as replacements for the canonical triage-state labels. diff --git a/migrations/0001_init.sql b/migrations/0001_init.sql index 47c0291..9a614d4 100644 --- a/migrations/0001_init.sql +++ b/migrations/0001_init.sql @@ -1,170 +1,55 @@ --- 0001_init.sql +-- SQLite schema for Weavegraph session and step checkpointing. -- --- Initial SQLite schema for Weavegraph session & step checkpointing. --- This supports a future `SQLiteCheckpointer` implementation that can: --- * Create / resume sessions by `session_id` --- * Persist a full durable checkpoint after every barrier (superstep) --- * Query historical steps (for audit, replay, diffing, debugging) +-- Two tables: +-- sessions — one row per session; carries a denormalized latest-step snapshot +-- so resume is a single point-lookup with no aggregate. +-- steps — append-only checkpoint history; one row per barrier crossing. -- --- Design notes (aligned with runtimes/checkpointer.rs & runner.rs types): --- Checkpoint fields we persist per step: --- - session_id (string) --- - step (u64 -> INTEGER) --- - state (VersionedState) -> JSON (TEXT) --- - frontier (Vec) -> JSON (TEXT) --- - versions_seen (HashMap<..>) -> JSON (TEXT) --- - ran_nodes / skipped_nodes -> JSON (TEXT) (from StepReport) --- - updated_channels -> JSON (TEXT) (Vec<&'static str>) --- - created_at timestamp --- - (Optionally) concurrency_limit is denormalized at the session level --- --- We also keep a denormalized "latest" snapshot on the `sessions` row so --- resuming a session can be a single SELECT (without an aggregate). --- --- JSON is stored as TEXT (SQLite default). The application layer (SQLx) is --- responsible for (de)serialization and validation. We may later add CHECK --- constraints using json_valid(...) if desired (requires JSON1 extension). --- --- Timestamps are stored in RFC3339/ISO8601 (UTC) via strftime. All times UTC. --- --- Foreign keys are enforced (ON DELETE CASCADE ensures step history is removed --- when a session is deleted). --- --- Step numbering starts at 1 (after first barrier) though the schema does not --- enforce an origin; the runner should ensure monotonic increment. --- --- NodeKind serialization suggestion (not enforced here): --- Start -> "Start" --- End -> "End" --- Other -> {"Other":""} --- or a simpler flat string encoding: "Start", "End", "Other:" --- (Must be consistent across state/frontier/ran/skipped arrays.) --- -PRAGMA foreign_keys = ON; - ---------------------------------------------------------------------------- --- Sessions ---------------------------------------------------------------------------- +-- All JSON columns are TEXT. Timestamps are RFC 3339 UTC. +-- Steps are immutable once written; monotonic step numbering is enforced by the +-- application layer, not the schema. CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, -- session_id - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), - updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), - - -- Concurrency limit used when the session was created (for reference / resume) - concurrency_limit INTEGER NOT NULL, - - -- Denormalized latest checkpoint snapshot (mirrors most recent row in steps) - last_step INTEGER NOT NULL DEFAULT 0, - last_state_json TEXT, -- Full VersionedState JSON (messages, extra, versions) - last_frontier_json TEXT, -- JSON array of node kinds - last_versions_seen_json TEXT -- JSON object: { "": { "messages": , "extra": , ... } } + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')), + concurrency_limit INTEGER NOT NULL, + last_step INTEGER NOT NULL DEFAULT 0, + last_state_json TEXT, + last_frontier_json TEXT, + last_versions_seen_json TEXT ); -CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC); - ---------------------------------------------------------------------------- --- Steps (historical checkpoints) ---------------------------------------------------------------------------- +CREATE INDEX IF NOT EXISTS idx_sessions_updated_at + ON sessions(updated_at DESC); CREATE TABLE IF NOT EXISTS steps ( - session_id TEXT NOT NULL, - step INTEGER NOT NULL, - created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ','now')), - - -- Durable snapshot data - state_json TEXT NOT NULL, -- Full VersionedState JSON - frontier_json TEXT NOT NULL, -- JSON array - versions_seen_json TEXT NOT NULL, -- JSON object of objects - - -- Execution metadata (from StepReport) - ran_nodes_json TEXT NOT NULL, -- JSON array - skipped_nodes_json TEXT NOT NULL, -- JSON array - updated_channels_json TEXT, -- JSON array of updated channel names (may be empty/NULL) - - -- Optional future fields (placeholders for forward-compat): - -- error_json TEXT, -- structured error info if barrier failed - -- pause_reason_json TEXT, -- if an interrupt paused the session at this step - + session_id TEXT NOT NULL, + step INTEGER NOT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')), + state_json TEXT NOT NULL, + frontier_json TEXT NOT NULL, + versions_seen_json TEXT NOT NULL, + ran_nodes_json TEXT NOT NULL, + skipped_nodes_json TEXT NOT NULL, + updated_channels_json TEXT, PRIMARY KEY (session_id, step), FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE ); -CREATE INDEX IF NOT EXISTS idx_steps_session_step_desc +-- Covers both DESC (latest-step resume) and ASC (replay) scans. +CREATE INDEX IF NOT EXISTS idx_steps_session_id_step ON steps(session_id, step DESC); --- Fast access to chronological iteration (ascending) -CREATE INDEX IF NOT EXISTS idx_steps_session_step_asc - ON steps(session_id, step ASC); - ---------------------------------------------------------------------------- --- Triggers ---------------------------------------------------------------------------- - --- Keep sessions.updated_at & denormalized latest snapshot in sync on insert. +-- Advance the denormalized snapshot on the parent session row after each insert. CREATE TRIGGER IF NOT EXISTS trg_steps_after_insert AFTER INSERT ON steps BEGIN UPDATE sessions - SET - updated_at = strftime('%Y-%m-%dT%H:%M:%fZ','now'), + SET updated_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now'), last_step = NEW.step, last_state_json = NEW.state_json, last_frontier_json = NEW.frontier_json, last_versions_seen_json = NEW.versions_seen_json WHERE id = NEW.session_id; END; - --- (Optional) If you later allow updating a step row (should be rare), --- keep the denormalized snapshot accurate only when updating the latest step. -CREATE TRIGGER IF NOT EXISTS trg_steps_after_update -AFTER UPDATE ON steps -WHEN (SELECT last_step FROM sessions WHERE id = NEW.session_id) = NEW.step -BEGIN - UPDATE sessions - SET - updated_at = strftime('%Y-%m-%dT%H:%M:%fZ','now'), - last_state_json = NEW.state_json, - last_frontier_json = NEW.frontier_json, - last_versions_seen_json = NEW.versions_seen_json - WHERE id = NEW.session_id; -END; - ---------------------------------------------------------------------------- --- Views (Convenience) ---------------------------------------------------------------------------- - --- Latest checkpoint per session (essentially mirrors sessions.* but sourced --- from authoritative steps table if you prefer not to trust denormalized columns). -CREATE VIEW IF NOT EXISTS v_latest_checkpoints AS -SELECT - s.id AS session_id, - s.concurrency_limit, - s.created_at AS session_created_at, - s.updated_at AS session_updated_at, - st.step, - st.created_at AS step_created_at, - st.state_json, - st.frontier_json, - st.versions_seen_json, - st.ran_nodes_json, - st.skipped_nodes_json, - st.updated_channels_json -FROM sessions s -LEFT JOIN steps st - ON st.session_id = s.id - AND st.step = s.last_step; - ---------------------------------------------------------------------------- --- Integrity / Sanity Notes (enforced at application layer for now): --- * step must be monotonic increasing per session (PRIMARY KEY + app logic) --- * versions_seen_json should contain only non-negative integers --- * frontier_json, ran_nodes_json, skipped_nodes_json are arrays of node encodings --- * state_json must contain messages + extra + version metadata --- --- Future migration ideas: --- * Move JSON to FTS5 virtual tables for semantic search over messages --- * Add error / pause tables for richer observability --- * Differential checkpoints (store deltas after initial baseline) --- --- End of migration. diff --git a/migrations/postgres/0001_init.sql b/migrations/postgres/0001_init.sql index ae483a9..928886f 100644 --- a/migrations/postgres/0001_init.sql +++ b/migrations/postgres/0001_init.sql @@ -1,124 +1,49 @@ --- 0001_init.sql +-- PostgreSQL schema for Weavegraph session and step checkpointing. -- --- Initial PostgreSQL schema for Weavegraph session & step checkpointing. --- This supports the `PostgresCheckpointer` implementation that can: --- * Create / resume sessions by `session_id` --- * Persist a full durable checkpoint after every barrier (superstep) --- * Query historical steps (for audit, replay, diffing, debugging) +-- Two tables: +-- sessions — one row per session; carries a denormalized latest-step snapshot +-- so resume is a single point-lookup with no aggregate. +-- steps — append-only checkpoint history; one row per barrier crossing. -- --- Design notes (aligned with runtimes/checkpointer.rs & runner.rs types): --- Checkpoint fields we persist per step: --- - session_id (string) --- - step (u64 -> BIGINT) --- - state (VersionedState) -> JSONB --- - frontier (Vec) -> JSONB --- - versions_seen (HashMap<..>) -> JSONB --- - ran_nodes / skipped_nodes -> JSONB (from StepReport) --- - updated_channels -> JSONB (Vec<&'static str>) --- - created_at timestamp --- - (Optionally) concurrency_limit is denormalized at the session level --- --- We also keep a denormalized "latest" snapshot on the `sessions` row so --- resuming a session can be a single SELECT (without an aggregate). --- --- JSONB is used for efficient storage and querying of JSON data. --- --- Timestamps use TIMESTAMPTZ for timezone-aware storage (all times UTC). --- --- Foreign keys are enforced (ON DELETE CASCADE ensures step history is removed --- when a session is deleted). --- --- Step numbering starts at 1 (after first barrier) though the schema does not --- enforce an origin; the runner should ensure monotonic increment. --- --- NodeKind serialization suggestion (not enforced here): --- Start -> "Start" --- End -> "End" --- Other -> {"Other":""} --- or a simpler flat string encoding: "Start", "End", "Other:" --- (Must be consistent across state/frontier/ran/skipped arrays.) - ---------------------------------------------------------------------------- --- Sessions ---------------------------------------------------------------------------- +-- All JSON columns are JSONB. Timestamps are TIMESTAMPTZ (UTC). +-- The denormalized last_* columns on sessions are advanced by the application +-- in the same transaction as each step write (no trigger). CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, -- session_id - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - - -- Concurrency limit used when the session was created (for reference / resume) - concurrency_limit BIGINT NOT NULL, - - -- Denormalized latest checkpoint snapshot (mirrors most recent row in steps) - last_step BIGINT NOT NULL DEFAULT 0, - last_state_json JSONB, -- Full VersionedState JSON (messages, extra, versions) - last_frontier_json JSONB, -- JSON array of node kinds - last_versions_seen_json JSONB -- JSON object: { "": { "messages": , "extra": , ... } } + id TEXT PRIMARY KEY, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + concurrency_limit BIGINT NOT NULL, + last_step BIGINT NOT NULL DEFAULT 0, + last_state_json JSONB, + last_frontier_json JSONB, + last_versions_seen_json JSONB ); -CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON sessions(updated_at DESC); - ---------------------------------------------------------------------------- --- Steps (historical checkpoints) ---------------------------------------------------------------------------- +CREATE INDEX IF NOT EXISTS idx_sessions_updated_at + ON sessions(updated_at DESC); CREATE TABLE IF NOT EXISTS steps ( - session_id TEXT NOT NULL, - step BIGINT NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), - - -- Durable snapshot data - state_json JSONB NOT NULL, -- Full VersionedState JSON - frontier_json JSONB NOT NULL, -- JSON array - versions_seen_json JSONB NOT NULL, -- JSON object of objects - - -- Execution metadata (from StepReport) - ran_nodes_json JSONB NOT NULL, -- JSON array - skipped_nodes_json JSONB NOT NULL, -- JSON array - updated_channels_json JSONB, -- JSON array of updated channel names (may be empty/NULL) - - -- Optional future fields (placeholders for forward-compat): - -- error_json JSONB, -- structured error info if barrier failed - -- pause_reason_json JSONB, -- if an interrupt paused the session at this step - + session_id TEXT NOT NULL, + step BIGINT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + state_json JSONB NOT NULL, + frontier_json JSONB NOT NULL, + versions_seen_json JSONB NOT NULL, + ran_nodes_json JSONB NOT NULL, + skipped_nodes_json JSONB NOT NULL, + updated_channels_json JSONB, PRIMARY KEY (session_id, step), FOREIGN KEY (session_id) REFERENCES sessions(id) ON DELETE CASCADE ); -CREATE INDEX IF NOT EXISTS idx_steps_session_step_desc +-- Covers both DESC (latest-step resume) and ASC (replay) scans. +CREATE INDEX IF NOT EXISTS idx_steps_session_id_step ON steps(session_id, step DESC); --- Fast access to chronological iteration (ascending) -CREATE INDEX IF NOT EXISTS idx_steps_session_step_asc - ON steps(session_id, step ASC); - --- JSONB indexing for efficient containment queries used by query_steps() --- (e.g. ran_nodes_json @> '["Start"]') +-- JSONB containment indexes for query_steps ran_node / skipped_node filters. CREATE INDEX IF NOT EXISTS idx_steps_ran_nodes_gin ON steps USING GIN (ran_nodes_json); CREATE INDEX IF NOT EXISTS idx_steps_skipped_nodes_gin ON steps USING GIN (skipped_nodes_json); - - --- Denormalized session snapshot maintenance --- --- We intentionally do NOT use database triggers to maintain sessions.last_*. --- The application updates those fields in the same transaction as step writes. --- This makes the "latest" pointer monotonic by construction even if steps are --- written out-of-order (replays/imports/retries). - ---------------------------------------------------------------------------- --- Integrity / Sanity Notes (enforced at application layer for now): --- * step must be monotonic increasing per session (PRIMARY KEY + app logic) --- * versions_seen_json should contain only non-negative integers --- * frontier_json, ran_nodes_json, skipped_nodes_json are arrays of node encodings --- * state_json must contain messages + extra + version metadata --- --- Future migration ideas: --- * Move JSON to full-text search for semantic search over messages --- * Add error / pause tables for richer observability --- * Differential checkpoints (store deltas after initial baseline) --- --- End of migration. diff --git a/src/runtimes/checkpointer.rs b/src/runtimes/checkpointer.rs index 35acc14..717340e 100644 --- a/src/runtimes/checkpointer.rs +++ b/src/runtimes/checkpointer.rs @@ -1,29 +1,15 @@ -//! Checkpointer infrastructure with thread‑safe persistence. +//! Checkpointer trait, shared types, and the in-memory backend. //! -//! This module provides: -//! - Thread‑safe async operations using `tokio::sync::RwLock` -//! - Tracing integration (`#[instrument]`) for observability -//! - Multiple backends: in‑memory (volatile) and SQLite (durable) +//! Three concrete implementations are available, selected via [`CheckpointerType`]: +//! - [`InMemoryCheckpointer`] — volatile, process-local; suitable for tests and +//! single-run workloads. +//! - `SQLiteCheckpointer` (`feature = "sqlite"`) — file-backed with full step history. +//! - `PostgresCheckpointer` (`feature = "postgres"`) — server-backed with full step history. //! -//! # Thread Safety -//! All implementations use async‑aware synchronization primitives and can be -//! called across `await` points without blocking executor threads. -//! -//! # Observability -//! Each operation on `InMemoryCheckpointer` and `SQLiteCheckpointer` is -//! instrumented with structured tracing fields (e.g. `session_id`, `step`). -//! Enable debug logging to view activity: +//! Enable debug tracing to inspect operations: //! ```bash //! RUST_LOG=weavegraph::runtimes::checkpointer=debug cargo run //! ``` -//! -//! # Storage Management -//! - **InMemoryCheckpointer**: Stores only the latest checkpoint per session -//! (implicit retention; no history). -//! - **SQLiteCheckpointer**: Stores full step history for durable audit and -//! replay; use external SQL maintenance for long‑running deployments. -//! -//! See type‑level docs for cleanup guidance. use async_trait::async_trait; use chrono::{DateTime, Utc}; @@ -35,101 +21,50 @@ use crate::{ types::NodeKind, }; -/// A durable snapshot of session execution state at a barrier boundary. -/// -/// This structure captures both the current state and execution history -/// to enable full session resumption and audit trails. +/// Snapshot of session execution state captured at a barrier boundary. #[derive(Debug, Clone)] pub struct Checkpoint { - /// Unique identifier of the workflow session this checkpoint belongs to. + /// Session this checkpoint belongs to. pub session_id: String, - /// Execution step number at the time of this checkpoint. + /// Barrier step index at capture time. pub step: u64, - /// Full versioned state snapshot captured at this step. + /// Full versioned-state snapshot. pub state: VersionedState, - /// Node frontier to resume from when restoring this checkpoint. + /// Node frontier to resume from. pub frontier: Vec, - /// Scheduler version-gating state for change detection. - pub versions_seen: FxHashMap>, // scheduler gating - /// Maximum concurrent nodes configured for this session. + /// Scheduler version-gating counters. + pub versions_seen: FxHashMap>, + /// Maximum concurrent nodes for this session. pub concurrency_limit: usize, - /// Timestamp at which this checkpoint was created. + /// Wall-clock time of capture. pub created_at: DateTime, - /// Nodes that executed in this step (empty for step 0) + /// Nodes that executed in this step (empty for step 0). pub ran_nodes: Vec, - /// Nodes that were skipped in this step (empty for step 0) + /// Nodes that were skipped in this step (empty for step 0). pub skipped_nodes: Vec, - /// Channels that were updated in this step (empty for step 0) + /// Channels updated in this step (empty for step 0). pub updated_channels: Vec, } impl Checkpoint { - /// Create a checkpoint from the current session state. - /// - /// This captures a snapshot of the session's execution state that can be - /// persisted and later restored to resume execution from this point. - /// - /// # Parameters - /// - /// * `session_id` - Unique identifier for the session - /// * `session` - Current session state to checkpoint - /// - /// # Returns - /// - /// A `Checkpoint` containing all necessary state for resumption - /// - /// # Examples - /// - /// ```rust,no_run - /// # use weavegraph::runtimes::{Checkpoint, SessionState}; - /// # fn example(session_state: SessionState) { - /// let checkpoint = Checkpoint::from_session("my_session", &session_state); - /// // checkpoint can now be saved via a Checkpointer - /// # } - /// ``` + /// Build a checkpoint from raw session state, with no step execution metadata. #[must_use] pub fn from_session(session_id: &str, session: &SessionState) -> Self { Self { - session_id: session_id.to_string(), + session_id: session_id.to_owned(), step: session.step, state: session.state.clone(), frontier: session.frontier.clone(), versions_seen: session.scheduler_state.versions_seen.clone(), concurrency_limit: session.scheduler.concurrency_limit, created_at: Utc::now(), - ran_nodes: vec![], // No execution history for raw session state + ran_nodes: vec![], skipped_nodes: vec![], updated_channels: vec![], } } - /// Create a checkpoint from a completed step report. - /// - /// This captures the full execution context including what nodes ran, - /// were skipped, and which channels were updated during the step. - /// - /// # Parameters - /// - /// * `session_id` - Unique identifier for the session - /// * `session_state` - Current session state after step execution - /// * `step_report` - Details of what happened during step execution - /// - /// # Returns - /// - /// A `Checkpoint` with complete step execution metadata - /// - /// # Examples - /// - /// ```rust,no_run - /// # use weavegraph::runtimes::{Checkpoint, SessionState, StepReport}; - /// # fn example(session_state: SessionState, step_report: StepReport) { - /// let checkpoint = Checkpoint::from_step_report( - /// "my_session", - /// &session_state, - /// &step_report - /// ); - /// # } - /// ``` + /// Build a checkpoint from a completed step, capturing full execution metadata. #[must_use] pub fn from_step_report( session_id: &str, @@ -137,7 +72,7 @@ impl Checkpoint { step_report: &crate::runtimes::execution::StepReport, ) -> Self { Self { - session_id: session_id.to_string(), + session_id: session_id.to_owned(), step: session_state.step, state: session_state.state.clone(), frontier: session_state.frontier.clone(), @@ -150,18 +85,18 @@ impl Checkpoint { .barrier_outcome .updated_channels .iter() - .map(|s| (*s).to_string()) + .map(|s| s.to_string()) .collect(), } } } -/// Errors from checkpointer operations. +/// Errors returned by [`Checkpointer`] operations. #[derive(Debug, thiserror::Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] #[non_exhaustive] pub enum CheckpointerError { - /// Session was not found in the checkpointer. + /// No checkpoint found for the given session. #[error("session not found: {session_id}")] #[cfg_attr( feature = "diagnostics", @@ -177,7 +112,7 @@ pub enum CheckpointerError { session_id: String, }, - /// Backend storage error (database, filesystem, etc.). + /// A storage backend error (database, filesystem, etc.). #[error("backend error: {message}")] #[cfg_attr( feature = "diagnostics", @@ -187,180 +122,58 @@ pub enum CheckpointerError { ) )] Backend { - /// Description of the backend storage error. + /// Description of the backend error. message: String, }, - /// Other checkpointer errors. + /// An unexpected or miscellaneous checkpointer error. #[error("checkpointer error: {message}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::checkpointer::other)) )] Other { - /// Human-readable description of the error. + /// Human-readable error description. message: String, }, } -/// Selects the backing implementation of the `Checkpointer` trait. -/// -/// Variants: -/// * `InMemory` – Volatile process‑local storage. Fast, non‑durable; suitable for -/// tests and ephemeral runs. -/// * `SQLite` – Durable, file (or memory) backed storage using `SQLiteCheckpointer` -/// (see `runtimes::checkpointer_sqlite`). Persists step history and the latest -/// snapshot for session resumption. -/// -/// Note: -/// The runtime previously had an unreachable wildcard match when exhaustively -/// enumerating these variants. If additional variants are added in the future, -/// they should be explicitly matched (or a deliberate catch‑all retained). +/// Selects the backing store used by the runtime's checkpointer. #[derive(Debug, Clone, PartialEq, Eq)] pub enum CheckpointerType { - /// In‑memory (non‑durable) checkpointing. + /// Volatile, process-local storage. Fast; no durability. InMemory, #[cfg(feature = "sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] - /// SQLite‑backed durable checkpointing (see `SQLiteCheckpointer`). + /// File-backed SQLite storage with full step history. SQLite, #[cfg(feature = "postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] - /// PostgreSQL‑backed durable checkpointing (see `PostgresCheckpointer`). + /// PostgreSQL-backed storage with full step history. Postgres, } -/// Convenience alias for checkpointer operation results. +/// Convenience alias for checkpointer results. pub type Result = std::result::Result; -/// Trait for persistent storage and retrieval of workflow execution state. -/// -/// Checkpointers provide durable storage for workflow execution state, enabling -/// session resumption across process restarts. Implementations must ensure that -/// checkpoints are atomic and consistent. -/// -/// # Design Principles -/// -/// - **Atomicity**: Checkpoint saves should be all-or-nothing operations -/// - **Consistency**: The stored state should always be in a valid, resumable state -/// - **Idempotency**: Saving the same checkpoint multiple times should be safe -/// - **Isolation**: Concurrent access to different sessions should not interfere -/// -/// # Implementation Notes -/// -/// - All operations should be idempotent where possible -/// - Concurrent access to the same session should be handled gracefully -/// - Backend errors should be mapped to appropriate `CheckpointerError` variants -/// - The `save` operation replaces any existing checkpoint for the session -/// - The `load_latest` operation returns `None` for non-existent sessions -/// -/// # Thread Safety -/// -/// All implementations must be `Send + Sync` to allow usage across async tasks -/// and thread boundaries. Interior mutability should use appropriate synchronization -/// primitives (e.g., `RwLock`, `Mutex`). -/// -/// # Error Handling -/// -/// Methods should return specific `CheckpointerError` variants: -/// - `NotFound`: When a session doesn't exist (only for operations that require it) -/// - `Backend`: For storage-related errors (database, filesystem, network) -/// - `Other`: For serialization errors or other unexpected conditions -/// -/// # Examples -/// -/// ```rust,no_run -/// use weavegraph::runtimes::{Checkpointer, Checkpoint, InMemoryCheckpointer}; -/// use weavegraph::state::VersionedState; +/// Persistent storage and retrieval of workflow execution state. /// -/// # async fn example() -> Result<(), Box> { -/// let checkpointer = InMemoryCheckpointer::new(); -/// -/// // Save a checkpoint -/// let state = VersionedState::new_with_user_message("Hello"); -/// // ... create checkpoint from session state -/// # let checkpoint = todo!(); // placeholder -/// checkpointer.save(checkpoint).await?; -/// -/// // Load the latest checkpoint -/// if let Some(checkpoint) = checkpointer.load_latest("session_id").await? { -/// // Resume execution from checkpoint -/// println!("Resuming from step {}", checkpoint.step); -/// } -/// -/// // List all sessions -/// let sessions = checkpointer.list_sessions().await?; -/// println!("Found {} sessions", sessions.len()); -/// # Ok(()) -/// # } -/// ``` +/// Implementations must be `Send + Sync` and handle concurrent access to the +/// same session gracefully. The `save` operation replaces any existing +/// snapshot for the session; `load_latest` returns `None` for absent sessions. #[async_trait] pub trait Checkpointer: Send + Sync { - /// Persist the latest checkpoint for a session. - /// - /// This operation should be atomic and idempotent. If a checkpoint already - /// exists for the session, it will be replaced. The implementation should - /// ensure that concurrent saves to the same session are handled safely. - /// - /// # Parameters - /// - /// * `checkpoint` - The checkpoint data to persist - /// - /// # Returns - /// - /// * `Ok(())` - Checkpoint was successfully saved - /// * `Err(CheckpointerError)` - Save operation failed - /// - /// # Errors - /// - /// * `Backend` - Storage backend error (database, filesystem, etc.) - /// * `Other` - Serialization error or other unexpected condition + /// Persist a checkpoint, replacing any prior snapshot for the same session. async fn save(&self, checkpoint: Checkpoint) -> Result<()>; - /// Load the most recent checkpoint for a session. - /// - /// Returns `None` if no checkpoint exists for the given session ID. - /// This operation should be consistent with the latest `save` operation. - /// - /// # Parameters - /// - /// * `session_id` - Unique identifier for the session - /// - /// # Returns - /// - /// * `Ok(Some(checkpoint))` - Latest checkpoint was found and loaded - /// * `Ok(None)` - No checkpoint exists for this session - /// * `Err(CheckpointerError)` - Load operation failed - /// - /// # Errors - /// - /// * `Backend` - Storage backend error - /// * `Other` - Deserialization error or corruption + /// Return the most recent checkpoint for a session, or `None` if absent. async fn load_latest(&self, session_id: &str) -> Result>; - /// List all session IDs known to this checkpointer. - /// - /// Returns a vector of session IDs that have at least one checkpoint - /// stored. The order is implementation-defined but should be consistent. - /// - /// # Returns - /// - /// * `Ok(session_ids)` - List of all known session IDs - /// * `Err(CheckpointerError)` - List operation failed - /// - /// # Errors - /// - /// * `Backend` - Storage backend error + /// Return all session IDs that have at least one stored checkpoint. async fn list_sessions(&self) -> Result>; } -/// Simple in‑memory checkpointer with implicit retention. -/// -/// Characteristics: -/// - Volatile: process‑local only -/// - Retention: last checkpoint per session (no historical steps) -/// - Concurrency: `std::sync::RwLock` for fast synchronous access (no async overhead) -/// - Observability: `#[instrument]` on public trait methods +/// Volatile in-memory checkpointer. Keeps only the latest snapshot per session. /// /// Enable debug tracing to inspect operations: /// ```bash @@ -372,16 +185,10 @@ pub struct InMemoryCheckpointer { } impl InMemoryCheckpointer { - /// Create a new in-memory checkpointer. - /// - /// # Returns - /// - /// A new `InMemoryCheckpointer` instance + /// Create a new, empty in-memory checkpointer. #[must_use] pub fn new() -> Self { - Self { - inner: RwLock::new(FxHashMap::default()), - } + Self::default() } } @@ -389,63 +196,36 @@ impl InMemoryCheckpointer { impl Checkpointer for InMemoryCheckpointer { #[tracing::instrument(skip(self), fields(session_id = %checkpoint.session_id, step = checkpoint.step))] async fn save(&self, checkpoint: Checkpoint) -> Result<()> { - let mut map = self - .inner + self.inner .write() - .expect("InMemoryCheckpointer RwLock poisoned"); - map.insert(checkpoint.session_id.clone(), checkpoint); + .expect("InMemoryCheckpointer RwLock poisoned") + .insert(checkpoint.session_id.clone(), checkpoint); Ok(()) } #[tracing::instrument(skip(self), fields(session_id = %session_id))] async fn load_latest(&self, session_id: &str) -> Result> { - let map = self + Ok(self .inner .read() - .expect("InMemoryCheckpointer RwLock poisoned"); - Ok(map.get(session_id).cloned()) + .expect("InMemoryCheckpointer RwLock poisoned") + .get(session_id) + .cloned()) } #[tracing::instrument(skip(self))] async fn list_sessions(&self) -> Result> { - let map = self + Ok(self .inner .read() - .expect("InMemoryCheckpointer RwLock poisoned"); - Ok(map.keys().cloned().collect()) + .expect("InMemoryCheckpointer RwLock poisoned") + .keys() + .cloned() + .collect()) } } -/// Restore a `SessionState` from a persisted `Checkpoint`. -/// -/// This utility function reconstructs the in-memory session state from a -/// checkpoint, allowing execution to resume from the checkpointed step. -/// The restored state maintains all version information and scheduler state -/// for seamless continuation. -/// -/// # Parameters -/// -/// * `cp` - The checkpoint to restore from -/// -/// # Returns -/// -/// A `SessionState` ready for continued execution with: -/// - Restored versioned state channels (messages, extra) -/// - Correct step counter and frontier nodes -/// - Reconstructed scheduler with original concurrency limits -/// - Preserved version tracking for proper barrier coordination -/// -/// # Examples -/// -/// ```rust,no_run -/// # use weavegraph::runtimes::{restore_session_state, Checkpoint}; -/// # async fn example(checkpoint: Checkpoint) { -/// let session_state = restore_session_state(&checkpoint); -/// // session_state can now be used to continue execution -/// assert_eq!(session_state.step, checkpoint.step); -/// assert_eq!(session_state.frontier, checkpoint.frontier); -/// # } -/// ``` +/// Reconstruct a [`SessionState`] from a persisted [`Checkpoint`]. #[must_use = "restored session state should be used to continue execution"] pub fn restore_session_state(cp: &Checkpoint) -> SessionState { use crate::schedulers::Scheduler; diff --git a/src/runtimes/checkpointer_postgres.rs b/src/runtimes/checkpointer_postgres.rs index c652e11..b0a5681 100644 --- a/src/runtimes/checkpointer_postgres.rs +++ b/src/runtimes/checkpointer_postgres.rs @@ -1,52 +1,28 @@ /*! -PostgreSQL Checkpointer +PostgreSQL checkpointer backend. -This module provides the `PostgresCheckpointer` async implementation of the -`Checkpointer` trait defined in `runtimes/checkpointer.rs`. +Implements [`Checkpointer`] over a [`PgPool`], storing the full step history +with paginated queries and optional optimistic concurrency control. -## Features +## Schema -- **Complete Step History**: Stores full execution metadata including ran/skipped nodes -- **Pagination Support**: Efficient querying of large checkpoint histories -- **Optimistic Concurrency**: Prevention of concurrent checkpoint conflicts -- **Serde Integration**: Uses persistence models for consistent serialization +- `sessions.id` — session identifier (primary key) +- `sessions.concurrency_limit` — maximum concurrent nodes +- `sessions.last_step`, `sessions.last_*_json` — cached latest snapshot, + advanced monotonically so out-of-order replays cannot regress it +- `steps.session_id + steps.step` — composite primary key +- `steps.*_json` — JSONB columns for state, frontier, versions_seen, + ran/skipped nodes, and updated channel names -## Behavior +## Migrations -- Uses serde-based persistence models (see `runtimes::persistence`) for - encoding `VersionedState`, frontier node kinds, and `versions_seen`. -- When the `postgres-migrations` feature is enabled, embedded - migrations (`sqlx::migrate!("./migrations/postgres")`) are executed on connect; - disabling the feature assumes external migration orchestration. +Enable the `postgres-migrations` feature to apply embedded migrations on +connect; otherwise the schema must be managed externally. -## Design Goals +## NodeKind encoding -- Keep this module focused on database I/O; pure serialization lives in - the persistence module. -- Provide efficient querying with filtering and pagination support. -- Ensure data consistency with optimistic concurrency control. - -## Database Schema - -The checkpoint data maps to database tables as follows: - -- `sessions.id` ← `checkpoint.session_id` -- `sessions.concurrency_limit` ← `checkpoint.concurrency_limit` -- `steps.session_id` ← `checkpoint.session_id` -- `steps.step` ← `checkpoint.step` -- `steps.state_json` ← serialized `VersionedState` (JSONB) -- `steps.frontier_json` ← JSON array of encoded `NodeKind` (JSONB) -- `steps.versions_seen_json` ← JSON object (node → channel → version) (JSONB) -- `steps.ran_nodes_json` ← JSON array of executed nodes (JSONB) -- `steps.skipped_nodes_json` ← JSON array of skipped nodes (JSONB) -- `steps.updated_channels_json` ← JSON array of updated channel names (JSONB) - -## NodeKind Encoding - -NodeKinds are encoded as strings for JSON storage: -- `Start` → `"Start"` -- `End` → `"End"` -- `Custom(name)` → `"Custom:"` +`NodeKind` is encoded as a string: `Start` → `"Start"`, `End` → `"End"`, +`Custom(n)` → `"Custom:"`. */ use std::sync::Arc; @@ -63,89 +39,59 @@ use crate::{ types::NodeKind, }; -use super::checkpointer_postgres_helpers::{ - deserialize_json_value, require_json_field, serialize_json, -}; +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- -/// Query parameters for filtering step history. +/// Filters and pagination controls for [`PostgresCheckpointer::query_steps`]. #[derive(Debug, Clone, Default)] pub struct StepQuery { - /// Maximum number of results to return (capped at 1000) + /// Maximum results per page (defaults to 100, capped at 1 000). pub limit: Option, - /// Number of results to skip (for pagination) + /// Zero-based offset of the first result (for cursor-free pagination). pub offset: Option, - /// Filter by minimum step number (inclusive) + /// Restrict to steps with step number ≥ this value. pub min_step: Option, - /// Filter by maximum step number (inclusive) + /// Restrict to steps with step number ≤ this value. pub max_step: Option, - /// Only return steps that executed the specified node + /// Restrict to steps where this node ran. pub ran_node: Option, - /// Only return steps that skipped the specified node + /// Restrict to steps where this node was skipped. pub skipped_node: Option, } -/// Pagination information for query results. +/// Pagination metadata included in a [`StepQueryResult`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PageInfo { - /// Total number of matching records + /// Total records matching the query filters (before pagination). pub total_count: u64, - /// Number of records returned in this page + /// Records returned in this page. pub page_size: u32, - /// Zero-based offset of the first record in this page + /// Zero-based offset of the first record in this page. pub offset: u32, - /// Whether there are more records after this page + /// Whether records remain after this page. pub has_next_page: bool, } -/// Paginated query result for step history. +/// A page of checkpoint records from [`PostgresCheckpointer::query_steps`]. #[derive(Debug, Clone)] pub struct StepQueryResult { - /// The matching checkpoints + /// Checkpoints on this page, ordered by step descending. pub checkpoints: Vec, - /// Pagination metadata + /// Pagination metadata. pub page_info: PageInfo, } -/// PostgreSQL-backed checkpointer with full step history. -/// -/// Provides durable checkpoint storage with advanced querying capabilities -/// including pagination, filtering, and optimistic concurrency control. -/// -/// # Storage Growth -/// -/// This backend stores complete step history. Storage grows roughly with: -/// `(sessions × steps_per_session × state_size)`. -/// -/// For long-running applications, plan periodic cleanup to control database size: -/// -/// ## Option 1: Direct SQL maintenance (recommended) -/// -/// ```bash -/// # Delete checkpoints older than 30 days -/// psql -c "DELETE FROM steps WHERE created_at < NOW() - INTERVAL '30 days'" -/// -/// # Keep only latest 100 steps per session -/// psql -c " -/// DELETE FROM steps s -/// WHERE s.step NOT IN ( -/// SELECT step FROM steps -/// WHERE session_id = s.session_id -/// ORDER BY step DESC -/// LIMIT 100 -/// ) -/// " -/// -/// # Reclaim space -/// psql -c "VACUUM ANALYZE" -/// ``` -/// -/// ## Option 2: Application lifecycle management +// --------------------------------------------------------------------------- +// PostgresCheckpointer +// --------------------------------------------------------------------------- + +/// Durable PostgreSQL checkpointer that retains the full step history. /// -/// Delete entire sessions when workflows complete or expire. The schema includes -/// timestamps (`created_at` on steps, `updated_at` on sessions) to facilitate -/// time-based policies. +/// Storage grows with session count and step depth. For long-lived deployments +/// use periodic SQL maintenance — for example, delete steps older than N days +/// or keep only the most recent M steps per session, then `VACUUM ANALYZE`. pub struct PostgresCheckpointer { - /// Shared PostgreSQL connection pool for concurrent checkpoint operations pool: Arc, } @@ -156,231 +102,120 @@ impl std::fmt::Debug for PostgresCheckpointer { } impl PostgresCheckpointer { - /// Connect to a PostgreSQL database at `database_url`. - /// Example URL: \"postgresql://user:password@localhost/weavegraph\" + /// Open a connection pool to `database_url` and return a ready checkpointer. /// - /// Returns a configured `PostgresCheckpointer` ready for use. + /// When the `postgres-migrations` feature is enabled, embedded migrations + /// are applied automatically (idempotent). #[must_use = "checkpointer must be used to persist state"] #[instrument(skip(database_url))] pub async fn connect(database_url: &str) -> std::result::Result { let pool = PgPool::connect(database_url) .await .map_err(|e| CheckpointerError::Backend { - message: format!("connect error: {e}"), + message: format!("connect: {e}"), })?; - // Run embedded migrations only if the feature is enabled (idempotent). + #[cfg(feature = "postgres-migrations")] - { - if let Err(e) = sqlx::migrate!("./migrations/postgres").run(&pool).await { - return Err(CheckpointerError::Backend { - message: format!("migration failure: {e}"), - }); - } - } - #[cfg(not(feature = "postgres-migrations"))] - { - // Feature disabled: assume external migration orchestration already applied schema. - } + sqlx::migrate!("./migrations/postgres") + .run(&pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("migration: {e}"), + })?; + Ok(Self { pool: Arc::new(pool), }) } -} -#[async_trait::async_trait] -impl Checkpointer for PostgresCheckpointer { - #[instrument(skip(self, checkpoint), err)] - async fn save(&self, checkpoint: Checkpoint) -> Result<()> { - // Serialize using persistence module (serde-based) - let persisted_state = PersistedState::from(&checkpoint.state); - let state_json = serialize_json(&persisted_state, "state")?; - let frontier_enc: Vec = checkpoint.frontier.iter().map(|k| k.encode()).collect(); - let frontier_json = serialize_json(&frontier_enc, "frontier")?; - let persisted_vs = PersistedVersionsSeen(checkpoint.versions_seen.clone()); - let versions_seen_json = serialize_json(&persisted_vs, "versions_seen")?; - - // Serialize step execution metadata - let ran_nodes_enc: Vec = checkpoint.ran_nodes.iter().map(|k| k.encode()).collect(); - let ran_nodes_json = serialize_json(&ran_nodes_enc, "ran_nodes")?; - let skipped_nodes_enc: Vec = checkpoint - .skipped_nodes - .iter() - .map(|k| k.encode()) - .collect(); - let skipped_nodes_json = serialize_json(&skipped_nodes_enc, "skipped_nodes")?; - let updated_channels_json = - serialize_json(&checkpoint.updated_channels, "updated_channels")?; - - let mut tx = self - .pool + async fn begin_tx(&self) -> Result { + self.pool .begin() .await .map_err(|e| CheckpointerError::Backend { - message: format!("tx begin: {e}"), - })?; + message: format!("begin transaction: {e}"), + }) + } +} - // Ensure session row (upsert with ON CONFLICT DO NOTHING) - sqlx::query( - r#" - INSERT INTO sessions (id, concurrency_limit) - VALUES ($1, $2) - ON CONFLICT (id) DO NOTHING - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.concurrency_limit as i64) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert session: {e}"), - })?; +// --------------------------------------------------------------------------- +// Checkpointer trait impl +// --------------------------------------------------------------------------- - // Insert or replace step row (upsert for idempotent re-save of same step) - sqlx::query( - r#" - INSERT INTO steps ( - session_id, - step, - state_json, - frontier_json, - versions_seen_json, - ran_nodes_json, - skipped_nodes_json, - updated_channels_json - ) VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb) - ON CONFLICT (session_id, step) DO UPDATE SET - state_json = EXCLUDED.state_json, - frontier_json = EXCLUDED.frontier_json, - versions_seen_json = EXCLUDED.versions_seen_json, - ran_nodes_json = EXCLUDED.ran_nodes_json, - skipped_nodes_json = EXCLUDED.skipped_nodes_json, - updated_channels_json = EXCLUDED.updated_channels_json - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.step as i64) - .bind(&state_json) - .bind(&frontier_json) - .bind(&versions_seen_json) - .bind(&ran_nodes_json) - .bind(&skipped_nodes_json) - .bind(&updated_channels_json) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert step: {e}"), - })?; +#[async_trait::async_trait] +impl Checkpointer for PostgresCheckpointer { + #[instrument(skip(self, checkpoint), err)] + async fn save(&self, checkpoint: Checkpoint) -> Result<()> { + let enc = EncodedCheckpoint::encode(&checkpoint)?; + let mut tx = self.begin_tx().await?; - // Maintain denormalized latest snapshot on sessions row. - // - // IMPORTANT: steps may be written out-of-order (replays/imports/retries), - // so this update MUST be monotonic. We only advance last_* when the new - // step is >= the currently recorded last_step. - sqlx::query( - r#" - UPDATE sessions - SET - updated_at = NOW(), - last_step = CASE WHEN last_step <= $2 THEN $2 ELSE last_step END, - last_state_json = CASE WHEN last_step <= $2 THEN $3::jsonb ELSE last_state_json END, - last_frontier_json = CASE WHEN last_step <= $2 THEN $4::jsonb ELSE last_frontier_json END, - last_versions_seen_json = CASE WHEN last_step <= $2 THEN $5::jsonb ELSE last_versions_seen_json END - WHERE id = $1 - "#, + exec_upsert_session( + &mut tx, + &checkpoint.session_id, + checkpoint.concurrency_limit, ) - .bind(&checkpoint.session_id) - .bind(checkpoint.step as i64) - .bind(&state_json) - .bind(&frontier_json) - .bind(&versions_seen_json) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("update session latest: {e}"), - })?; + .await?; + exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; + exec_update_latest(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; tx.commit().await.map_err(|e| CheckpointerError::Backend { - message: format!("tx commit: {e}"), - })?; - - Ok(()) + message: format!("commit: {e}"), + }) } #[instrument(skip(self, session_id), err)] async fn load_latest(&self, session_id: &str) -> Result> { - let row_opt: Option = sqlx::query( - r#" - SELECT - s.id, - s.last_step, - s.last_state_json, - s.last_frontier_json, - s.last_versions_seen_json, - s.concurrency_limit, - s.updated_at - FROM sessions s - WHERE s.id = $1 - "#, + let row: Option = sqlx::query( + "SELECT id, last_step, last_state_json, last_frontier_json, \ + last_versions_seen_json, concurrency_limit, updated_at \ + FROM sessions WHERE id = $1", ) .bind(session_id) .fetch_optional(&*self.pool) .await .map_err(|e| CheckpointerError::Backend { - message: format!("select latest: {e}"), + message: format!("load_latest: {e}"), })?; - let row = match row_opt { + let row = match row { Some(r) => r, None => return Ok(None), }; let last_step: i64 = row.get("last_step"); + let concurrency_limit: i64 = row.get("concurrency_limit"); + let updated_at: DateTime = row.get("updated_at"); let state_json: Option = row.try_get("last_state_json") .map_err(|e| CheckpointerError::Backend { - message: format!("last_state_json read: {e}"), + message: format!("last_state_json: {e}"), })?; let frontier_json: Option = row.try_get("last_frontier_json") .map_err(|e| CheckpointerError::Backend { - message: format!("last_frontier_json read: {e}"), + message: format!("last_frontier_json: {e}"), })?; let versions_seen_json: Option = row.try_get("last_versions_seen_json") .map_err(|e| CheckpointerError::Backend { - message: format!("last_versions_seen_json read: {e}"), + message: format!("last_versions_seen_json: {e}"), })?; - let concurrency_limit: i64 = row.get("concurrency_limit"); - let updated_at: DateTime = row.get("updated_at"); + // Session row exists but no checkpoint written yet. if last_step == 0 && state_json.is_none() { - // Session row exists but no checkpoint has been persisted yet. return Ok(None); } - let state_val = require_json_field(state_json, "state_json")?; - let frontier_val = require_json_field(frontier_json, "frontier_json")?; - let versions_seen_val = require_json_field(versions_seen_json, "versions_seen_json")?; - - // Deserialize using persistence models - let persisted_state: PersistedState = deserialize_json_value(state_val, "state")?; - let state = - VersionedState::try_from(persisted_state).map_err(|e| CheckpointerError::Other { - message: format!("state convert: {e}"), - })?; - let frontier: Vec = frontier_val - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "frontier not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - let persisted_vs: PersistedVersionsSeen = - deserialize_json_value(versions_seen_val, "versions_seen")?; - let versions_seen = persisted_vs.0; + let state = decode_state(need_field(state_json, "last_state_json")?)?; + let frontier = decode_node_kinds(need_field(frontier_json, "last_frontier_json")?)?; + let versions_seen = { + let pv: PersistedVersionsSeen = from_json_value( + need_field(versions_seen_json, "last_versions_seen_json")?, + "versions_seen", + )?; + pv.0 + }; Ok(Some(Checkpoint { session_id: session_id.to_string(), @@ -390,8 +225,8 @@ impl Checkpointer for PostgresCheckpointer { versions_seen, concurrency_limit: concurrency_limit as usize, created_at: updated_at, - // Note: load_latest uses denormalized session data which doesn't include - // step execution metadata. Use query_steps() for full checkpoint details. + // Denormalized session row omits per-step execution metadata; + // call query_steps() to retrieve those fields. ran_nodes: vec![], skipped_nodes: vec![], updated_channels: vec![], @@ -400,209 +235,159 @@ impl Checkpointer for PostgresCheckpointer { #[instrument(skip(self), err)] async fn list_sessions(&self) -> Result> { - let rows = sqlx::query( - r#" - SELECT id FROM sessions - ORDER BY updated_at DESC - "#, - ) - .fetch_all(&*self.pool) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("list sessions: {e}"), - })?; - - Ok(rows.into_iter().map(|r| r.get::("id")).collect()) + sqlx::query("SELECT id FROM sessions ORDER BY updated_at DESC") + .fetch_all(&*self.pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("list_sessions: {e}"), + }) + .map(|rows| rows.into_iter().map(|r| r.get::("id")).collect()) } } -// Extended PostgresCheckpointer methods (not part of base Checkpointer trait) +// --------------------------------------------------------------------------- +// Extended methods +// --------------------------------------------------------------------------- + impl PostgresCheckpointer { - /// Query step history with filtering and pagination. - /// - /// This method provides comprehensive access to checkpoint history with - /// support for filtering by step range, node execution, and pagination - /// for efficient access to large histories. - /// - /// # Parameters + /// Query the full step history for a session with optional filters and pagination. /// - /// * `session_id` - Session to query - /// * `query` - Filter and pagination parameters + /// Results are ordered by step number descending. `limit` defaults to 100 and + /// is capped at 1 000. The total record count before pagination is reported in + /// [`PageInfo::total_count`]. /// - /// # Returns - /// - /// * `Ok(StepQueryResult)` - Matching checkpoints with pagination info - /// * `Err(CheckpointerError)` - Query execution failure - /// - /// # Examples + /// # Example /// /// ```rust,no_run /// use weavegraph::runtimes::checkpointer_postgres::{PostgresCheckpointer, StepQuery}; - /// use weavegraph::types::NodeKind; - /// - /// # async fn example() -> Result<(), Box> { - /// let checkpointer = PostgresCheckpointer::connect("postgresql://localhost/app").await?; /// - /// // Get recent steps with pagination - /// let query = StepQuery { - /// limit: Some(10), - /// offset: Some(0), - /// min_step: Some(5), + /// # async fn run() -> Result<(), Box> { + /// let cp = PostgresCheckpointer::connect("postgresql://localhost/app").await?; + /// let result = cp.query_steps("my-session", StepQuery { + /// limit: Some(25), + /// min_step: Some(10), /// ..Default::default() - /// }; - /// - /// let result = checkpointer.query_steps("session1", query).await?; - /// println!("Found {} steps", result.page_info.page_size); + /// }).await?; + /// println!("{} total, {} on page", result.page_info.total_count, result.page_info.page_size); /// # Ok(()) /// # } /// ``` #[instrument(skip(self), err)] pub async fn query_steps(&self, session_id: &str, query: StepQuery) -> Result { - // Build WHERE clause conditions + let limit = query.limit.unwrap_or(100).min(1_000); + let offset = query.offset.unwrap_or(0); + + // Build WHERE clause; $1 is always session_id. let mut conditions = vec!["st.session_id = $1".to_string()]; - let mut param_count = 1; + let mut param = 1u32; if query.min_step.is_some() { - param_count += 1; - conditions.push(format!("st.step >= ${param_count}")); + param += 1; + conditions.push(format!("st.step >= ${param}")); } if query.max_step.is_some() { - param_count += 1; - conditions.push(format!("st.step <= ${param_count}")); + param += 1; + conditions.push(format!("st.step <= ${param}")); } if query.ran_node.is_some() { - param_count += 1; - conditions.push(format!("st.ran_nodes_json @> ${param_count}::jsonb")); + param += 1; + conditions.push(format!("st.ran_nodes_json @> ${param}::jsonb")); } if query.skipped_node.is_some() { - param_count += 1; - conditions.push(format!("st.skipped_nodes_json @> ${param_count}::jsonb")); + param += 1; + conditions.push(format!("st.skipped_nodes_json @> ${param}::jsonb")); } let where_clause = conditions.join(" AND "); - // Count total matching records - let count_sql = format!("SELECT COUNT(*) as total FROM steps st WHERE {where_clause}"); - - let limit = query.limit.unwrap_or(100).min(1000); // Cap at 1000 - let offset = query.offset.unwrap_or(0); - - // Query with pagination + let count_sql = format!("SELECT COUNT(*) AS total FROM steps st WHERE {where_clause}"); let select_sql = format!( - r#"SELECT - st.session_id, - st.step, - st.state_json, - st.frontier_json, - st.versions_seen_json, - st.ran_nodes_json, - st.skipped_nodes_json, - st.updated_channels_json, - st.created_at, - s.concurrency_limit - FROM steps st - JOIN sessions s ON s.id = st.session_id - WHERE {where_clause} - ORDER BY st.step DESC - LIMIT {limit} OFFSET {offset}"# + "SELECT st.session_id, st.step, st.state_json, st.frontier_json, \ + st.versions_seen_json, st.ran_nodes_json, st.skipped_nodes_json, \ + st.updated_channels_json, st.created_at, s.concurrency_limit \ + FROM steps st \ + JOIN sessions s ON s.id = st.session_id \ + WHERE {where_clause} \ + ORDER BY st.step DESC \ + LIMIT {limit} OFFSET {offset}" ); - // Execute count query - let mut count_query = sqlx::query(&count_sql).bind(session_id); - if let Some(min_step) = query.min_step { - count_query = count_query.bind(min_step as i64); - } - if let Some(max_step) = query.max_step { - count_query = count_query.bind(max_step as i64); - } - if let Some(ran_node) = &query.ran_node { - count_query = count_query.bind(serde_json::json!([ran_node.encode()])); - } - if let Some(skipped_node) = &query.skipped_node { - count_query = count_query.bind(serde_json::json!([skipped_node.encode()])); - } - - let total_count: i64 = count_query - .fetch_one(&*self.pool) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("count query: {e}"), - })? - .get("total"); - - // Execute select query - let mut select_query = sqlx::query(&select_sql).bind(session_id); - if let Some(min_step) = query.min_step { - select_query = select_query.bind(min_step as i64); - } - if let Some(max_step) = query.max_step { - select_query = select_query.bind(max_step as i64); - } - if let Some(ran_node) = &query.ran_node { - select_query = select_query.bind(serde_json::json!([ran_node.encode()])); - } - if let Some(skipped_node) = &query.skipped_node { - select_query = select_query.bind(serde_json::json!([skipped_node.encode()])); + // Both count and select use the same parameter binding sequence. + let total_count: i64 = { + let mut q = sqlx::query(&count_sql).bind(session_id); + if let Some(v) = query.min_step { + q = q.bind(v as i64); + } + if let Some(v) = query.max_step { + q = q.bind(v as i64); + } + if let Some(ref node) = query.ran_node { + q = q.bind(serde_json::json!([node.encode()])); + } + if let Some(ref node) = query.skipped_node { + q = q.bind(serde_json::json!([node.encode()])); + } + q } - - let rows = - select_query - .fetch_all(&*self.pool) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("select query: {e}"), - })?; - - // Convert rows to checkpoints - let mut checkpoints = Vec::new(); - for row in rows { - let checkpoint = self.row_to_checkpoint(session_id, &row)?; - checkpoints.push(checkpoint); + .fetch_one(&*self.pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("count query: {e}"), + })? + .get("total"); + + let rows = { + let mut q = sqlx::query(&select_sql).bind(session_id); + if let Some(v) = query.min_step { + q = q.bind(v as i64); + } + if let Some(v) = query.max_step { + q = q.bind(v as i64); + } + if let Some(ref node) = query.ran_node { + q = q.bind(serde_json::json!([node.encode()])); + } + if let Some(ref node) = query.skipped_node { + q = q.bind(serde_json::json!([node.encode()])); + } + q } + .fetch_all(&*self.pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("select query: {e}"), + })?; - let page_info = PageInfo { - total_count: total_count as u64, - page_size: checkpoints.len() as u32, - offset, - has_next_page: (offset + limit) < total_count as u32, - }; + let checkpoints = rows + .iter() + .map(|r| self.row_to_checkpoint(session_id, r)) + .collect::>>()?; Ok(StepQueryResult { + page_info: PageInfo { + total_count: total_count as u64, + page_size: checkpoints.len() as u32, + offset, + has_next_page: (offset + limit) < total_count as u32, + }, checkpoints, - page_info, }) } - /// Save a checkpoint with optimistic concurrency control. - /// - /// This method prevents concurrent modifications by checking that the - /// session's last step matches the expected value before saving. - /// This ensures checkpoint sequence integrity in multi-writer scenarios. - /// - /// # Parameters + /// Save a checkpoint only when `sessions.last_step` equals `expected_last_step`. /// - /// * `checkpoint` - The checkpoint to save - /// * `expected_last_step` - Expected current step number (for concurrency control) + /// Uses `SELECT … FOR UPDATE` to serialize concurrent writers. Pass `None` to + /// skip the check (equivalent to [`Checkpointer::save`]). /// - /// # Returns - /// - /// * `Ok(())` - Checkpoint saved successfully - /// * `Err(CheckpointerError::Backend)` - Concurrency conflict or storage error - /// - /// # Examples + /// # Example /// /// ```rust,no_run /// use weavegraph::runtimes::checkpointer_postgres::PostgresCheckpointer; /// - /// # async fn example() -> Result<(), Box> { - /// let checkpointer = PostgresCheckpointer::connect("postgresql://localhost/app").await?; + /// # async fn run() -> Result<(), Box> { + /// let cp = PostgresCheckpointer::connect("postgresql://localhost/app").await?; /// # let checkpoint = todo!(); - /// - /// // Save step 5, expecting current step to be 4 - /// match checkpointer.save_with_concurrency_check(checkpoint, Some(4)).await { - /// Ok(()) => println!("Checkpoint saved successfully"), - /// Err(e) => println!("Concurrency conflict or error: {}", e), - /// } + /// cp.save_with_concurrency_check(checkpoint, Some(4)).await?; /// # Ok(()) /// # } /// ``` @@ -612,53 +397,18 @@ impl PostgresCheckpointer { checkpoint: Checkpoint, expected_last_step: Option, ) -> Result<()> { - // Serialize checkpoint data - let persisted_state = PersistedState::from(&checkpoint.state); - let state_json = serialize_json(&persisted_state, "state")?; - let frontier_enc: Vec = checkpoint.frontier.iter().map(|k| k.encode()).collect(); - let frontier_json = serialize_json(&frontier_enc, "frontier")?; - let persisted_vs = PersistedVersionsSeen(checkpoint.versions_seen.clone()); - let versions_seen_json = serialize_json(&persisted_vs, "versions_seen")?; - let ran_nodes_enc: Vec = checkpoint.ran_nodes.iter().map(|k| k.encode()).collect(); - let ran_nodes_json = serialize_json(&ran_nodes_enc, "ran_nodes")?; - let skipped_nodes_enc: Vec = checkpoint - .skipped_nodes - .iter() - .map(|k| k.encode()) - .collect(); - let skipped_nodes_json = serialize_json(&skipped_nodes_enc, "skipped_nodes")?; - let updated_channels_json = - serialize_json(&checkpoint.updated_channels, "updated_channels")?; - - let mut tx = self - .pool - .begin() - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("tx begin: {e}"), - })?; + let enc = EncodedCheckpoint::encode(&checkpoint)?; + let mut tx = self.begin_tx().await?; - // Ensure session row exists - sqlx::query( - r#" - INSERT INTO sessions (id, concurrency_limit) - VALUES ($1, $2) - ON CONFLICT (id) DO NOTHING - "#, + exec_upsert_session( + &mut tx, + &checkpoint.session_id, + checkpoint.concurrency_limit, ) - .bind(&checkpoint.session_id) - .bind(checkpoint.concurrency_limit as i64) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert session: {e}"), - })?; + .await?; - // Check concurrency constraint if specified. - // - // We take a row-level lock so concurrent writers serialize their checks. - if let Some(expected_step) = expected_last_step { - let current_step: i64 = + if let Some(expected) = expected_last_step { + let actual: i64 = sqlx::query_scalar("SELECT last_step FROM sessions WHERE id = $1 FOR UPDATE") .bind(&checkpoint.session_id) .fetch_one(&mut *tx) @@ -667,164 +417,226 @@ impl PostgresCheckpointer { message: format!("concurrency check: {e}"), })?; - if current_step != expected_step as i64 { + if actual != expected as i64 { return Err(CheckpointerError::Backend { message: format!( - "concurrency conflict: expected step {}, found {}", - expected_step, current_step + "concurrency conflict: expected last_step {expected}, found {actual}" ), }); } } - // Insert or replace step row (upsert for idempotent re-save of same step) - sqlx::query( - r#" - INSERT INTO steps ( - session_id, - step, - state_json, - frontier_json, - versions_seen_json, - ran_nodes_json, - skipped_nodes_json, - updated_channels_json - ) VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb) - ON CONFLICT (session_id, step) DO UPDATE SET - state_json = EXCLUDED.state_json, - frontier_json = EXCLUDED.frontier_json, - versions_seen_json = EXCLUDED.versions_seen_json, - ran_nodes_json = EXCLUDED.ran_nodes_json, - skipped_nodes_json = EXCLUDED.skipped_nodes_json, - updated_channels_json = EXCLUDED.updated_channels_json - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.step as i64) - .bind(&state_json) - .bind(&frontier_json) - .bind(&versions_seen_json) - .bind(&ran_nodes_json) - .bind(&skipped_nodes_json) - .bind(&updated_channels_json) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert step: {e}"), - })?; - - // Maintain denormalized latest snapshot (monotonic). - sqlx::query( - r#" - UPDATE sessions - SET - updated_at = NOW(), - last_step = CASE WHEN last_step <= $2 THEN $2 ELSE last_step END, - last_state_json = CASE WHEN last_step <= $2 THEN $3::jsonb ELSE last_state_json END, - last_frontier_json = CASE WHEN last_step <= $2 THEN $4::jsonb ELSE last_frontier_json END, - last_versions_seen_json = CASE WHEN last_step <= $2 THEN $5::jsonb ELSE last_versions_seen_json END - WHERE id = $1 - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.step as i64) - .bind(&state_json) - .bind(&frontier_json) - .bind(&versions_seen_json) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("update session latest: {e}"), - })?; + exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; + exec_update_latest(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; tx.commit().await.map_err(|e| CheckpointerError::Backend { - message: format!("tx commit: {e}"), - })?; - - Ok(()) + message: format!("commit: {e}"), + }) } - /// Helper to convert a database row to a Checkpoint. fn row_to_checkpoint(&self, session_id: &str, row: &PgRow) -> Result { let step: i64 = row.get("step"); - let state_json: Value = row.get("state_json"); - let frontier_json: Value = row.get("frontier_json"); - let versions_seen_json: Value = row.get("versions_seen_json"); - let ran_nodes_json: Value = row.get("ran_nodes_json"); - let skipped_nodes_json: Value = row.get("skipped_nodes_json"); + let created_at: DateTime = row.get("created_at"); + let concurrency_limit: i64 = row.get("concurrency_limit"); + let updated_channels_json: Option = row.try_get("updated_channels_json") .map_err(|e| CheckpointerError::Backend { - message: format!("updated_channels_json read: {e}"), + message: format!("updated_channels_json: {e}"), })?; - let created_at: DateTime = row.get("created_at"); - let concurrency_limit: i64 = row.get("concurrency_limit"); - - // Deserialize using persistence models - let persisted_state: PersistedState = deserialize_json_value(state_json, "state")?; - let state = - VersionedState::try_from(persisted_state).map_err(|e| CheckpointerError::Other { - message: format!("state convert: {e}"), - })?; - - let frontier: Vec = frontier_json - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "frontier not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - - let ran_nodes: Vec = ran_nodes_json - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "ran_nodes not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - - let skipped_nodes: Vec = skipped_nodes_json - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "skipped_nodes not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); let updated_channels: Vec = match updated_channels_json { None => vec![], Some(v) => v .as_array() .ok_or_else(|| CheckpointerError::Other { - message: "updated_channels not array".to_string(), + message: "updated_channels_json: expected array".to_string(), })? .iter() - .filter_map(|v| v.as_str()) - .map(|s| s.to_string()) + .filter_map(Value::as_str) + .map(str::to_owned) .collect(), }; - let persisted_vs: PersistedVersionsSeen = - deserialize_json_value(versions_seen_json, "versions_seen")?; - let versions_seen = persisted_vs.0; + let versions_seen = { + let pv: PersistedVersionsSeen = + from_json_value(row.get("versions_seen_json"), "versions_seen")?; + pv.0 + }; Ok(Checkpoint { session_id: session_id.to_string(), step: step as u64, - state, - frontier, + state: decode_state(row.get("state_json"))?, + frontier: decode_node_kinds(row.get("frontier_json"))?, + ran_nodes: decode_node_kinds(row.get("ran_nodes_json"))?, + skipped_nodes: decode_node_kinds(row.get("skipped_nodes_json"))?, versions_seen, concurrency_limit: concurrency_limit as usize, created_at, - ran_nodes, - skipped_nodes, updated_channels, }) } } + +// --------------------------------------------------------------------------- +// Serialization helpers +// --------------------------------------------------------------------------- + +struct EncodedCheckpoint { + state_json: String, + frontier_json: String, + versions_seen_json: String, + ran_nodes_json: String, + skipped_nodes_json: String, + updated_channels_json: String, +} + +impl EncodedCheckpoint { + fn encode(cp: &Checkpoint) -> Result { + let frontier_enc: Vec = cp.frontier.iter().map(NodeKind::encode).collect(); + let ran_nodes_enc: Vec = cp.ran_nodes.iter().map(NodeKind::encode).collect(); + let skipped_enc: Vec = cp.skipped_nodes.iter().map(NodeKind::encode).collect(); + + Ok(Self { + state_json: to_json(&PersistedState::from(&cp.state), "state")?, + frontier_json: to_json(&frontier_enc, "frontier")?, + versions_seen_json: to_json( + &PersistedVersionsSeen(cp.versions_seen.clone()), + "versions_seen", + )?, + ran_nodes_json: to_json(&ran_nodes_enc, "ran_nodes")?, + skipped_nodes_json: to_json(&skipped_enc, "skipped_nodes")?, + updated_channels_json: to_json(&cp.updated_channels, "updated_channels")?, + }) + } +} + +fn to_json(value: &T, ctx: &'static str) -> Result { + serde_json::to_string(value).map_err(|e| CheckpointerError::Other { + message: format!("{ctx} serialize: {e}"), + }) +} + +fn from_json_value(value: Value, ctx: &'static str) -> Result { + serde_json::from_value(value).map_err(|e| CheckpointerError::Other { + message: format!("{ctx} parse: {e}"), + }) +} + +fn need_field(opt: Option, name: &'static str) -> Result { + opt.ok_or_else(|| CheckpointerError::Other { + message: format!("missing field {name}"), + }) +} + +fn decode_state(v: Value) -> Result { + let persisted: PersistedState = from_json_value(v, "state")?; + VersionedState::try_from(persisted).map_err(|e| CheckpointerError::Other { + message: format!("state convert: {e}"), + }) +} + +fn decode_node_kinds(v: Value) -> Result> { + v.as_array() + .ok_or_else(|| CheckpointerError::Other { + message: "expected JSON array of node kinds".to_string(), + }) + .map(|arr| { + arr.iter() + .filter_map(Value::as_str) + .map(NodeKind::decode) + .collect() + }) +} + +// --------------------------------------------------------------------------- +// SQL execution helpers +// --------------------------------------------------------------------------- + +type Tx = sqlx::Transaction<'static, sqlx::Postgres>; + +async fn exec_upsert_session( + tx: &mut Tx, + session_id: &str, + concurrency_limit: usize, +) -> Result<()> { + sqlx::query( + "INSERT INTO sessions (id, concurrency_limit) VALUES ($1, $2) \ + ON CONFLICT (id) DO NOTHING", + ) + .bind(session_id) + .bind(concurrency_limit as i64) + .execute(&mut **tx) + .await + .map(|_| ()) + .map_err(|e| CheckpointerError::Backend { + message: format!("upsert session: {e}"), + }) +} + +async fn exec_upsert_step( + tx: &mut Tx, + session_id: &str, + step: u64, + enc: &EncodedCheckpoint, +) -> Result<()> { + sqlx::query( + "INSERT INTO steps ( + session_id, step, + state_json, frontier_json, versions_seen_json, + ran_nodes_json, skipped_nodes_json, updated_channels_json + ) VALUES ($1, $2, $3::jsonb, $4::jsonb, $5::jsonb, $6::jsonb, $7::jsonb, $8::jsonb) + ON CONFLICT (session_id, step) DO UPDATE SET + state_json = EXCLUDED.state_json, + frontier_json = EXCLUDED.frontier_json, + versions_seen_json = EXCLUDED.versions_seen_json, + ran_nodes_json = EXCLUDED.ran_nodes_json, + skipped_nodes_json = EXCLUDED.skipped_nodes_json, + updated_channels_json = EXCLUDED.updated_channels_json", + ) + .bind(session_id) + .bind(step as i64) + .bind(&enc.state_json) + .bind(&enc.frontier_json) + .bind(&enc.versions_seen_json) + .bind(&enc.ran_nodes_json) + .bind(&enc.skipped_nodes_json) + .bind(&enc.updated_channels_json) + .execute(&mut **tx) + .await + .map(|_| ()) + .map_err(|e| CheckpointerError::Backend { + message: format!("upsert step: {e}"), + }) +} + +async fn exec_update_latest( + tx: &mut Tx, + session_id: &str, + step: u64, + enc: &EncodedCheckpoint, +) -> Result<()> { + // Advance last_* only when the incoming step is >= the current last_step so + // that out-of-order replays or imports cannot regress the cached latest snapshot. + sqlx::query( + "UPDATE sessions SET + updated_at = NOW(), + last_step = CASE WHEN last_step <= $2 THEN $2 ELSE last_step END, + last_state_json = CASE WHEN last_step <= $2 THEN $3::jsonb ELSE last_state_json END, + last_frontier_json = CASE WHEN last_step <= $2 THEN $4::jsonb ELSE last_frontier_json END, + last_versions_seen_json = CASE WHEN last_step <= $2 THEN $5::jsonb ELSE last_versions_seen_json END + WHERE id = $1", + ) + .bind(session_id) + .bind(step as i64) + .bind(&enc.state_json) + .bind(&enc.frontier_json) + .bind(&enc.versions_seen_json) + .execute(&mut **tx) + .await + .map(|_| ()) + .map_err(|e| CheckpointerError::Backend { + message: format!("update session latest: {e}"), + }) +} diff --git a/src/runtimes/checkpointer_postgres_helpers.rs b/src/runtimes/checkpointer_postgres_helpers.rs deleted file mode 100644 index ce8338f..0000000 --- a/src/runtimes/checkpointer_postgres_helpers.rs +++ /dev/null @@ -1,72 +0,0 @@ -/*! -Helper utilities for PostgreSQL checkpointer operations. - -This module provides consistent JSON serialization/deserialization helpers -that standardize error handling and context reporting across the PostgreSQL -checkpointer implementation. All functions convert serde errors to appropriate -`CheckpointerError` variants with contextual information. - -## Usage - -These helpers should be used for all JSON operations in the PostgreSQL checkpointer -to ensure consistent error handling and debugging information. - -```rust,ignore -// Serialization -let json_str = serialize_json(&data, "user_data")?; - -// Deserialization -let data: MyType = deserialize_json(&json_str, "user_data")?; - -// Value deserialization -let data: MyType = deserialize_json_value(json_value, "user_data")?; - -// Field validation -let required_field = require_json_field(optional_field, "state_json")?; -``` -*/ - -use crate::runtimes::checkpointer::CheckpointerError; -use crate::utils::json_ext::{deserialize_with_context, serialize_with_context}; -use serde_json::Value; - -/// Helper for JSON serialization with consistent error formatting. -pub(super) fn serialize_json( - value: &T, - context: &'static str, -) -> Result { - serialize_with_context(value, context, |e, ctx| CheckpointerError::Other { - message: format!("{ctx} serialize: {e}"), - }) -} - -/// Helper for JSON deserialization with consistent error formatting. -#[allow(dead_code)] -pub(super) fn deserialize_json( - json: &str, - context: &'static str, -) -> Result { - deserialize_with_context(json, context, |e, ctx| CheckpointerError::Other { - message: format!("{ctx} parse: {e}"), - }) -} - -/// Helper for JSON value deserialization with consistent error formatting. -pub(super) fn deserialize_json_value( - value: Value, - context: &'static str, -) -> Result { - serde_json::from_value(value).map_err(|e| CheckpointerError::Other { - message: format!("{context} parse (serde): {e}"), - }) -} - -/// Helper for extracting required JSON fields with consistent error formatting. -pub(super) fn require_json_field( - field: Option, - field_name: &'static str, -) -> Result { - field.ok_or_else(|| CheckpointerError::Other { - message: format!("missing {field_name} for persisted checkpoint"), - }) -} diff --git a/src/runtimes/checkpointer_sqlite.rs b/src/runtimes/checkpointer_sqlite.rs index ae26ef4..4dbce05 100644 --- a/src/runtimes/checkpointer_sqlite.rs +++ b/src/runtimes/checkpointer_sqlite.rs @@ -1,60 +1,34 @@ /*! -SQLite Checkpointer +SQLite checkpointer backend. -This module provides the `SQLiteCheckpointer` async implementation of the -`Checkpointer` trait defined in `runtimes/checkpointer.rs`. +Implements [`Checkpointer`] over a [`SqlitePool`], storing the full step +history with paginated queries and optional optimistic concurrency control. -## Features +## Schema -- **Complete Step History**: Stores full execution metadata including ran/skipped nodes -- **Pagination Support**: Efficient querying of large checkpoint histories -- **Optimistic Concurrency**: Prevention of concurrent checkpoint conflicts -- **Serde Integration**: Uses persistence models for consistent serialization +- `sessions.id` — session identifier (primary key) +- `sessions.concurrency_limit` — maximum concurrent nodes +- `sessions.last_step`, `sessions.last_*_json` — cached latest snapshot, + advanced by an `AFTER INSERT` trigger on `steps` +- `steps.session_id + steps.step` — composite primary key +- `steps.*_json` — TEXT columns containing JSON for state, frontier, + versions_seen, ran/skipped nodes, and updated channel names -## Behavior +## Migrations -- Uses serde-based persistence models (see `runtimes::persistence`) for - encoding `VersionedState`, frontier node kinds, and `versions_seen`. -- When the `sqlite-migrations` feature is enabled (default), embedded - migrations (`sqlx::migrate!("./migrations")`) are executed on connect; - disabling the feature assumes external migration orchestration. +Enable the `sqlite-migrations` feature (default) to apply embedded migrations +on connect; otherwise the schema must be managed externally. -## Design Goals +## NodeKind encoding -- Keep this module focused on database I/O; pure serialization lives in - the persistence module. -- Provide efficient querying with filtering and pagination support. -- Ensure data consistency with optimistic concurrency control. - -## Database Schema - -The checkpoint data maps to database tables as follows: - -- `sessions.id` ← `checkpoint.session_id` -- `sessions.concurrency_limit` ← `checkpoint.concurrency_limit` -- `steps.session_id` ← `checkpoint.session_id` -- `steps.step` ← `checkpoint.step` -- `steps.state_json` ← serialized `VersionedState` -- `steps.frontier_json` ← JSON array of encoded `NodeKind` -- `steps.versions_seen_json` ← JSON object (node → channel → version) -- `steps.ran_nodes_json` ← JSON array of executed nodes -- `steps.skipped_nodes_json` ← JSON array of skipped nodes -- `steps.updated_channels_json` ← JSON array of updated channel names - -## NodeKind Encoding - -NodeKinds are encoded as strings for JSON storage: -- `Start` → `"Start"` -- `End` → `"End"` -- `Custom(name)` → `"Custom:"` +`NodeKind` is encoded as a string: `Start` → `"Start"`, `End` → `"End"`, +`Custom(n)` → `"Custom:"`. */ use std::sync::Arc; use chrono::{DateTime, Utc}; -use serde_json::Value; use sqlx::{Row, SqlitePool, sqlite::SqliteRow}; -use thiserror::Error; use tracing::instrument; use crate::{ @@ -64,155 +38,59 @@ use crate::{ types::NodeKind, }; -use super::checkpointer_sqlite_helpers::{ - deserialize_json, deserialize_json_value, require_json_field, serialize_json, -}; +// --------------------------------------------------------------------------- +// Public types +// --------------------------------------------------------------------------- -/// Query parameters for filtering step history. +/// Filters and pagination controls for [`SQLiteCheckpointer::query_steps`]. #[derive(Debug, Clone, Default)] pub struct StepQuery { - /// Maximum number of results to return (capped at 1000) + /// Maximum results per page (defaults to 100, capped at 1 000). pub limit: Option, - /// Number of results to skip (for pagination) + /// Zero-based offset of the first result (for cursor-free pagination). pub offset: Option, - /// Filter by minimum step number (inclusive) + /// Restrict to steps with step number ≥ this value. pub min_step: Option, - /// Filter by maximum step number (inclusive) + /// Restrict to steps with step number ≤ this value. pub max_step: Option, - /// Only return steps that executed the specified node + /// Restrict to steps where this node ran. pub ran_node: Option, - /// Only return steps that skipped the specified node + /// Restrict to steps where this node was skipped. pub skipped_node: Option, } -/// Pagination information for query results. +/// Pagination metadata included in a [`StepQueryResult`]. #[derive(Debug, Clone, PartialEq, Eq)] pub struct PageInfo { - /// Total number of matching records + /// Total records matching the query filters (before pagination). pub total_count: u64, - /// Number of records returned in this page + /// Records returned in this page. pub page_size: u32, - /// Zero-based offset of the first record in this page + /// Zero-based offset of the first record in this page. pub offset: u32, - /// Whether there are more records after this page + /// Whether records remain after this page. pub has_next_page: bool, } -/// Paginated query result for step history. +/// A page of checkpoint records from [`SQLiteCheckpointer::query_steps`]. #[derive(Debug, Clone)] pub struct StepQueryResult { - /// The matching checkpoints + /// Checkpoints on this page, ordered by step descending. pub checkpoints: Vec, - /// Pagination metadata + /// Pagination metadata. pub page_info: PageInfo, } -/// Errors that can occur within the SQLite-backed checkpointer. -#[derive(Debug, Error)] -#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] -pub enum SQLiteCheckpointerError { - /// An underlying SQLx database error. - #[error("SQLx error: {0}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic( - code(weavegraph::sqlite::sqlx), - help("Ensure the SQLite database URL is valid and accessible.") - ) - )] - Sqlx(#[from] sqlx::Error), - - /// A JSON serialization or deserialization error. - #[error("JSON serialization error: {0}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic( - code(weavegraph::sqlite::serde), - help("Check serialized shapes for state/frontier/versions_seen.") - ) - )] - Serde(#[from] serde_json::Error), - - /// A required field was missing from a persisted row. - #[error("Missing persisted field: {0}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic( - code(weavegraph::sqlite::missing), - help("Backfill or re-run migrations to populate the missing field.") - ) - )] - Missing(&'static str), - - /// A generic backend error. - #[error("Backend error: {0}")] - #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::sqlite::backend)))] - Backend(String), - - /// Any other error not covered by the above variants. - #[error("Other error: {0}")] - #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::sqlite::other)))] - Other(String), -} +// --------------------------------------------------------------------------- +// SQLiteCheckpointer +// --------------------------------------------------------------------------- -impl From for CheckpointerError { - fn from(e: SQLiteCheckpointerError) -> Self { - match e { - SQLiteCheckpointerError::Sqlx(err) => CheckpointerError::Backend { - message: err.to_string(), - }, - SQLiteCheckpointerError::Serde(err) => CheckpointerError::Other { - message: err.to_string(), - }, - SQLiteCheckpointerError::Missing(what) => CheckpointerError::Other { - message: format!("missing persisted field: {what}"), - }, - SQLiteCheckpointerError::Backend(msg) => CheckpointerError::Backend { message: msg }, - SQLiteCheckpointerError::Other(msg) => CheckpointerError::Other { message: msg }, - } - } -} - -/// SQLite-backed checkpointer with full step history. -/// -/// Provides durable checkpoint storage with advanced querying capabilities -/// including pagination, filtering, and optimistic concurrency control. -/// -/// # Storage Growth -/// -/// This backend stores complete step history. Storage grows roughly with: -/// `(sessions × steps_per_session × state_size)`. -/// -/// For long-running applications, plan periodic cleanup to control database size: -/// -/// ## Option 1: Direct SQL maintenance (recommended) -/// -/// ```bash -/// # Delete checkpoints older than 30 days -/// sqlite3 workflow.db "DELETE FROM steps WHERE created_at < datetime('now', '-30 days')" -/// -/// # Keep only latest 100 steps per session -/// sqlite3 workflow.db " -/// DELETE FROM steps -/// WHERE step NOT IN ( -/// SELECT step FROM steps -/// WHERE session_id = steps.session_id -/// ORDER BY step DESC -/// LIMIT 100 -/// ) -/// " -/// -/// # Reclaim space -/// sqlite3 workflow.db "VACUUM" -/// ``` +/// Durable SQLite checkpointer that retains the full step history. /// -/// ## Option 2: Application lifecycle management -/// -/// Delete entire sessions when workflows complete or expire. The schema includes -/// timestamps (`created_at` on steps, `updated_at` on sessions) to facilitate -/// time-based policies. +/// Storage grows with session count and step depth. For long-lived deployments +/// use periodic SQL maintenance — for example, delete steps older than N days +/// or keep only the most recent M steps per session, then `VACUUM`. pub struct SQLiteCheckpointer { - /// Shared SQLite connection pool for concurrent checkpoint operations pool: Arc, } @@ -223,200 +101,117 @@ impl std::fmt::Debug for SQLiteCheckpointer { } impl SQLiteCheckpointer { - /// Connect (or create) a SQLite database at `database_url`. - /// Example URL: \"sqlite://weavegraph.db\" + /// Open (or create) a SQLite database at `database_url` and return a ready checkpointer. + /// + /// When the `sqlite-migrations` feature is enabled (default), embedded migrations + /// are applied automatically (idempotent). /// - /// Returns a configured `SQLiteCheckpointer` ready for use. + /// Example URL: `"sqlite://weavegraph.db"` #[must_use = "checkpointer must be used to persist state"] #[instrument(skip(database_url))] pub async fn connect(database_url: &str) -> std::result::Result { - let pool = - SqlitePool::connect(database_url) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("connect error: {e}"), - })?; - // Run embedded migrations only if the feature is enabled (idempotent). + let pool = SqlitePool::connect(database_url) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("connect: {e}"), + })?; + #[cfg(feature = "sqlite-migrations")] - { - if let Err(e) = sqlx::migrate!("./migrations").run(&pool).await { - return Err(CheckpointerError::Backend { - message: format!("migration failure: {e}"), - }); - } - } - #[cfg(not(feature = "sqlite-migrations"))] - { - // Feature disabled: assume external migration orchestration already applied schema. - } + sqlx::migrate!("./migrations") + .run(&pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("migration: {e}"), + })?; + Ok(Self { pool: Arc::new(pool), }) } + + async fn begin_tx(&self) -> Result { + self.pool + .begin() + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("begin transaction: {e}"), + }) + } } +// --------------------------------------------------------------------------- +// Checkpointer trait impl +// --------------------------------------------------------------------------- + #[async_trait::async_trait] impl Checkpointer for SQLiteCheckpointer { #[instrument(skip(self, checkpoint), err)] async fn save(&self, checkpoint: Checkpoint) -> Result<()> { - // Serialize using persistence module (serde-based) - let persisted_state = PersistedState::from(&checkpoint.state); - let state_json = serialize_json(&persisted_state, "state")?; - let frontier_enc: Vec = checkpoint.frontier.iter().map(|k| k.encode()).collect(); - let frontier_json = serialize_json(&frontier_enc, "frontier")?; - let persisted_vs = PersistedVersionsSeen(checkpoint.versions_seen.clone()); - let versions_seen_json = serialize_json(&persisted_vs, "versions_seen")?; - - // Serialize step execution metadata - let ran_nodes_enc: Vec = checkpoint.ran_nodes.iter().map(|k| k.encode()).collect(); - let ran_nodes_json = serialize_json(&ran_nodes_enc, "ran_nodes")?; - let skipped_nodes_enc: Vec = checkpoint - .skipped_nodes - .iter() - .map(|k| k.encode()) - .collect(); - let skipped_nodes_json = serialize_json(&skipped_nodes_enc, "skipped_nodes")?; - let updated_channels_json = - serialize_json(&checkpoint.updated_channels, "updated_channels")?; - - let mut tx = self - .pool - .begin() - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("tx begin: {e}"), - })?; + let enc = EncodedCheckpoint::encode(&checkpoint)?; + let mut tx = self.begin_tx().await?; - // Ensure session row - sqlx::query( - r#" - INSERT OR IGNORE INTO sessions (id, concurrency_limit) - VALUES (?1, ?2) - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.concurrency_limit as i64) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert session: {e}"), - })?; - - // Insert or replace step row (allows idempotent re-save of same step) - sqlx::query( - r#" - INSERT OR REPLACE INTO steps ( - session_id, - step, - state_json, - frontier_json, - versions_seen_json, - ran_nodes_json, - skipped_nodes_json, - updated_channels_json - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.step as i64) - .bind(&state_json) - .bind(&frontier_json) - .bind(&versions_seen_json) - .bind(&ran_nodes_json) - .bind(&skipped_nodes_json) - .bind(&updated_channels_json) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert step: {e}"), - })?; + exec_insert_session(&mut tx, &checkpoint.session_id, checkpoint.concurrency_limit).await?; + exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; tx.commit().await.map_err(|e| CheckpointerError::Backend { - message: format!("tx commit: {e}"), - })?; - - Ok(()) + message: format!("commit: {e}"), + }) } #[instrument(skip(self, session_id), err)] async fn load_latest(&self, session_id: &str) -> Result> { - let row_opt: Option = sqlx::query( - r#" - SELECT - s.id, - s.last_step, - s.last_state_json, - s.last_frontier_json, - s.last_versions_seen_json, - s.concurrency_limit, - s.updated_at - FROM sessions s - WHERE s.id = ?1 - "#, + let row: Option = sqlx::query( + "SELECT id, last_step, last_state_json, last_frontier_json, \ + last_versions_seen_json, concurrency_limit, updated_at \ + FROM sessions WHERE id = ?1", ) .bind(session_id) .fetch_optional(&*self.pool) .await .map_err(|e| CheckpointerError::Backend { - message: format!("select latest: {e}"), + message: format!("load_latest: {e}"), })?; - let row = match row_opt { + let row = match row { Some(r) => r, None => return Ok(None), }; let last_step: i64 = row.get("last_step"); - - let state_json: Option = - row.try_get("last_state_json") - .map_err(|e| CheckpointerError::Backend { - message: format!("last_state_json read: {e}"), - })?; - let frontier_json: Option = - row.try_get("last_frontier_json") - .map_err(|e| CheckpointerError::Backend { - message: format!("last_frontier_json read: {e}"), - })?; - let versions_seen_json: Option = - row.try_get("last_versions_seen_json") - .map_err(|e| CheckpointerError::Backend { - message: format!("last_versions_seen_json read: {e}"), - })?; let concurrency_limit: i64 = row.get("concurrency_limit"); let updated_at_str: String = row.get("updated_at"); + let state_json: Option = row.try_get("last_state_json").map_err(|e| { + CheckpointerError::Backend { + message: format!("last_state_json: {e}"), + } + })?; + let frontier_json: Option = row.try_get("last_frontier_json").map_err(|e| { + CheckpointerError::Backend { + message: format!("last_frontier_json: {e}"), + } + })?; + let versions_seen_json: Option = + row.try_get("last_versions_seen_json").map_err(|e| { + CheckpointerError::Backend { + message: format!("last_versions_seen_json: {e}"), + } + })?; + + // Session row exists but no checkpoint written yet. if last_step == 0 && state_json.is_none() { - // Session row exists but no checkpoint has been persisted yet. return Ok(None); } - let state_payload = require_json_field(state_json, "state_json")?; - let frontier_payload = require_json_field(frontier_json, "frontier_json")?; - let versions_seen_payload = require_json_field(versions_seen_json, "versions_seen_json")?; - - let state_val: Value = deserialize_json(&state_payload, "state")?; - let frontier_val: Value = deserialize_json(&frontier_payload, "frontier")?; - let versions_seen_val: Value = deserialize_json(&versions_seen_payload, "versions_seen")?; - - // Deserialize using persistence models - let persisted_state: PersistedState = deserialize_json_value(state_val, "state")?; - let state = - VersionedState::try_from(persisted_state).map_err(|e| CheckpointerError::Other { - message: format!("state convert: {e}"), - })?; - let frontier: Vec = frontier_val - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "frontier not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - let persisted_vs: PersistedVersionsSeen = - deserialize_json_value(versions_seen_val, "versions_seen")?; - let versions_seen = persisted_vs.0; + let state = decode_state(&need_field(state_json, "last_state_json")?)?; + let frontier = decode_node_kinds(&need_field(frontier_json, "last_frontier_json")?)?; + let versions_seen = { + let pv: PersistedVersionsSeen = from_json_str( + &need_field(versions_seen_json, "last_versions_seen_json")?, + "versions_seen", + )?; + pv.0 + }; let created_at = DateTime::parse_from_rfc3339(&updated_at_str) .map(|dt| dt.with_timezone(&Utc)) @@ -430,8 +225,8 @@ impl Checkpointer for SQLiteCheckpointer { versions_seen, concurrency_limit: concurrency_limit as usize, created_at, - // Note: load_latest uses denormalized session data which doesn't include - // step execution metadata. Use query_steps() for full checkpoint details. + // Denormalized session row omits per-step execution metadata; + // call query_steps() to retrieve those fields. ran_nodes: vec![], skipped_nodes: vec![], updated_channels: vec![], @@ -440,204 +235,155 @@ impl Checkpointer for SQLiteCheckpointer { #[instrument(skip(self), err)] async fn list_sessions(&self) -> Result> { - let rows = sqlx::query( - r#" - SELECT id FROM sessions - ORDER BY updated_at DESC - "#, - ) - .fetch_all(&*self.pool) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("list sessions: {e}"), - })?; - - Ok(rows.into_iter().map(|r| r.get::("id")).collect()) + sqlx::query("SELECT id FROM sessions ORDER BY updated_at DESC") + .fetch_all(&*self.pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("list_sessions: {e}"), + }) + .map(|rows| rows.into_iter().map(|r| r.get::("id")).collect()) } } -// Extended SQLiteCheckpointer methods (not part of base Checkpointer trait) +// --------------------------------------------------------------------------- +// Extended methods +// --------------------------------------------------------------------------- + impl SQLiteCheckpointer { - /// Query step history with filtering and pagination. - /// - /// This method provides comprehensive access to checkpoint history with - /// support for filtering by step range, node execution, and pagination - /// for efficient access to large histories. - /// - /// # Parameters + /// Query the full step history for a session with optional filters and pagination. /// - /// * `session_id` - Session to query - /// * `query` - Filter and pagination parameters + /// Results are ordered by step number descending. `limit` defaults to 100 and + /// is capped at 1 000. The total record count before pagination is reported in + /// [`PageInfo::total_count`]. /// - /// # Returns - /// - /// * `Ok(StepQueryResult)` - Matching checkpoints with pagination info - /// * `Err(CheckpointerError)` - Query execution failure - /// - /// # Examples + /// # Example /// /// ```rust,no_run /// use weavegraph::runtimes::checkpointer_sqlite::{SQLiteCheckpointer, StepQuery}; - /// use weavegraph::types::NodeKind; - /// - /// # async fn example() -> Result<(), Box> { - /// let checkpointer = SQLiteCheckpointer::connect("sqlite://app.db").await?; /// - /// // Get recent steps with pagination - /// let query = StepQuery { - /// limit: Some(10), - /// offset: Some(0), - /// min_step: Some(5), + /// # async fn run() -> Result<(), Box> { + /// let cp = SQLiteCheckpointer::connect("sqlite://app.db").await?; + /// let result = cp.query_steps("my-session", StepQuery { + /// limit: Some(25), + /// min_step: Some(10), /// ..Default::default() - /// }; - /// - /// let result = checkpointer.query_steps("session1", query).await?; - /// println!("Found {} steps", result.page_info.page_size); + /// }).await?; + /// println!("{} total, {} on page", result.page_info.total_count, result.page_info.page_size); /// # Ok(()) /// # } /// ``` #[instrument(skip(self), err)] pub async fn query_steps(&self, session_id: &str, query: StepQuery) -> Result { - // Build WHERE clause conditions + let limit = query.limit.unwrap_or(100).min(1_000); + let offset = query.offset.unwrap_or(0); + + // Build WHERE clause; ?1 is always session_id. let mut conditions = vec!["session_id = ?1".to_string()]; - let mut param_count = 1; + let mut param = 1u32; if query.min_step.is_some() { - param_count += 1; - conditions.push(format!("step >= ?{param_count}")); + param += 1; + conditions.push(format!("step >= ?{param}")); } if query.max_step.is_some() { - param_count += 1; - conditions.push(format!("step <= ?{param_count}")); + param += 1; + conditions.push(format!("step <= ?{param}")); } if query.ran_node.is_some() { - param_count += 1; - conditions.push(format!( - "JSON_EXTRACT(ran_nodes_json, '$') LIKE ?{param_count}" - )); + param += 1; + conditions.push(format!("JSON_EXTRACT(ran_nodes_json, '$') LIKE ?{param}")); } if query.skipped_node.is_some() { - param_count += 1; + param += 1; conditions.push(format!( - "JSON_EXTRACT(skipped_nodes_json, '$') LIKE ?{param_count}" + "JSON_EXTRACT(skipped_nodes_json, '$') LIKE ?{param}" )); } let where_clause = conditions.join(" AND "); - - // Count total matching records - let count_sql = format!("SELECT COUNT(*) as total FROM steps WHERE {where_clause}"); - - let limit = query.limit.unwrap_or(100).min(1000); // Cap at 1000 - let offset = query.offset.unwrap_or(0); - - // Query with pagination + let count_sql = format!("SELECT COUNT(*) AS total FROM steps WHERE {where_clause}"); let select_sql = format!( - r#"SELECT - session_id, step, state_json, frontier_json, versions_seen_json, - ran_nodes_json, skipped_nodes_json, updated_channels_json, created_at - FROM steps - WHERE {where_clause} - ORDER BY step DESC - LIMIT {limit} OFFSET {offset}"# + "SELECT session_id, step, state_json, frontier_json, versions_seen_json, \ + ran_nodes_json, skipped_nodes_json, updated_channels_json, created_at \ + FROM steps WHERE {where_clause} \ + ORDER BY step DESC LIMIT {limit} OFFSET {offset}" ); - // Execute count query - let mut count_query = sqlx::query(&count_sql).bind(session_id); - if let Some(min_step) = query.min_step { - count_query = count_query.bind(min_step as i64); - } - if let Some(max_step) = query.max_step { - count_query = count_query.bind(max_step as i64); - } - if let Some(ran_node) = &query.ran_node { - count_query = count_query.bind(format!("%{}%", ran_node.encode())); - } - if let Some(skipped_node) = &query.skipped_node { - count_query = count_query.bind(format!("%{}%", skipped_node.encode())); - } - - let total_count: i64 = count_query - .fetch_one(&*self.pool) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("count query: {e}"), - })? - .get("total"); - - // Execute select query - let mut select_query = sqlx::query(&select_sql).bind(session_id); - if let Some(min_step) = query.min_step { - select_query = select_query.bind(min_step as i64); - } - if let Some(max_step) = query.max_step { - select_query = select_query.bind(max_step as i64); - } - if let Some(ran_node) = &query.ran_node { - select_query = select_query.bind(format!("%{}%", ran_node.encode())); - } - if let Some(skipped_node) = &query.skipped_node { - select_query = select_query.bind(format!("%{}%", skipped_node.encode())); + // Both count and select use the same parameter binding sequence. + let total_count: i64 = { + let mut q = sqlx::query(&count_sql).bind(session_id); + if let Some(v) = query.min_step { + q = q.bind(v as i64); + } + if let Some(v) = query.max_step { + q = q.bind(v as i64); + } + if let Some(ref node) = query.ran_node { + q = q.bind(format!("%{}%", node.encode())); + } + if let Some(ref node) = query.skipped_node { + q = q.bind(format!("%{}%", node.encode())); + } + q } - - let rows = - select_query - .fetch_all(&*self.pool) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("select query: {e}"), - })?; - - // Convert rows to checkpoints - let mut checkpoints = Vec::new(); - for row in rows { - let checkpoint = self.row_to_checkpoint(session_id, &row).await?; - checkpoints.push(checkpoint); + .fetch_one(&*self.pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("count query: {e}"), + })? + .get("total"); + + let rows = { + let mut q = sqlx::query(&select_sql).bind(session_id); + if let Some(v) = query.min_step { + q = q.bind(v as i64); + } + if let Some(v) = query.max_step { + q = q.bind(v as i64); + } + if let Some(ref node) = query.ran_node { + q = q.bind(format!("%{}%", node.encode())); + } + if let Some(ref node) = query.skipped_node { + q = q.bind(format!("%{}%", node.encode())); + } + q } + .fetch_all(&*self.pool) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("select query: {e}"), + })?; - let page_info = PageInfo { - total_count: total_count as u64, - page_size: checkpoints.len() as u32, - offset, - has_next_page: (offset + limit) < total_count as u32, - }; + let checkpoints = rows + .iter() + .map(|r| self.row_to_checkpoint(session_id, r)) + .collect::>>()?; Ok(StepQueryResult { + page_info: PageInfo { + total_count: total_count as u64, + page_size: checkpoints.len() as u32, + offset, + has_next_page: (offset + limit) < total_count as u32, + }, checkpoints, - page_info, }) } - /// Save a checkpoint with optimistic concurrency control. + /// Save a checkpoint only when `sessions.last_step` equals `expected_last_step`. /// - /// This method prevents concurrent modifications by checking that the - /// session's last step matches the expected value before saving. - /// This ensures checkpoint sequence integrity in multi-writer scenarios. + /// Pass `None` to skip the check (equivalent to [`Checkpointer::save`]). /// - /// # Parameters - /// - /// * `checkpoint` - The checkpoint to save - /// * `expected_last_step` - Expected current step number (for concurrency control) - /// - /// # Returns - /// - /// * `Ok(())` - Checkpoint saved successfully - /// * `Err(CheckpointerError::Backend)` - Concurrency conflict or storage error - /// - /// # Examples + /// # Example /// /// ```rust,no_run /// use weavegraph::runtimes::checkpointer_sqlite::SQLiteCheckpointer; /// - /// # async fn example() -> Result<(), Box> { - /// let checkpointer = SQLiteCheckpointer::connect("sqlite://app.db").await?; + /// # async fn run() -> Result<(), Box> { + /// let cp = SQLiteCheckpointer::connect("sqlite://app.db").await?; /// # let checkpoint = todo!(); - /// - /// // Save step 5, expecting current step to be 4 - /// match checkpointer.save_with_concurrency_check(checkpoint, Some(4)).await { - /// Ok(()) => println!("Checkpoint saved successfully"), - /// Err(e) => println!("Concurrency conflict or error: {}", e), - /// } + /// cp.save_with_concurrency_check(checkpoint, Some(4)).await?; /// # Ok(()) /// # } /// ``` @@ -647,35 +393,13 @@ impl SQLiteCheckpointer { checkpoint: Checkpoint, expected_last_step: Option, ) -> Result<()> { - // Serialize checkpoint data - let persisted_state = PersistedState::from(&checkpoint.state); - let state_json = serialize_json(&persisted_state, "state")?; - let frontier_enc: Vec = checkpoint.frontier.iter().map(|k| k.encode()).collect(); - let frontier_json = serialize_json(&frontier_enc, "frontier")?; - let persisted_vs = PersistedVersionsSeen(checkpoint.versions_seen.clone()); - let versions_seen_json = serialize_json(&persisted_vs, "versions_seen")?; - let ran_nodes_enc: Vec = checkpoint.ran_nodes.iter().map(|k| k.encode()).collect(); - let ran_nodes_json = serialize_json(&ran_nodes_enc, "ran_nodes")?; - let skipped_nodes_enc: Vec = checkpoint - .skipped_nodes - .iter() - .map(|k| k.encode()) - .collect(); - let skipped_nodes_json = serialize_json(&skipped_nodes_enc, "skipped_nodes")?; - let updated_channels_json = - serialize_json(&checkpoint.updated_channels, "updated_channels")?; - - let mut tx = self - .pool - .begin() - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("tx begin: {e}"), - })?; + let enc = EncodedCheckpoint::encode(&checkpoint)?; + let mut tx = self.begin_tx().await?; + + exec_insert_session(&mut tx, &checkpoint.session_id, checkpoint.concurrency_limit).await?; - // Check concurrency constraint if specified - if let Some(expected_step) = expected_last_step { - let current_step: Option = + if let Some(expected) = expected_last_step { + let current: Option = sqlx::query_scalar("SELECT last_step FROM sessions WHERE id = ?1") .bind(&checkpoint.session_id) .fetch_optional(&mut *tx) @@ -684,151 +408,49 @@ impl SQLiteCheckpointer { message: format!("concurrency check: {e}"), })?; - match current_step { - Some(step) if step != expected_step as i64 => { + match current { + Some(actual) if actual != expected as i64 => { return Err(CheckpointerError::Backend { message: format!( - "concurrency conflict: expected step {}, found {}", - expected_step, step + "concurrency conflict: expected last_step {expected}, found {actual}" ), }); } - None if expected_step != 0 => { + None if expected != 0 => { return Err(CheckpointerError::Backend { message: format!( - "concurrency conflict: session not found, expected step {}", - expected_step + "concurrency conflict: session not found, expected step {expected}" ), }); } - _ => {} // Check passed + _ => {} } } - // Ensure session row exists - sqlx::query( - r#" - INSERT OR IGNORE INTO sessions (id, concurrency_limit) - VALUES (?1, ?2) - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.concurrency_limit as i64) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert session: {e}"), - })?; - - // Insert step row (fail if step already exists to prevent overwrites) - sqlx::query( - r#" - INSERT INTO steps ( - session_id, - step, - state_json, - frontier_json, - versions_seen_json, - ran_nodes_json, - skipped_nodes_json, - updated_channels_json - ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) - "#, - ) - .bind(&checkpoint.session_id) - .bind(checkpoint.step as i64) - .bind(&state_json) - .bind(&frontier_json) - .bind(&versions_seen_json) - .bind(&ran_nodes_json) - .bind(&skipped_nodes_json) - .bind(&updated_channels_json) - .execute(&mut *tx) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("insert step: {e}"), - })?; + exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; tx.commit().await.map_err(|e| CheckpointerError::Backend { - message: format!("tx commit: {e}"), - })?; - - Ok(()) + message: format!("commit: {e}"), + }) } - /// Helper to convert a database row to a Checkpoint. - async fn row_to_checkpoint( - &self, - session_id: &str, - row: &sqlx::sqlite::SqliteRow, - ) -> Result { + fn row_to_checkpoint(&self, session_id: &str, row: &SqliteRow) -> Result { let step: i64 = row.get("step"); + let created_at_str: String = row.get("created_at"); let state_json: String = row.get("state_json"); let frontier_json: String = row.get("frontier_json"); let versions_seen_json: String = row.get("versions_seen_json"); let ran_nodes_json: String = row.get("ran_nodes_json"); let skipped_nodes_json: String = row.get("skipped_nodes_json"); - let updated_channels_json: String = row.get("updated_channels_json"); - let created_at_str: String = row.get("created_at"); - // Deserialize using persistence models - let state_val: Value = deserialize_json(&state_json, "state")?; - let frontier_val: Value = deserialize_json(&frontier_json, "frontier")?; - let versions_seen_val: Value = deserialize_json(&versions_seen_json, "versions_seen")?; - let ran_nodes_val: Value = deserialize_json(&ran_nodes_json, "ran_nodes")?; - let skipped_nodes_val: Value = deserialize_json(&skipped_nodes_json, "skipped_nodes")?; - let updated_channels_val: Value = - deserialize_json(&updated_channels_json, "updated_channels")?; - - let persisted_state: PersistedState = deserialize_json_value(state_val, "state")?; - let state = - VersionedState::try_from(persisted_state).map_err(|e| CheckpointerError::Other { - message: format!("state convert: {e}"), - })?; - - let frontier: Vec = frontier_val - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "frontier not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - - let ran_nodes: Vec = ran_nodes_val - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "ran_nodes not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - - let skipped_nodes: Vec = skipped_nodes_val - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "skipped_nodes not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(NodeKind::decode) - .collect(); - - let updated_channels: Vec = updated_channels_val - .as_array() - .ok_or_else(|| CheckpointerError::Other { - message: "updated_channels not array".to_string(), - })? - .iter() - .filter_map(|v| v.as_str()) - .map(|s| s.to_string()) - .collect(); + let updated_channels_json: Option = + row.try_get("updated_channels_json").ok().flatten(); + let updated_channels = match updated_channels_json { + Some(ref json) => from_json_str::>(json, "updated_channels")?, + None => vec![], + }; - let persisted_vs: PersistedVersionsSeen = - deserialize_json_value(versions_seen_val, "versions_seen")?; - let versions_seen = persisted_vs.0; + let pv: PersistedVersionsSeen = from_json_str(&versions_seen_json, "versions_seen")?; let created_at = DateTime::parse_from_rfc3339(&created_at_str) .map(|dt| dt.with_timezone(&Utc)) @@ -837,14 +459,128 @@ impl SQLiteCheckpointer { Ok(Checkpoint { session_id: session_id.to_string(), step: step as u64, - state, - frontier, - versions_seen, - concurrency_limit: 1, // Will need to be retrieved from session table if needed + state: decode_state(&state_json)?, + frontier: decode_node_kinds(&frontier_json)?, + ran_nodes: decode_node_kinds(&ran_nodes_json)?, + skipped_nodes: decode_node_kinds(&skipped_nodes_json)?, + versions_seen: pv.0, + // query_steps() does not join sessions; concurrency_limit is not available per-row. + concurrency_limit: 1, created_at, - ran_nodes, - skipped_nodes, updated_channels, }) } } + +// --------------------------------------------------------------------------- +// Serialization helpers +// --------------------------------------------------------------------------- + +struct EncodedCheckpoint { + state_json: String, + frontier_json: String, + versions_seen_json: String, + ran_nodes_json: String, + skipped_nodes_json: String, + updated_channels_json: String, +} + +impl EncodedCheckpoint { + fn encode(cp: &Checkpoint) -> Result { + let frontier_enc: Vec = cp.frontier.iter().map(NodeKind::encode).collect(); + let ran_nodes_enc: Vec = cp.ran_nodes.iter().map(NodeKind::encode).collect(); + let skipped_enc: Vec = cp.skipped_nodes.iter().map(NodeKind::encode).collect(); + + Ok(Self { + state_json: to_json(&PersistedState::from(&cp.state), "state")?, + frontier_json: to_json(&frontier_enc, "frontier")?, + versions_seen_json: to_json( + &PersistedVersionsSeen(cp.versions_seen.clone()), + "versions_seen", + )?, + ran_nodes_json: to_json(&ran_nodes_enc, "ran_nodes")?, + skipped_nodes_json: to_json(&skipped_enc, "skipped_nodes")?, + updated_channels_json: to_json(&cp.updated_channels, "updated_channels")?, + }) + } +} + +fn to_json(value: &T, ctx: &'static str) -> Result { + serde_json::to_string(value).map_err(|e| CheckpointerError::Other { + message: format!("{ctx} serialize: {e}"), + }) +} + +fn from_json_str(json: &str, ctx: &'static str) -> Result { + serde_json::from_str(json).map_err(|e| CheckpointerError::Other { + message: format!("{ctx} parse: {e}"), + }) +} + +fn need_field(opt: Option, name: &'static str) -> Result { + opt.ok_or_else(|| CheckpointerError::Other { + message: format!("missing field {name}"), + }) +} + +fn decode_state(json: &str) -> Result { + let persisted: PersistedState = from_json_str(json, "state")?; + VersionedState::try_from(persisted).map_err(|e| CheckpointerError::Other { + message: format!("state convert: {e}"), + }) +} + +fn decode_node_kinds(json: &str) -> Result> { + let encoded: Vec = from_json_str(json, "node_kinds")?; + Ok(encoded.iter().map(|s| NodeKind::decode(s)).collect()) +} + +// --------------------------------------------------------------------------- +// SQL execution helpers +// --------------------------------------------------------------------------- + +type Tx = sqlx::Transaction<'static, sqlx::Sqlite>; + +async fn exec_insert_session( + tx: &mut Tx, + session_id: &str, + concurrency_limit: usize, +) -> Result<()> { + sqlx::query("INSERT OR IGNORE INTO sessions (id, concurrency_limit) VALUES (?1, ?2)") + .bind(session_id) + .bind(concurrency_limit as i64) + .execute(&mut **tx) + .await + .map(|_| ()) + .map_err(|e| CheckpointerError::Backend { + message: format!("insert session: {e}"), + }) +} + +async fn exec_upsert_step( + tx: &mut Tx, + session_id: &str, + step: u64, + enc: &EncodedCheckpoint, +) -> Result<()> { + sqlx::query( + "INSERT OR REPLACE INTO steps \ + (session_id, step, state_json, frontier_json, versions_seen_json, \ + ran_nodes_json, skipped_nodes_json, updated_channels_json) \ + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + ) + .bind(session_id) + .bind(step as i64) + .bind(&enc.state_json) + .bind(&enc.frontier_json) + .bind(&enc.versions_seen_json) + .bind(&enc.ran_nodes_json) + .bind(&enc.skipped_nodes_json) + .bind(&enc.updated_channels_json) + .execute(&mut **tx) + .await + .map(|_| ()) + .map_err(|e| CheckpointerError::Backend { + message: format!("upsert step: {e}"), + }) +} diff --git a/src/runtimes/checkpointer_sqlite_helpers.rs b/src/runtimes/checkpointer_sqlite_helpers.rs deleted file mode 100644 index 1338c73..0000000 --- a/src/runtimes/checkpointer_sqlite_helpers.rs +++ /dev/null @@ -1,71 +0,0 @@ -/*! -Helper utilities for SQLite checkpointer operations. - -This module provides consistent JSON serialization/deserialization helpers -that standardize error handling and context reporting across the SQLite -checkpointer implementation. All functions convert serde errors to appropriate -`CheckpointerError` variants with contextual information. - -## Usage - -These helpers should be used for all JSON operations in the SQLite checkpointer -to ensure consistent error handling and debugging information. - -```rust,ignore -// Serialization -let json_str = serialize_json(&data, "user_data")?; - -// Deserialization -let data: MyType = deserialize_json(&json_str, "user_data")?; - -// Value deserialization -let data: MyType = deserialize_json_value(json_value, "user_data")?; - -// Field validation -let required_field = require_json_field(optional_field, "state_json")?; -``` -*/ - -use crate::runtimes::checkpointer::CheckpointerError; -use crate::utils::json_ext::{deserialize_with_context, serialize_with_context}; -use serde_json::Value; - -/// Helper for JSON serialization with consistent error formatting. -pub(super) fn serialize_json( - value: &T, - context: &'static str, -) -> Result { - serialize_with_context(value, context, |e, ctx| CheckpointerError::Other { - message: format!("{ctx} serialize: {e}"), - }) -} - -/// Helper for JSON deserialization with consistent error formatting. -pub(super) fn deserialize_json( - json: &str, - context: &'static str, -) -> Result { - deserialize_with_context(json, context, |e, ctx| CheckpointerError::Other { - message: format!("{ctx} parse: {e}"), - }) -} - -/// Helper for JSON value deserialization with consistent error formatting. -pub(super) fn deserialize_json_value( - value: Value, - context: &'static str, -) -> Result { - serde_json::from_value(value).map_err(|e| CheckpointerError::Other { - message: format!("{context} parse (serde): {e}"), - }) -} - -/// Helper for extracting required JSON fields with consistent error formatting. -pub(super) fn require_json_field( - field: Option, - field_name: &'static str, -) -> Result { - field.ok_or_else(|| CheckpointerError::Other { - message: format!("missing {field_name} for persisted checkpoint"), - }) -} diff --git a/src/runtimes/execution.rs b/src/runtimes/execution.rs index d750ccf..a5183ce 100644 --- a/src/runtimes/execution.rs +++ b/src/runtimes/execution.rs @@ -1,44 +1,25 @@ -//! Step execution types and logic for workflow runs. -//! -//! This module defines the types used to represent step execution results, -//! pause conditions, and execution options during workflow processing. +//! Step execution types for workflow runs. use crate::app::BarrierOutcome; use crate::node::NodePartial; use crate::runtimes::session::{SessionState, StateVersions}; use crate::types::NodeKind; -/// Result of executing one superstep in a session. +/// Result of executing one superstep. /// /// The embedded [`BarrierOutcome`] carries the canonical ordering of -/// updates/errors so callers can persist and resume without drift. -/// -/// # Examples -/// -/// ```rust,no_run -/// use weavegraph::runtimes::StepReport; -/// -/// fn analyze_step(report: &StepReport) { -/// println!("Step {} completed", report.step); -/// println!("Ran {} nodes, skipped {}", -/// report.ran_nodes.len(), -/// report.skipped_nodes.len()); -/// if report.completed { -/// println!("Workflow finished!"); -/// } -/// } -/// ``` +/// updates and errors so callers can persist and resume without drift. #[derive(Debug, Clone)] pub struct StepReport { - /// The step number that was executed. + /// Step index that was executed. pub step: u64, /// Nodes that ran during this step. pub ran_nodes: Vec, - /// Nodes that were skipped (e.g., End nodes or version-gated). + /// Nodes that were skipped (version-gated or End nodes). pub skipped_nodes: Vec, - /// The outcome from applying the barrier. + /// Outcome from applying the barrier. pub barrier_outcome: BarrierOutcome, - /// The frontier for the next step. + /// Frontier for the next step. pub next_frontier: Vec, /// Channel versions after this step completed. pub state_versions: StateVersions, @@ -46,77 +27,50 @@ pub struct StepReport { pub completed: bool, } -/// Options for controlling step execution behavior. -/// -/// Use these options to implement human-in-the-loop workflows, debugging, -/// or step-by-step execution patterns. -/// -/// # Examples -/// -/// ```rust -/// use weavegraph::runtimes::StepOptions; -/// use weavegraph::types::NodeKind; -/// -/// // Pause before a specific node -/// let options = StepOptions { -/// interrupt_before: vec![NodeKind::Custom("approval".into())], -/// interrupt_after: vec![], -/// interrupt_each_step: false, -/// }; -/// ``` +/// Controls which nodes or steps trigger an execution pause. #[derive(Debug, Clone, Default)] pub struct StepOptions { - /// Nodes to pause execution before (for human-in-the-loop). + /// Pause before executing these nodes. pub interrupt_before: Vec, - /// Nodes to pause execution after. + /// Pause after executing these nodes. pub interrupt_after: Vec, - /// Whether to pause after each step (debugging mode). + /// Pause after every step (debugging mode). pub interrupt_each_step: bool, } -/// The reason why execution was paused. -/// -/// When a workflow is paused (not completed), this enum indicates -/// why the pause occurred, enabling appropriate handling. +/// Reason why execution was paused mid-run. #[derive(Debug, Clone)] pub enum PausedReason { - /// Paused before executing the specified node. + /// Paused before the specified node ran. BeforeNode(NodeKind), - /// Paused after executing the specified node. + /// Paused after the specified node ran. AfterNode(NodeKind), - /// Paused after completing the specified step number. + /// Paused after the specified step completed. AfterStep(u64), } -/// Extended step report when execution is paused. -/// -/// Contains the full session state at the point of pause, allowing -/// inspection, modification, or later resumption. +/// Snapshot returned when execution is paused rather than completed. #[derive(Debug, Clone)] pub struct PausedReport { - /// The complete session state at the pause point. + /// Session state at the pause point. pub session_state: SessionState, /// Why execution was paused. pub reason: PausedReason, } -/// Result of attempting to run a step. -/// -/// Either the step completed normally, or execution was paused -/// for one of several reasons (human-in-the-loop, debugging, etc.). +/// Outcome of a single step attempt. #[derive(Debug, Clone)] pub enum StepResult { - /// The step completed and execution can continue. + /// Step ran to completion; execution can continue. Completed(StepReport), - /// Execution was paused before completion. + /// Execution was paused before the step finished. Paused(PausedReport), } -/// Internal outcome from scheduler after normalization. -/// -/// Contains ordered partials ready for barrier application. +/// Normalized scheduler output, ready for barrier application. pub(crate) struct SchedulerOutcome { pub ran_nodes: Vec, pub skipped_nodes: Vec, pub partials: Vec, } + diff --git a/src/runtimes/metrics_observer.rs b/src/runtimes/metrics_observer.rs index 216747f..3f68d97 100644 --- a/src/runtimes/metrics_observer.rs +++ b/src/runtimes/metrics_observer.rs @@ -1,30 +1,25 @@ -//! Feature-gated [`RuntimeObserver`] implementation using the [`metrics`] crate facade. +//! Feature-gated [`RuntimeObserver`] backed by the [`metrics`] crate. //! -//! Enable with `features = ["metrics"]`. This module emits standard counters and -//! histograms that any `metrics`-compatible recorder (e.g. `metrics-exporter-prometheus`) -//! can capture. +//! Enable with `features = ["metrics"]`. Metrics are forwarded to whatever +//! [`metrics`]-compatible recorder is installed (e.g. `metrics-exporter-prometheus`). //! -//! # Metric inventory +//! ## Metric inventory //! //! | Metric | Kind | Labels | Description | //! |--------|------|--------|-------------| //! | `weavegraph.node.invocations` | counter | `node`, `outcome` | Completed node executions | -//! | `weavegraph.node.step_duration_ms` | histogram | `node` | Superstep duration (shared across parallel nodes) | +//! | `weavegraph.node.step_duration_ms` | histogram | `node` | Superstep wall-clock duration | //! | `weavegraph.invocation.count` | counter | `outcome` | Completed workflow invocations | -//! | `weavegraph.invocation.duration_ms` | histogram | (none) | Invocation wall-clock duration | +//! | `weavegraph.invocation.duration_ms` | histogram | — | Invocation wall-clock duration | //! | `weavegraph.checkpoint.saves` | counter | `backend` | Successful checkpoint saves | //! | `weavegraph.checkpoint.save_duration_ms` | histogram | `backend` | Checkpoint save duration | -//! | `weavegraph.checkpoint.loads` | counter | `backend` | Sessions resumed from a checkpoint | -//! | `weavegraph.event_bus.emits` | counter | `scope` | Events emitted through the event bus | +//! | `weavegraph.checkpoint.loads` | counter | `backend` | Sessions resumed from checkpoint | +//! | `weavegraph.event_bus.emits` | counter | `scope` | Events emitted through the bus | //! -//! # Cardinality note +//! Per-session and per-invocation identifiers are deliberately excluded from labels +//! to keep cardinality bounded in long-running services. //! -//! Labels are kept conservative by default. `session_id` and `invocation_id` are -//! intentionally **not** included as labels to avoid unbounded cardinality in -//! long-running services. The `node` label uses the node kind's string encoding -//! (e.g. `"features"`, `"strategy"`). -//! -//! # Usage +//! ## Usage //! //! ```rust,no_run //! use std::sync::Arc; @@ -44,31 +39,23 @@ use std::panic::RefUnwindSafe; use crate::runtimes::observer::{ CheckpointLoadMeta, CheckpointSaveMeta, EventBusEmitMeta, InvocationFinishMeta, - InvocationStartMeta, NodeFinishMeta, NodeOutcome, RuntimeObserver, + InvocationOutcome, NodeFinishMeta, NodeOutcome, RuntimeObserver, }; -/// A [`RuntimeObserver`] that emits metrics via the [`metrics`] crate facade. -/// -/// Install a compatible recorder (e.g. `metrics-exporter-prometheus`) before -/// starting the runner to have these metrics exported to your observability stack. +/// A [`RuntimeObserver`] that forwards lifecycle events to the [`metrics`] crate. /// -/// See the [module documentation](self) for the full metric inventory. +/// Install a compatible recorder before starting the runner; see the +/// [module documentation](self) for the full metric inventory. #[derive(Debug, Clone, Copy)] pub struct MetricsObserver; -// MetricsObserver holds no interior mutability and all metrics calls are -// thread-safe through the global recorder, so RefUnwindSafe is safe to assert. impl RefUnwindSafe for MetricsObserver {} impl RuntimeObserver for MetricsObserver { - fn on_invocation_start(&self, _meta: &InvocationStartMeta<'_>) { - // Nothing to emit at start — counts and durations are emitted on finish. - } - fn on_invocation_finish(&self, meta: &InvocationFinishMeta<'_>) { let outcome = match meta.outcome { - crate::runtimes::observer::InvocationOutcome::Completed => "completed", - crate::runtimes::observer::InvocationOutcome::Error => "error", + InvocationOutcome::Completed => "completed", + InvocationOutcome::Error => "error", }; metrics::counter!("weavegraph.invocation.count", "outcome" => outcome).increment(1); metrics::histogram!("weavegraph.invocation.duration_ms").record(meta.duration_ms as f64); @@ -78,19 +65,19 @@ impl RuntimeObserver for MetricsObserver { let node = meta.node_kind.encode().to_string(); let outcome = match meta.outcome { NodeOutcome::Completed => "completed", - NodeOutcome::Error => "error", NodeOutcome::Skipped => "skipped", + NodeOutcome::Error => "error", }; + if meta.outcome != NodeOutcome::Skipped { + metrics::histogram!("weavegraph.node.step_duration_ms", "node" => node.clone()) + .record(meta.step_duration_ms as f64); + } metrics::counter!( "weavegraph.node.invocations", - "node" => node.clone(), + "node" => node, "outcome" => outcome ) .increment(1); - if meta.outcome != NodeOutcome::Skipped { - metrics::histogram!("weavegraph.node.step_duration_ms", "node" => node) - .record(meta.step_duration_ms as f64); - } } fn on_checkpoint_load(&self, meta: &CheckpointLoadMeta<'_>) { @@ -103,12 +90,12 @@ impl RuntimeObserver for MetricsObserver { fn on_checkpoint_save(&self, meta: &CheckpointSaveMeta<'_>) { let backend = meta.backend.to_string(); - metrics::counter!("weavegraph.checkpoint.saves", "backend" => backend.clone()).increment(1); metrics::histogram!( "weavegraph.checkpoint.save_duration_ms", - "backend" => backend + "backend" => backend.clone() ) .record(meta.duration_ms as f64); + metrics::counter!("weavegraph.checkpoint.saves", "backend" => backend).increment(1); } fn on_event_bus_emit(&self, meta: &EventBusEmitMeta<'_>) { diff --git a/src/runtimes/mod.rs b/src/runtimes/mod.rs index cca02a8..8885af1 100644 --- a/src/runtimes/mod.rs +++ b/src/runtimes/mod.rs @@ -1,41 +1,22 @@ -//! Workflow runtime infrastructure for session management and state persistence. +//! Runtime infrastructure for stepwise workflow execution and state persistence. //! -//! This module provides the core runtime components for executing workflows with -//! support for checkpointing, session management, and resumable execution. The -//! runtime layer abstracts over different persistence backends while maintaining -//! a consistent API for workflow execution. -//! -//! # Architecture -//! -//! The runtime is built around several key abstractions: -//! -//! - **[`AppRunner`]** - Main orchestrator for stepwise workflow execution -//! - **[`Checkpointer`]** - Trait for pluggable state persistence -//! - **[`SessionState`]** - In-memory representation of execution state -//! - **Persistence Models** - Serde-friendly types for state serialization -//! -//! # Persistence Backends -//! -//! - **[`InMemoryCheckpointer`]** - Volatile storage for testing and development -//! - **[`SQLiteCheckpointer`]** - Durable SQLite-backed persistence -//! -//! # Usage Example +//! This module provides the components needed to run, checkpoint, and resume +//! workflows across sessions. The runtime layer is backend-agnostic: the same +//! [`AppRunner`] API works with in-memory, SQLite, and (with the `postgres` +//! feature) PostgreSQL checkpointers. //! //! ```rust,no_run //! use weavegraph::runtimes::{AppRunner, CheckpointerType}; //! use weavegraph::state::VersionedState; //! # use weavegraph::app::App; //! # async fn example(app: App) -> Result<(), Box> { -//! //! let mut runner = AppRunner::builder() //! .app(app) //! .checkpointer(CheckpointerType::InMemory) //! .build() //! .await; -//! let initial_state = VersionedState::new_with_user_message("Hello"); -//! -//! // Create session and run to completion -//! runner.create_session("session_1".to_string(), initial_state).await?; +//! let state = VersionedState::new_with_user_message("Hello"); +//! runner.create_session("session_1".to_string(), state).await?; //! let final_state = runner.run_until_complete("session_1").await?; //! # Ok(()) //! # } @@ -45,13 +26,9 @@ pub mod checkpointer; #[cfg(feature = "postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "postgres")))] pub mod checkpointer_postgres; -#[cfg(feature = "postgres")] -mod checkpointer_postgres_helpers; #[cfg(feature = "sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] pub mod checkpointer_sqlite; -#[cfg(feature = "sqlite")] -mod checkpointer_sqlite_helpers; pub mod execution; #[cfg(feature = "metrics")] #[cfg_attr(docsrs, doc(cfg(feature = "metrics")))] @@ -78,16 +55,9 @@ pub use checkpointer_postgres::{ #[cfg(feature = "sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] pub use checkpointer_sqlite::{PageInfo, SQLiteCheckpointer, StepQuery, StepQueryResult}; - -// Re-export execution types pub use execution::{PausedReason, PausedReport, StepOptions, StepReport, StepResult}; - -// Re-export session types -pub use session::{SessionInit, SessionState, StateVersions}; - -// Re-export runner pub use runner::{AppRunner, AppRunnerBuilder, RunMetadata}; - +pub use session::{SessionInit, SessionState, StateVersions}; pub use replay::{ ReplayComparison, ReplayConformanceError, ReplayRun, StateNormalizeProfile, compare_event_sequences, compare_event_sequences_with, compare_final_state, @@ -96,7 +66,6 @@ pub use replay::{ }; pub use runtime_config::{EventBusConfig, RuntimeConfig, SinkConfig}; pub use types::{SessionId, StepNumber}; - #[cfg(feature = "metrics")] pub use metrics_observer::MetricsObserver; pub use observer::{ diff --git a/src/runtimes/observer.rs b/src/runtimes/observer.rs index 0f3962a..04b1834 100644 --- a/src/runtimes/observer.rs +++ b/src/runtimes/observer.rs @@ -1,14 +1,10 @@ -//! Runtime observer trait and metadata types for workflow telemetry hooks. +//! [`RuntimeObserver`] trait and associated metadata types for workflow telemetry. //! -//! `RuntimeObserver` is an opt-in interface that receives structured callbacks -//! at key points during graph execution: invocation boundaries, per-node -//! completion, checkpoint operations, and event-bus emissions. All methods -//! have default no-op implementations, so implementors only override the hooks -//! they care about. +//! Provides opt-in hooks at key execution boundaries: invocation start/finish, +//! per-node completion, checkpoint operations, and event-bus emissions. Every +//! method defaults to a no-op; implement only the hooks you need. //! -//! # Usage -//! -//! Register an observer when building a runner: +//! ## Usage //! //! ```rust,no_run //! use std::sync::Arc; @@ -36,19 +32,17 @@ //! # } //! ``` //! -//! # Performance contract +//! ## Performance //! -//! Observer hooks are called **synchronously** on the execution thread. They must -//! be fast and non-blocking. Panics inside hooks are caught by the runner, which -//! logs a warning via [`tracing`] and continues execution — the graph is never -//! brought down by an observer failure. +//! Hooks are called **synchronously** on the execution thread and must be fast +//! and non-blocking. Panics are caught by the runner, which emits a +//! `tracing::warn!` and continues — an observer failure never aborts the graph. //! -//! # Note on per-node timing in 0.6.0 +//! ## Per-node timing in 0.6.0 //! -//! In this release, `step_duration_ms` in [`NodeFinishMeta`] reflects the elapsed -//! time for the **entire superstep** that contained the node, not the per-node -//! wall time. Nodes within the same superstep share the step's duration. Per-node -//! timing would require scheduler-level instrumentation and is planned for a +//! `step_duration_ms` in [`NodeFinishMeta`] is the wall-clock duration of the +//! **entire superstep**, shared by every node that ran in the same parallel step. +//! Per-node timing requires scheduler-level instrumentation and is planned for a //! future release. use std::fmt; @@ -56,10 +50,6 @@ use std::panic::RefUnwindSafe; use crate::types::NodeKind; -// ============================================================================ -// Outcome enums -// ============================================================================ - /// Outcome of a completed workflow invocation. #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[non_exhaustive] @@ -82,16 +72,11 @@ pub enum NodeOutcome { Skipped, } -// ============================================================================ -// Metadata structs — all #[non_exhaustive] so fields can be added without -// breaking implementors that destructure them (though &-access is idiomatic). -// ============================================================================ - /// Metadata supplied to [`RuntimeObserver::on_invocation_start`]. #[derive(Debug)] #[non_exhaustive] pub struct InvocationStartMeta<'a> { - /// The session identifier for this invocation. + /// Session identifier for this invocation. pub session_id: &'a str, /// Stable fingerprint of the compiled graph definition. /// @@ -103,7 +88,7 @@ pub struct InvocationStartMeta<'a> { #[derive(Debug)] #[non_exhaustive] pub struct InvocationFinishMeta<'a> { - /// The session identifier. + /// Session identifier. pub session_id: &'a str, /// Stable fingerprint of the compiled graph definition. pub graph_id: &'a str, @@ -115,20 +100,17 @@ pub struct InvocationFinishMeta<'a> { /// Metadata supplied to [`RuntimeObserver::on_node_finish`]. /// -/// See [module-level note](self) on per-node timing in 0.6.0. +/// See the [per-node timing note](self#per-node-timing-in-060) in the module docs. #[derive(Debug)] #[non_exhaustive] pub struct NodeFinishMeta<'a> { /// The node that completed. pub node_kind: &'a NodeKind, - /// The session identifier. + /// Session identifier. pub session_id: &'a str, - /// The step number within which this node executed. + /// Step number within which this node executed. pub step: u64, - /// Elapsed time for the superstep containing this node, in milliseconds. - /// - /// All nodes in the same superstep share this value. Per-node timing - /// is not available in 0.6.0. + /// Superstep wall-clock duration in milliseconds; shared by all nodes in the step. pub step_duration_ms: u64, /// Outcome of this node. pub outcome: NodeOutcome, @@ -138,11 +120,11 @@ pub struct NodeFinishMeta<'a> { #[derive(Debug)] #[non_exhaustive] pub struct CheckpointLoadMeta<'a> { - /// The session identifier. + /// Session identifier. pub session_id: &'a str, - /// Human-readable backend name (e.g. `"sqlite"`, `"postgres"`, `"in-memory"`). + /// Backend name (e.g. `"sqlite"`, `"postgres"`, `"in-memory"`). pub backend: &'a str, - /// The step number that was loaded from the checkpoint. + /// Step number loaded from the checkpoint. pub step: u64, } @@ -150,11 +132,11 @@ pub struct CheckpointLoadMeta<'a> { #[derive(Debug)] #[non_exhaustive] pub struct CheckpointSaveMeta<'a> { - /// The session identifier. + /// Session identifier. pub session_id: &'a str, - /// Human-readable backend name. + /// Backend name (e.g. `"sqlite"`, `"postgres"`, `"in-memory"`). pub backend: &'a str, - /// The step number that was saved. + /// Step number saved. pub step: u64, /// Wall-clock duration of the save operation in milliseconds. pub duration_ms: u64, @@ -164,43 +146,33 @@ pub struct CheckpointSaveMeta<'a> { #[derive(Debug)] #[non_exhaustive] pub struct EventBusEmitMeta<'a> { - /// The scope label of the emitted event (e.g. `"features"`, `"__weavegraph_stream_end__"`). + /// Scope label of the emitted event (e.g. `"features"`, `"__weavegraph_stream_end__"`). pub scope: &'a str, } -// ============================================================================ -// RuntimeObserver trait -// ============================================================================ - -/// Observer interface for runtime telemetry hooks. +/// Opt-in telemetry interface for the runtime. /// /// Register an implementation with /// [`AppRunnerBuilder::observer`](crate::runtimes::runner::AppRunnerBuilder::observer). -/// All methods default to no-ops; implement only the callbacks you need. -/// -/// # Safety contract +/// All methods default to no-ops; override only the hooks you need. /// -/// Implementations **must not panic** — panics are caught by the runner and -/// produce a `tracing::warn!` log entry. The supertrait bound [`RefUnwindSafe`] -/// is required to make this catch-and-continue safe without `AssertUnwindSafe` -/// wrappers at every callsite. +/// ## Contract /// -/// Implementations must be `Send + Sync` as the runner can share them across -/// async tasks. +/// Implementations must be `Send + Sync + 'static` — the runner shares observers +/// across async tasks. The [`RefUnwindSafe`] bound lets the runner catch panics +/// inside hooks without an `AssertUnwindSafe` wrapper at every callsite; a +/// panicking hook produces a `tracing::warn!` and execution continues. pub trait RuntimeObserver: Send + Sync + fmt::Debug + RefUnwindSafe + 'static { - /// Called immediately before a workflow invocation begins running. + /// Called immediately before a workflow invocation begins. fn on_invocation_start(&self, _meta: &InvocationStartMeta<'_>) {} /// Called after a workflow invocation finishes (successfully or with an error). fn on_invocation_finish(&self, _meta: &InvocationFinishMeta<'_>) {} - /// Called once for each node after the superstep containing it completes. - /// - /// In 0.6.0, `step_duration_ms` is the superstep duration shared by all - /// nodes in the same parallel step. See the [module note](self). + /// Called once per node after the superstep containing it completes. fn on_node_finish(&self, _meta: &NodeFinishMeta<'_>) {} - /// Called after a checkpoint is successfully loaded during session creation. + /// Called after a checkpoint is successfully loaded during session resumption. fn on_checkpoint_load(&self, _meta: &CheckpointLoadMeta<'_>) {} /// Called after a checkpoint is successfully saved. diff --git a/src/runtimes/persistence.rs b/src/runtimes/persistence.rs index 6743ce1..6a7ccd6 100644 --- a/src/runtimes/persistence.rs +++ b/src/runtimes/persistence.rs @@ -1,24 +1,17 @@ -/*! -Persistence primitives for serializing/deserializing Weavegraph runtime -state and checkpoints (used by the SQLite checkpointer and any future -persistent backends). - -Design Goals: -- Provide explicit serde-friendly structs decoupled from internal - in-memory representations. -- Keep conversion logic localized (From / TryFrom impls) so the - checkpointer code is lean and declarative. -- Allow forward compatibility (unknown NodeKind encodings round-trip - as `NodeKind::Custom(encoded_string)`). - -This module intentionally does NOT perform I/O. It is pure data -transformation and (de)serialization glue. -*/ +//! Serialization bridge between live runtime types and their persisted forms. +//! +//! The `Persisted*` types are stable, serde-friendly representations of runtime +//! state, checkpoints, and scheduler metadata. Conversion between live and persisted +//! types is handled through `From` / `TryFrom` impls, keeping checkpointer code +//! free of transformation detail. +//! +//! This module performs no I/O. use chrono::Utc; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use serde_json::Value; +use thiserror::Error; use crate::{ channels::{Channel, ExtrasChannel, MessagesChannel}, @@ -29,26 +22,56 @@ use crate::{ utils::json_ext::JsonSerializable, }; -/// Blanket implementation of JsonSerializable for all suitable types using PersistenceError. +/// Errors arising from persistence conversion or JSON serialization. +#[derive(Debug, Error)] +#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] +pub enum PersistenceError { + /// A serde_json serialization or deserialization failure. + #[error("serialization error: {source}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic( + code(weavegraph::persistence::serde), + help("Check that the JSON structure matches the Persisted* types.") + ) + )] + Serde { + /// The underlying serde_json error. + #[source] + source: serde_json::Error, + }, + + /// A catch-all for other persistence failures. + #[error("persistence error: {0}")] + #[cfg_attr( + feature = "diagnostics", + diagnostic(code(weavegraph::persistence::other)) + )] + Other(String), +} + +/// Convenience alias for persistence results. +pub type Result = std::result::Result; + impl JsonSerializable for T where T: serde::Serialize + for<'de> serde::de::DeserializeOwned, { fn to_json_string(&self) -> std::result::Result { - serde_json::to_string(self).map_err(|e| PersistenceError::Serde { source: e }) + serde_json::to_string(self).map_err(|source| PersistenceError::Serde { source }) } fn from_json_str(s: &str) -> std::result::Result { - serde_json::from_str(s).map_err(|e| PersistenceError::Serde { source: e }) + serde_json::from_str(s).map_err(|source| PersistenceError::Serde { source }) } } -/// Channel that stores a vector collection (e.g., messages) with version metadata. +/// Persisted form of a versioned list channel (e.g., messages, errors). #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PersistedVecChannel { - /// Version counter for change-detection. + /// Change-detection counter; starts at 1 for a freshly created channel. pub version: u32, - /// The stored items. + /// Stored items. #[serde(default)] pub items: Vec, } @@ -62,12 +85,12 @@ impl Default for PersistedVecChannel { } } -/// Channel that stores a map collection (e.g., extra) with version metadata. +/// Persisted form of a versioned map channel (e.g., extras). #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PersistedMapChannel { - /// Version counter for change-detection. + /// Change-detection counter; starts at 1 for a freshly created channel. pub version: u32, - /// The stored key-value map. + /// Stored key-value pairs. #[serde(default)] pub map: FxHashMap, } @@ -81,100 +104,53 @@ impl Default for PersistedMapChannel { } } -/// Complete persisted shape of the in‑memory VersionedState. +/// Persisted snapshot of a [`VersionedState`]. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PersistedState { /// Persisted messages channel. pub messages: PersistedVecChannel, - /// Persisted extra key-value channel. + /// Persisted extras channel. pub extra: PersistedMapChannel, /// Persisted errors channel. #[serde(default)] pub errors: PersistedVecChannel, } -/// Wrapper for the scheduler versions_seen structure. +/// Persisted form of the scheduler's `versions_seen` tracking map. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct PersistedVersionsSeen(pub FxHashMap>); -/// Full persisted checkpoint representation. -/// (Step history tables may store multiple instances of this shape.) +/// Complete persisted representation of a runtime checkpoint. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct PersistedCheckpoint { - /// Unique session identifier. + /// Session this checkpoint belongs to. pub session_id: String, - /// Workflow step number for this checkpoint. + /// Monotonically increasing step counter. pub step: u64, /// Full state snapshot at this step. pub state: PersistedState, - /// Frontier encoded as string vector using NodeKind::encode(). + /// Frontier nodes encoded via [`NodeKind::encode`]. pub frontier: Vec, /// Scheduler version-gating state. pub versions_seen: PersistedVersionsSeen, - /// Maximum concurrent nodes for this session. + /// Maximum concurrent node executions for this session. pub concurrency_limit: usize, - /// RFC3339 string form of creation time (keeps chrono::DateTime out of serialized shape). + /// RFC 3339 timestamp, avoiding a `chrono` type in the serialized shape. pub created_at: String, - /// Nodes that executed in this step, encoded as strings + /// Nodes that executed in this step, encoded as strings. #[serde(default)] pub ran_nodes: Vec, - /// Nodes that were skipped in this step, encoded as strings + /// Nodes that were skipped in this step, encoded as strings. #[serde(default)] pub skipped_nodes: Vec, - /// Channels that were updated in this step + /// Channels that changed during this step. #[serde(default)] pub updated_channels: Vec, } -use thiserror::Error; - -/// Bidirectional conversion and serialization errors for persistence models. -#[derive(Debug, Error)] -#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] -pub enum PersistenceError { - /// A required field was absent from the persisted data. - #[error("missing field: {0}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic( - code(weavegraph::persistence::missing_field), - help("Populate the field in the persisted JSON before conversion.") - ) - )] - MissingField(&'static str), - - /// A JSON serialization or deserialization error. - #[error("JSON serialization/deserialization failed: {source}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic( - code(weavegraph::persistence::serde), - help("Ensure the JSON structure matches Persisted* types; serde error: {source}.") - ) - )] - Serde { - /// The underlying serde_json error. - #[source] - source: serde_json::Error, - }, - - /// Any other persistence error. - #[error("persistence error: {0}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic(code(weavegraph::persistence::other)) - )] - Other(String), -} - -/// Convenience alias for persistence operation results. -pub type Result = std::result::Result; - -/* ---------- VersionedState <-> PersistedState Conversions ---------- */ - impl From<&VersionedState> for PersistedState { fn from(s: &VersionedState) -> Self { - PersistedState { + Self { messages: PersistedVecChannel { version: s.messages.version(), items: s.messages.snapshot(), @@ -195,7 +171,7 @@ impl TryFrom for VersionedState { type Error = PersistenceError; fn try_from(p: PersistedState) -> Result { - Ok(VersionedState { + Ok(Self { messages: MessagesChannel::new(p.messages.items, p.messages.version), extra: ExtrasChannel::new(p.extra.map, p.extra.version), errors: crate::channels::ErrorsChannel::new(p.errors.items, p.errors.version), @@ -203,11 +179,9 @@ impl TryFrom for VersionedState { } } -/* ---------- versions_seen conversions ---------- */ - impl From<&FxHashMap>> for PersistedVersionsSeen { fn from(v: &FxHashMap>) -> Self { - PersistedVersionsSeen(v.clone()) + Self(v.clone()) } } @@ -217,20 +191,18 @@ impl From for FxHashMap> { } } -/* ---------- Checkpoint <-> PersistedCheckpoint Conversions ---------- */ - impl From<&Checkpoint> for PersistedCheckpoint { fn from(cp: &Checkpoint) -> Self { - PersistedCheckpoint { + Self { session_id: cp.session_id.clone(), step: cp.step, state: PersistedState::from(&cp.state), - frontier: cp.frontier.iter().map(|k| k.encode()).collect(), + frontier: cp.frontier.iter().map(NodeKind::encode).collect(), versions_seen: PersistedVersionsSeen(cp.versions_seen.clone()), concurrency_limit: cp.concurrency_limit, created_at: cp.created_at.to_rfc3339(), - ran_nodes: cp.ran_nodes.iter().map(|k| k.encode()).collect(), - skipped_nodes: cp.skipped_nodes.iter().map(|k| k.encode()).collect(), + ran_nodes: cp.ran_nodes.iter().map(NodeKind::encode).collect(), + skipped_nodes: cp.skipped_nodes.iter().map(NodeKind::encode).collect(), updated_channels: cp.updated_channels.clone(), } } @@ -241,32 +213,20 @@ impl TryFrom for Checkpoint { fn try_from(p: PersistedCheckpoint) -> Result { let state = VersionedState::try_from(p.state)?; - let frontier: Vec = p.frontier.iter().map(|s| NodeKind::decode(s)).collect(); - let ran_nodes: Vec = p.ran_nodes.iter().map(|s| NodeKind::decode(s)).collect(); - let skipped_nodes: Vec = p - .skipped_nodes - .iter() - .map(|s| NodeKind::decode(s)) - .collect(); - let parsed_dt = chrono::DateTime::parse_from_rfc3339(&p.created_at) + let created_at = chrono::DateTime::parse_from_rfc3339(&p.created_at) .map(|dt| dt.with_timezone(&Utc)) .unwrap_or_else(|_| Utc::now()); - Ok(Checkpoint { + Ok(Self { session_id: p.session_id, step: p.step, state, - frontier, + frontier: p.frontier.iter().map(|s| NodeKind::decode(s)).collect(), versions_seen: p.versions_seen.0, concurrency_limit: p.concurrency_limit, - created_at: parsed_dt, - ran_nodes, - skipped_nodes, + created_at, + ran_nodes: p.ran_nodes.iter().map(|s| NodeKind::decode(s)).collect(), + skipped_nodes: p.skipped_nodes.iter().map(|s| NodeKind::decode(s)).collect(), updated_channels: p.updated_channels, }) } } - -/* ---------- Convenience JSON helpers (using JsonSerializable trait from utils::json_ext) ---------- */ - -// Both PersistedState and PersistedCheckpoint automatically implement JsonSerializable -// through the blanket implementation above, providing to_json_string() and from_json_str() methods. diff --git a/src/runtimes/replay.rs b/src/runtimes/replay.rs index cdd9c72..99a0bc6 100644 --- a/src/runtimes/replay.rs +++ b/src/runtimes/replay.rs @@ -1,8 +1,7 @@ //! Replay conformance helpers for comparing workflow runs. //! -//! These helpers are intentionally small and test-friendly. They normalize common -//! nondeterministic fields, compare final state and event streams, and return -//! human-readable differences that can be used in ordinary assertions. +//! Normalizes nondeterministic fields, compares final state and event streams, +//! and returns human-readable differences suitable for test assertions. use serde_json::{Value, json}; use thiserror::Error; @@ -24,13 +23,10 @@ pub struct ReplayRun { } impl ReplayRun { - /// Create a replay run from final state and captured events. + /// Construct from final state and captured events. #[must_use] pub fn new(final_state: VersionedState, events: Vec) -> Self { - Self { - final_state, - events, - } + Self { final_state, events } } } @@ -41,40 +37,36 @@ pub struct ReplayComparison { } impl ReplayComparison { - /// Create a successful comparison with no differences. + /// No differences found. #[must_use] pub fn matched() -> Self { - Self { - differences: Vec::new(), - } + Self { differences: Vec::new() } } - /// Create a comparison with the supplied differences. + /// Construct with the supplied differences. #[must_use] pub fn with_differences(differences: Vec) -> Self { Self { differences } } - /// Return true when no differences were found. + /// Returns `true` when no differences were found. #[must_use] pub fn is_match(&self) -> bool { self.differences.is_empty() } - /// Return the differences found during comparison. + /// Returns the differences found during comparison. #[must_use] pub fn differences(&self) -> &[String] { &self.differences } - /// Convert this report into a `Result` suitable for test assertions. + /// Convert into a `Result` for use in assertions. pub fn assert_matches(self) -> Result<(), ReplayConformanceError> { if self.is_match() { Ok(()) } else { - Err(ReplayConformanceError::Mismatch { - differences: self.differences, - }) + Err(ReplayConformanceError::Mismatch { differences: self.differences }) } } } @@ -96,20 +88,19 @@ pub enum ReplayConformanceError { }, } -/// Normalize an event for replay comparison. +/// Strip nondeterministic fields from an event for replay comparison. /// -/// The default normalizer uses Weavegraph's JSON event shape and removes the -/// top-level timestamp, which is normally wall-clock dependent. +/// Removes the top-level `timestamp` field from the event's JSON representation. #[must_use] pub fn normalize_event(event: &Event) -> Value { let mut value = event.to_json_value(); - if let Value::Object(object) = &mut value { - object.remove("timestamp"); + if let Value::Object(obj) = &mut value { + obj.remove("timestamp"); } value } -/// Normalize a final state into a JSON value for stable comparison and diffs. +/// Serialize a final state to JSON for stable comparison. #[must_use] pub fn normalize_state(state: &VersionedState) -> Value { json!({ @@ -145,7 +136,7 @@ pub fn compare_event_sequences(left: &[Event], right: &[Event]) -> ReplayCompari /// Compare two event streams with a caller-provided normalizer. /// /// Use this when domain events contain timestamps, generated IDs, or other -/// values that should be compared semantically rather than byte-for-byte. +/// values that need semantic rather than byte-for-byte comparison. #[must_use] pub fn compare_event_sequences_with( left: &[Event], @@ -171,15 +162,13 @@ where )); } - let shared_len = left_values.len().min(right_values.len()); - for index in 0..shared_len { - if left_values[index] != right_values[index] { - differences.push(format!( - "event {index} differs: left={} right={}", - left_values[index], right_values[index] - )); - break; - } + if let Some((i, (l, r))) = left_values + .iter() + .zip(&right_values) + .enumerate() + .find(|(_, (l, r))| l != r) + { + differences.push(format!("event {i} differs: left={l} right={r}")); } ReplayComparison::with_differences(differences) @@ -202,39 +191,30 @@ where F: Fn(&Event) -> Value, { let mut differences = Vec::new(); - - let state_comparison = compare_final_state(&left.final_state, &right.final_state); - differences.extend(state_comparison.differences().iter().cloned()); - - let event_comparison = - compare_event_sequences_with(&left.events, &right.events, event_normalizer); - differences.extend(event_comparison.differences().iter().cloned()); - + let state_cmp = compare_final_state(&left.final_state, &right.final_state); + differences.extend_from_slice(state_cmp.differences()); + let event_cmp = compare_event_sequences_with(&left.events, &right.events, event_normalizer); + differences.extend_from_slice(event_cmp.differences()); ReplayComparison::with_differences(differences) } -// ============================================================================ -// Normalization profiles (WG-006) -// ============================================================================ - /// A filter profile for [`normalize_state_with`] and [`compare_final_state_with`]. /// -/// A profile lists extra-map keys that should be excluded from normalized state -/// output. This is the primary mechanism for separating durable state from -/// per-invocation scratch values during replay comparison and resume assertions. +/// Lists extra-map keys to exclude from normalized state output — the primary +/// mechanism for separating durable state from per-invocation scratch values +/// during replay comparison and resume assertions. /// -/// # Conflict detection +/// ## Conflict detection /// -/// When a key is added via [`ignore_key`](Self::ignore_key), the profile records -/// the key's [`StateLifecycle`] annotation. If the same storage key is later -/// registered with a **different** lifecycle annotation, the method panics with a -/// clear message. This prevents subtle bugs from defining the same slot constant -/// twice with conflicting policies. +/// [`ignore_key`](Self::ignore_key) records the key's [`StateLifecycle`] +/// annotation. Registering the same storage key with a **different** lifecycle +/// annotation panics with a clear message, surfacing configuration mistakes at +/// test time rather than silently producing wrong results. /// -/// Raw-string keys added via [`ignore_extra_keys`](Self::ignore_extra_keys) carry -/// no lifecycle annotation and do not trigger conflict detection. +/// Raw-string keys added via [`ignore_extra_keys`](Self::ignore_extra_keys) +/// carry no lifecycle annotation and do not trigger conflict detection. /// -/// # Examples +/// ## Examples /// /// ```rust /// use weavegraph::runtimes::replay::{StateNormalizeProfile, normalize_state_with}; @@ -250,13 +230,12 @@ where /// ``` #[derive(Debug, Default, Clone)] pub struct StateNormalizeProfile { - /// (storage_key, optional lifecycle annotation). - /// `None` = added via raw string; `Some(lc)` = added via typed StateKey. + // Each entry: (storage_key, lifecycle annotation if added via typed StateKey). ignored: Vec<(String, Option)>, } impl StateNormalizeProfile { - /// Create an empty profile (no keys ignored; equivalent to `normalize_state`). + /// Create an empty profile (no keys ignored). #[must_use] pub fn new() -> Self { Self::default() @@ -264,9 +243,8 @@ impl StateNormalizeProfile { /// Ignore the given raw storage key strings during normalization. /// - /// Use this for quick ad-hoc ignores. Prefer [`ignore_key`](Self::ignore_key) - /// when you have a typed `StateKey` constant, as it also validates lifecycle - /// consistency. + /// Prefer [`ignore_key`](Self::ignore_key) when a typed `StateKey` constant + /// is available, as it also validates lifecycle consistency. #[must_use] pub fn ignore_extra_keys(mut self, keys: I) -> Self where @@ -281,10 +259,8 @@ impl StateNormalizeProfile { /// Ignore the storage slot identified by `key` during normalization. /// - /// The key's [`StateLifecycle`] annotation is recorded. If the same storage - /// key has previously been registered with a different lifecycle annotation, - /// this method **panics** — this is intentional: it surfaces a configuration - /// mistake at test/startup time rather than silently producing wrong results. + /// Panics if the same storage key has already been registered with a + /// different [`StateLifecycle`] annotation. #[must_use] pub fn ignore_key(mut self, key: StateKey) -> Self { self.add_raw(key.storage_key(), Some(key.lifecycle())); @@ -294,14 +270,12 @@ impl StateNormalizeProfile { fn add_raw(&mut self, storage_key: String, lifecycle: Option) { if let Some((_, existing_lc)) = self.ignored.iter().find(|(k, _)| k == &storage_key) { match (existing_lc, &lifecycle) { - (Some(a), Some(b)) if a != b => { - panic!( - "StateNormalizeProfile: conflicting lifecycle annotations for key {:?}: \ - already registered as {:?}, attempted to re-register as {:?}. \ - Ensure the same StateKey constant is used throughout.", - storage_key, a, b - ); - } + (Some(a), Some(b)) if a != b => panic!( + "StateNormalizeProfile: conflicting lifecycle annotations for key {:?}: \ + already registered as {:?}, attempted to re-register as {:?}. \ + Ensure the same StateKey constant is used throughout.", + storage_key, a, b + ), _ => {} // duplicate or compatible — idempotent } return; @@ -309,19 +283,19 @@ impl StateNormalizeProfile { self.ignored.push((storage_key, lifecycle)); } - /// Iterate over the concrete storage key strings this profile ignores. + /// Iterate over the storage key strings this profile ignores. pub fn ignored_keys(&self) -> impl Iterator { self.ignored.iter().map(|(k, _)| k.as_str()) } } -/// Normalize a final state into a JSON value, excluding keys listed in `profile`. +/// Normalize a final state to JSON, excluding keys listed in `profile`. /// -/// Identical to [`normalize_state`] except the caller can suppress named keys -/// from the `extra` map. Use this to compare only durable state when some extra -/// entries are per-invocation scratch that should not influence the comparison. +/// Identical to [`normalize_state`] except named keys are suppressed from the +/// `extra` map. Use this to compare only durable state when some extra entries +/// are per-invocation scratch. /// -/// # Examples +/// ## Examples /// /// ```rust /// use weavegraph::runtimes::replay::{StateNormalizeProfile, normalize_state_with}; @@ -352,8 +326,7 @@ pub fn normalize_state_with(state: &VersionedState, profile: &StateNormalizeProf /// Compare two final states using a caller-provided normalization profile. /// /// Equivalent to [`compare_final_state`] but filters the `extra` map through -/// `profile` before comparing. Use this to assert that durable state matches -/// while ignoring known per-invocation scratch keys. +/// `profile` before comparing. #[must_use] pub fn compare_final_state_with( left: &VersionedState, @@ -374,8 +347,8 @@ pub fn compare_final_state_with( /// Compare two captured runs using a state profile and a caller-provided event normalizer. /// /// Combines [`compare_final_state_with`] and [`compare_event_sequences_with`] into -/// a single assertion. Use this as the single call in iterative resume tests that -/// need both durable-state filtering and custom event normalization. +/// a single assertion. Use this in iterative resume tests that need both +/// durable-state filtering and custom event normalization. #[must_use] pub fn compare_replay_runs_with_profile( left: &ReplayRun, @@ -387,14 +360,9 @@ where F: Fn(&Event) -> Value, { let mut differences = Vec::new(); - - let state_comparison = - compare_final_state_with(&left.final_state, &right.final_state, state_profile); - differences.extend(state_comparison.differences().iter().cloned()); - - let event_comparison = - compare_event_sequences_with(&left.events, &right.events, event_normalizer); - differences.extend(event_comparison.differences().iter().cloned()); - + let state_cmp = compare_final_state_with(&left.final_state, &right.final_state, state_profile); + differences.extend_from_slice(state_cmp.differences()); + let event_cmp = compare_event_sequences_with(&left.events, &right.events, event_normalizer); + differences.extend_from_slice(event_cmp.differences()); ReplayComparison::with_differences(differences) } diff --git a/src/runtimes/runner.rs b/src/runtimes/runner.rs index 15dc3fe..9a64c98 100644 --- a/src/runtimes/runner.rs +++ b/src/runtimes/runner.rs @@ -1,13 +1,10 @@ -//! Main workflow runner coordinating session management, execution, and event streaming. +//! Runtime execution engine for weavegraph sessions. //! -//! The `AppRunner` is the central coordinator that brings together: -//! - Session state management (from [`session`](super::session)) -//! - Step execution logic (from [`execution`](super::execution)) -//! - Event stream handling (see [`App::event_stream`](crate::app::App::event_stream) and -//! [`App::invoke_streaming`](crate::app::App::invoke_streaming)) -//! -//! For most use cases, interact with `AppRunner` directly rather than -//! the constituent modules. +//! [`AppRunner`] coordinates the session lifecycle, step dispatch, +//! checkpointing, and event streaming for workflow graphs. For simple +//! single-invocation use cases, call [`App::invoke`]. Construct an +//! `AppRunner` directly when you need custom event sinks, pluggable +//! checkpointers, or iterative (multi-turn) sessions. use crate::app::{App, BarrierOutcome}; use crate::channels::Channel; @@ -41,15 +38,11 @@ use thiserror::Error; use tokio::task::JoinError; use tracing::instrument; -// ============================================================================ -// Private helpers -// ============================================================================ +// ─── private helpers ───────────────────────────────────────────────────────── -/// An [`EventEmitter`] wrapper that calls an observer's `on_event_bus_emit` -/// hook after each successful (or failed) emit attempt. -/// -/// Built lazily in `schedule_step` when an observer is present; otherwise the -/// raw emitter is used directly, paying zero overhead. +/// Wraps an [`EventEmitter`] to fire the observer's `on_event_bus_emit` hook +/// after every emit. Constructed only when an observer is present; the +/// common observer-free path incurs zero overhead. struct ObservingEmitter { inner: Arc, observer: Arc, @@ -67,23 +60,19 @@ impl EventEmitter for ObservingEmitter { fn emit(&self, event: Event) -> Result<(), EmitterError> { let scope = event.scope_label().unwrap_or("unknown").to_owned(); let result = self.inner.emit(event); - let meta = EventBusEmitMeta { scope: &scope }; - // Safety: `Arc` is AssertUnwindSafe-safe - // because we require RefUnwindSafe as a supertrait on RuntimeObserver. - if std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { - self.observer.on_event_bus_emit(&meta) - })) - .is_err() - { - tracing::warn!("RuntimeObserver::on_event_bus_emit panicked; execution continues"); - } + let obs = Arc::clone(&self.observer); + call_observer_hook( + std::panic::AssertUnwindSafe(move || { + obs.on_event_bus_emit(&EventBusEmitMeta { scope: &scope }) + }), + "on_event_bus_emit", + ); result } } -/// Call an observer hook, catching any panic and logging it as a warning. -/// -/// Hooks must not kill graph execution; this helper enforces that contract. +/// Call an observer hook, catching and logging any panic so that observer +/// failures never abort graph execution. fn call_observer_hook(f: F, hook_name: &'static str) where F: FnOnce() + std::panic::UnwindSafe, @@ -96,92 +85,45 @@ where } } -/// Runtime execution engine for workflow graphs with session management and event streaming. -/// -/// `AppRunner` wraps an [`App`] and manages the runtime execution environment, -/// including: -/// - **Session Management**: Multiple isolated workflow executions -/// - **Event Streaming**: Custom EventBus with pluggable sinks -/// - **Checkpointing**: State persistence and recovery -/// - **Step Control**: Pausing, resuming, and interrupting execution -/// -/// # Architecture: App vs AppRunner -/// -/// - **`App`**: The workflow graph structure (nodes, edges, topology) -/// - **`AppRunner`**: The runtime environment (sessions, events, checkpoints) -/// -/// This separation allows: -/// - One `App` to be reused across multiple `AppRunner` instances -/// - Each runner to have isolated EventBus configuration -/// - Per-request event streaming in web servers -/// -/// # EventBus Integration -/// -/// The `AppRunner` owns the [`EventBus`] that receives events -/// from workflow nodes. When you need custom event handling: -/// -/// ```text -/// ❌ WRONG: App.invoke() → Uses default EventBus (stdout only) -/// ✅ RIGHT: AppRunner::builder() with .event_bus(bus) → Custom EventBus with your sinks -/// ``` -/// -/// # Usage Patterns +// ─── AppRunner ─────────────────────────────────────────────────────────────── + +/// Runtime execution engine for workflow graphs. /// -/// ## Simple Execution (via App.invoke) +/// `AppRunner` owns the session table, event bus, and optional checkpointer +/// for one or more concurrent workflow sessions. Use [`AppRunner::builder`] +/// to configure and construct an instance. /// -/// For basic workflows where stdout logging is sufficient: +/// # App vs AppRunner /// -/// ```rust,no_run -/// # use weavegraph::app::App; -/// # use weavegraph::state::VersionedState; -/// # async fn example(app: App) -> Result<(), Box> { -/// // App.invoke() creates an AppRunner internally with default EventBus -/// let final_state = app.invoke( -/// VersionedState::new_with_user_message("Hello") -/// ).await?; -/// # Ok(()) -/// # } -/// ``` +/// - **`App`**: Graph topology (nodes, edges). +/// - **`AppRunner`**: Execution environment (sessions, events, checkpoints). /// -/// ## Advanced Execution (Direct AppRunner) +/// A single `Arc` can be shared across many runners — one per request +/// or tenant — each with its own event bus and isolated session table. /// -/// For production systems needing event streaming, use `AppRunner` directly: +/// # Examples /// /// ```rust,no_run /// # use weavegraph::app::App; /// # use weavegraph::state::VersionedState; /// use weavegraph::event_bus::{EventBus, ChannelSink}; /// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// # async fn example(app: App) -> Result<(), Box> { /// -/// // Create channel for event streaming +/// # async fn example(app: App) -> Result<(), Box> { /// let (tx, rx) = flume::unbounded(); -/// -/// // Build custom EventBus /// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); /// -/// // Create runner with custom EventBus /// let mut runner = AppRunner::builder() /// .app(app) /// .checkpointer(CheckpointerType::InMemory) -/// .autosave(false) /// .event_bus(bus) /// .build() /// .await; /// /// let session_id = "my-session".to_string(); -/// runner.create_session( -/// session_id.clone(), -/// VersionedState::new_with_user_message("Hello") -/// ).await?; -/// -/// // Events stream to the channel while workflow runs -/// tokio::spawn(async move { -/// while let Ok(event) = rx.recv_async().await { -/// println!("Event: {:?}", event); -/// } -/// }); -/// +/// runner +/// .create_session(session_id.clone(), VersionedState::new_with_user_message("Hello")) +/// .await?; /// runner.run_until_complete(&session_id).await?; /// # Ok(()) /// # } @@ -189,13 +131,12 @@ where /// /// # See Also /// -/// - [`builder()`](Self::builder) - Recommended for custom event handling -/// - [`App::invoke()`](crate::app::App::invoke) - Simple execution with defaults -/// - Example: `examples/streaming_events.rs` - Complete streaming demonstration +/// - [`builder()`](Self::builder) — recommended construction path +/// - [`App::invoke()`](crate::app::App::invoke) — simple single-invocation shortcut pub struct AppRunner { app: Arc, sessions: FxHashMap, - checkpointer: Option>, // optional pluggable persistence + checkpointer: Option>, autosave: bool, event_bus: EventBus, event_stream_taken: bool, @@ -204,7 +145,7 @@ pub struct AppRunner { observer: Option>, } -/// Errors that can occur during workflow execution. +/// Errors produced during workflow execution. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] #[non_exhaustive] @@ -321,79 +262,34 @@ enum CompletionEventPolicy { KeepStreamOpen, } -// ============================================================================ -// Builder Pattern -// ============================================================================ +// ─── Builder ───────────────────────────────────────────────────────────────── -/// Builder for constructing [`AppRunner`] instances with a fluent API. -/// -/// This builder is the canonical way to construct `AppRunner` instances. -/// It provides a single, discoverable interface for all configuration options. -/// -/// # Examples +/// Fluent builder for [`AppRunner`]. /// -/// ## Basic usage with defaults +/// Obtain one via [`AppRunner::builder`]. Setting [`app`](Self::app) or +/// [`app_arc`](Self::app_arc) is required; all other options have defaults. /// -/// ```rust,no_run -/// # use weavegraph::app::App; -/// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// # async fn example(app: App) -> Result<(), Box> { +/// # Defaults /// -/// let runner = AppRunner::builder() -/// .app(app) -/// .checkpointer(CheckpointerType::InMemory) -/// .build() -/// .await; -/// # Ok(()) -/// # } -/// ``` +/// - `checkpointer`: [`CheckpointerType::InMemory`] +/// - `autosave`: `true` +/// - `event_bus`: from the app's [`RuntimeConfig`](crate::runtimes::RuntimeConfig) +/// - `start_listener`: `true` /// -/// ## Full configuration with custom EventBus +/// # Examples /// /// ```rust,no_run /// # use weavegraph::app::App; -/// use weavegraph::event_bus::{EventBus, ChannelSink, StdOutSink}; /// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// # async fn example(app: App) -> Result<(), Box> { -/// -/// let (tx, rx) = flume::unbounded(); -/// let bus = EventBus::with_sinks(vec![ -/// Box::new(StdOutSink::default()), -/// Box::new(ChannelSink::new(tx)), -/// ]); +/// use weavegraph::event_bus::{EventBus, ChannelSink}; /// +/// # async fn example(app: App) -> Result<(), Box> { +/// let (tx, _rx) = flume::unbounded(); /// let runner = AppRunner::builder() /// .app(app) -/// .checkpointer(CheckpointerType::SQLite) -/// .event_bus(bus) -/// .autosave(true) -/// .start_listener(true) -/// .build() -/// .await; -/// # Ok(()) -/// # } -/// ``` -/// -/// ## Using `Arc` for shared workflows -/// -/// ```rust,no_run -/// # use weavegraph::app::App; -/// use std::sync::Arc; -/// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// # async fn example(app: App) -> Result<(), Box> { -/// -/// let shared_app = Arc::new(app); -/// -/// // Create multiple runners sharing the same App -/// let runner1 = AppRunner::builder() -/// .app_arc(shared_app.clone()) -/// .checkpointer(CheckpointerType::InMemory) -/// .build() -/// .await; -/// -/// let runner2 = AppRunner::builder() -/// .app_arc(shared_app) /// .checkpointer(CheckpointerType::InMemory) +/// .event_bus(EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))])) +/// .autosave(true) /// .build() /// .await; /// # Ok(()) @@ -417,13 +313,7 @@ impl Default for AppRunnerBuilder { } impl AppRunnerBuilder { - /// Create a new builder with default settings. - /// - /// Defaults: - /// - `checkpointer`: `InMemory` - /// - `autosave`: `true` - /// - `event_bus`: Uses the app's runtime config when built - /// - `start_listener`: `true` + /// Create a builder with default settings. #[must_use] pub fn new() -> Self { Self { @@ -440,43 +330,39 @@ impl AppRunnerBuilder { /// Set the workflow application (takes ownership). /// - /// This is required before calling [`build()`](Self::build). + /// Required before calling [`build`](Self::build). #[must_use] pub fn app(mut self, app: App) -> Self { self.app = Some(Arc::new(app)); self } - /// Set the workflow application from an existing `Arc`. + /// Set the workflow application from a shared `Arc`. /// - /// Use this when sharing an `App` across multiple runners to avoid cloning. + /// Use when sharing an `App` across multiple runners. #[must_use] pub fn app_arc(mut self, app: Arc) -> Self { self.app = Some(app); self } - /// Set the checkpointer type for state persistence. - /// - /// Defaults to [`CheckpointerType::InMemory`]. + /// Set the checkpointer backend. Defaults to [`CheckpointerType::InMemory`]. #[must_use] pub fn checkpointer(mut self, checkpointer_type: CheckpointerType) -> Self { self.checkpointer_type = checkpointer_type; self } - /// Set a custom checkpointer implementation. + /// Provide a custom [`Checkpointer`] implementation. /// - /// When both enum-based and custom checkpointers are configured, the custom - /// checkpointer takes precedence. + /// Takes precedence over the enum-based checkpointer type when set. #[must_use] pub fn checkpointer_custom(mut self, checkpointer: Arc) -> Self { self.checkpointer_custom = Some(checkpointer); self } - /// Set whether to automatically save checkpoints after each step. - /// + /// Control whether checkpoints are saved automatically after each step. /// Defaults to `true`. #[must_use] pub fn autosave(mut self, autosave: bool) -> Self { @@ -484,27 +370,24 @@ impl AppRunnerBuilder { self } - /// Set a custom EventBus for event handling. + /// Provide a custom [`EventBus`]. /// - /// If not set, the runner will use the EventBus configured in the app's - /// [`RuntimeConfig`](crate::runtimes::RuntimeConfig). + /// If not set, the bus from the app's [`RuntimeConfig`](crate::runtimes::RuntimeConfig) + /// is used. #[must_use] pub fn event_bus(mut self, event_bus: EventBus) -> Self { self.event_bus = Some(event_bus); self } - /// Set whether to start the event listener immediately. - /// - /// Defaults to `true`. Set to `false` if you need to configure - /// the EventBus further before starting. + /// Whether to start the event listener on build. Defaults to `true`. #[must_use] pub fn start_listener(mut self, start: bool) -> Self { self.start_listener = start; self } - /// Set a runtime clock that will be injected into node contexts. + /// Inject a clock into every node context. #[must_use] pub fn clock(mut self, clock: Arc) -> Self { self.clock = Some(clock); @@ -513,9 +396,8 @@ impl AppRunnerBuilder { /// Attach a [`RuntimeObserver`] to receive telemetry hooks during execution. /// - /// The observer is called synchronously at invocation boundaries, per-node - /// completion, checkpoint operations, and event-bus emissions. It pays zero - /// runtime cost when not set (`None` default). + /// Hooks fire at invocation boundaries, per-node completion, checkpoint + /// operations, and event-bus emissions. Zero cost when not set. /// /// # Examples /// @@ -551,8 +433,8 @@ impl AppRunnerBuilder { /// /// # Panics /// - /// Panics if [`app()`](Self::app) or [`app_arc()`](Self::app_arc) was not called. - /// Use [`try_build()`](Self::try_build) for a fallible version. + /// Panics if [`app`](Self::app) or [`app_arc`](Self::app_arc) was not + /// called. Use [`try_build`](Self::try_build) for a fallible version. pub async fn build(self) -> AppRunner { self.try_build() .await @@ -571,12 +453,6 @@ impl AppRunnerBuilder { } else { AppRunner::checkpointer_type_label(&self.checkpointer_type).to_string() }; - let runtime_metadata = RunnerRuntimeMetadata { - clock, - checkpointer_descriptor, - observer: self.observer, - }; - Some( AppRunner::with_arc_and_bus( app, @@ -585,34 +461,23 @@ impl AppRunnerBuilder { self.autosave, event_bus, self.start_listener, - runtime_metadata, + RunnerRuntimeMetadata { + clock, + checkpointer_descriptor, + observer: self.observer, + }, ) .await, ) } } +// ─── AppRunner impl ─────────────────────────────────────────────────────────── + impl AppRunner { /// Create a new [`AppRunnerBuilder`] for fluent configuration. /// - /// This is the **preferred method** for constructing `AppRunner` instances. - /// - /// # Examples - /// - /// ```rust,no_run - /// # use weavegraph::app::App; - /// use weavegraph::runtimes::{AppRunner, CheckpointerType}; - /// # async fn example(app: App) -> Result<(), Box> { - /// - /// let runner = AppRunner::builder() - /// .app(app) - /// .checkpointer(CheckpointerType::InMemory) - /// .autosave(true) - /// .build() - /// .await; - /// # Ok(()) - /// # } - /// ``` + /// This is the preferred construction path for `AppRunner`. #[must_use] pub fn builder() -> AppRunnerBuilder { AppRunnerBuilder::new() @@ -640,10 +505,6 @@ impl AppRunner { .unwrap_or_else(|_| "weavegraph.db".to_string()); format!("sqlite://{fallback}") }); - // Ensure underlying sqlite file exists. Steps: - // 1. Strip "sqlite://" scheme to get filesystem path. - // 2. Create parent directories if needed. - // 3. Attempt to create the file (ignore errors if it already exists or any failure). if let Some(path) = db_url.strip_prefix("sqlite://") { let path = path.trim(); if !path.is_empty() { @@ -652,7 +513,6 @@ impl AppRunner { let _ = std::fs::create_dir_all(parent); } if !p.exists() { - // Ignore result; if it already exists or we lack permission we proceed anyway. let _ = std::fs::File::create_new(p); } } @@ -709,8 +569,6 @@ impl AppRunner { start_listener: bool, runtime_metadata: RunnerRuntimeMetadata, ) -> Self { - // Precedence rule: custom checkpointer always wins when provided. - // If custom is None, fall back to enum-based factory instantiation. let checkpointer = if let Some(custom) = checkpointer_custom { Some(custom) } else { @@ -735,9 +593,8 @@ impl AppRunner { /// Subscribe to the underlying event stream. /// - /// Returns a handle that yields events as they are emitted by workflow nodes. - /// Subsequent calls after the first return `None` until the stream is - /// finalized (e.g., when a session completes and the runner resets the flag). + /// Returns the stream on the first call; subsequent calls return `None` + /// until the stream is finalized after a session completes. pub fn event_stream(&mut self) -> Option { if self.event_stream_taken { return None; @@ -746,26 +603,22 @@ impl AppRunner { Some(self.event_bus.subscribe()) } - /// Initialize a new session with the given initial state + /// Initialize a new session, or resume from the latest checkpoint. #[instrument(skip(self, initial_state, session_id), err)] pub async fn create_session( &mut self, session_id: String, initial_state: VersionedState, ) -> Result { - // If checkpointer present and session exists, load instead of creating anew - let restored_checkpoint = if let Some(cp) = &self.checkpointer { - cp.load_latest(&session_id) + if let Some(cp) = &self.checkpointer + && let Some(stored) = cp + .load_latest(&session_id) .await .map_err(RunnerError::Checkpointer)? - } else { - None - }; - - if let Some(stored) = restored_checkpoint { - let restored = restore_session_state(&stored); + { let restored_step = stored.step; - self.sessions.insert(session_id.clone(), restored); + self.sessions + .insert(session_id.clone(), restore_session_state(&stored)); if let Some(obs) = &self.observer { let backend = self.checkpointer_descriptor.as_str(); let sid = session_id.as_str(); @@ -794,36 +647,31 @@ impl AppRunner { if frontier.is_empty() { return Err(RunnerError::NoStartNodes); } - let default_limit = std::thread::available_parallelism() + let concurrency = std::thread::available_parallelism() .map(|n| n.get()) .unwrap_or(1); - let scheduler = Scheduler::new(default_limit); let session_state = SessionState { state: initial_state, step: 0, frontier, - scheduler, + scheduler: Scheduler::new(concurrency), scheduler_state: SchedulerState::default(), }; - self.sessions - .insert(session_id.clone(), session_state.clone()); - if let Some(cp) = &self.checkpointer { - let _ = cp - .save(Checkpoint::from_session(&session_id, &session_state)) - .await; + self.sessions.insert(session_id.clone(), session_state); + if let Some(cp) = &self.checkpointer + && let Some(ss) = self.sessions.get(&session_id) + { + let _ = cp.save(Checkpoint::from_session(&session_id, ss)).await; } Ok(SessionInit::Fresh) } - /// Initialize or resume a session for repeated invocations under one durable lineage. + /// Initialize or resume a session for repeated invocations under one lineage. /// - /// This method behaves like [`create_session`](Self::create_session), then prepares - /// the session to run from `entry_node`. Passing [`NodeKind::Start`] uses the - /// graph's outgoing edges from the virtual Start node, matching normal session - /// initialization. Passing a custom node runs directly from that registered node. - /// - /// The session step counter is not reset when a checkpoint is resumed, so steps - /// remain monotonic across repeated invocations. + /// Behaves like [`create_session`](Self::create_session), then sets the + /// initial frontier to `entry_node`. Pass [`NodeKind::Start`] for normal + /// graph initialization. The step counter is not reset on checkpoint resume, + /// keeping it monotonic across repeated inputs. #[instrument(skip(self, session_id, initial_state), err)] pub async fn create_iterative_session( &mut self, @@ -839,17 +687,15 @@ impl AppRunner { Ok(init) } - /// Apply an input patch, restart the session frontier, and run to completion. + /// Apply an input patch, reset the frontier, and run to completion. /// - /// The existing session state is updated through the same deterministic barrier - /// path used for node outputs. The frontier is then reset to `entry_node` and the - /// scheduler's version-gating state is cleared so the entry path executes for this - /// logical invocation even when two consecutive input patches serialize to the - /// same state. + /// The patch is merged through the same deterministic barrier path used + /// for node outputs. The frontier and scheduler version-gate state are + /// then reset to `entry_node` so the entry path executes even when two + /// consecutive patches produce identical serialized state. /// - /// Use [`create_iterative_session`](Self::create_iterative_session) before the - /// first call, including after process restart, so the latest checkpoint is loaded - /// into the runner. + /// Call [`create_iterative_session`](Self::create_iterative_session) before + /// the first invocation, including after process restart. #[instrument(skip(self, input), err)] pub async fn invoke_next( &mut self, @@ -866,10 +712,10 @@ impl AppRunner { /// Emit the terminal stream marker for a completed iterative session. /// - /// `invoke_next` keeps long-lived event subscriptions open between logical - /// inputs. Call this after the final input when a subscriber should receive - /// [`STREAM_END_SCOPE`](crate::event_bus::STREAM_END_SCOPE) and the stream - /// should close cleanly. + /// `invoke_next` keeps the event stream open between logical inputs. Call + /// this after the final input to deliver the + /// [`STREAM_END_SCOPE`](crate::event_bus::STREAM_END_SCOPE) sentinel and + /// close the stream cleanly. pub fn finish_iterative_session(&mut self, session_id: &str) -> Result<(), RunnerError> { let (_, _, final_step) = self.finalize_state_snapshot(session_id)?; self.emit_completion_event( @@ -885,15 +731,14 @@ impl AppRunner { session_id: &str, frontier: Vec, ) -> Result<(), RunnerError> { - let session_state = - self.sessions - .get_mut(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; - - session_state.frontier = frontier; - session_state.scheduler_state = SchedulerState::default(); + let ss = self + .sessions + .get_mut(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; + ss.frontier = frontier; + ss.scheduler_state = SchedulerState::default(); Ok(()) } @@ -935,7 +780,7 @@ impl AppRunner { session_id: &str, input: NodePartial, ) -> Result<(), RunnerError> { - let mut updated_state = self + let mut state = self .sessions .get(session_id) .ok_or_else(|| RunnerError::SessionNotFound { @@ -943,49 +788,40 @@ impl AppRunner { })? .state .clone(); - self.app - .apply_barrier(&mut updated_state, &[], vec![input]) + .apply_barrier(&mut state, &[], vec![input]) .await .map_err(RunnerError::AppBarrier)?; - - let session_state = - self.sessions - .get_mut(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; - session_state.state = updated_state; + self.sessions + .get_mut(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })? + .state = state; Ok(()) } - /// Execute one superstep for the given session + /// Execute one superstep for the given session. #[instrument(skip(self, options), err)] pub async fn run_step( &mut self, session_id: &str, options: StepOptions, ) -> Result { - // Phase 3.1 (Clone Reduction - A): capture minimal snapshots without cloning full session let (current_step, current_frontier, current_versions) = { - let current_session_state = - self.sessions - .get(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; + let ss = self + .sessions + .get(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; let versions = StateVersions { - messages_version: current_session_state.state.messages.version(), - extra_version: current_session_state.state.extra.version(), + messages_version: ss.state.messages.version(), + extra_version: ss.state.extra.version(), }; - ( - current_session_state.step, - current_session_state.frontier.clone(), - versions, - ) + (ss.step, ss.frontier.clone(), versions) }; - // Check if already completed if current_frontier.is_empty() || current_frontier.iter().all(|n| *n == NodeKind::End) { return Ok(StepResult::Completed(StepReport { step: current_step, @@ -998,11 +834,8 @@ impl AppRunner { })); } - // Check for interrupt_before for node in ¤t_frontier { if options.interrupt_before.contains(node) { - // SAFETY: We verified session existence above with the same session_id. - // If this fails, we have a logic bug (e.g., concurrent mutation). let session_state = self .sessions .get(session_id) @@ -1017,83 +850,26 @@ impl AppRunner { } } - // Take ownership of session state for execution (eliminates full clone) - // SAFETY: We verified session existence above with the same session_id. - let mut session_state = - self.sessions - .remove(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; + let mut session_state = self + .sessions + .remove(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; - // Execute one superstep; on error, emit an ErrorEvent and rethrow let step_report = match self.run_one_superstep(session_id, &mut session_state).await { Ok(rep) => rep, Err(e) => { - // Build error event - let event = match &e { - RunnerError::Scheduler(source) => match source { - crate::schedulers::SchedulerError::NodeNotFound { kind, step } => { - ErrorEvent { - when: chrono::Utc::now(), - scope: ErrorScope::Scheduler { step: *step }, - error: WeaveError::msg(format!( - "node {:?} not found in registry", - kind - )), - tags: vec!["scheduler".into(), "node_not_found".into()], - context: serde_json::json!({ - "kind": kind.encode() - }), - } - } - crate::schedulers::SchedulerError::NodeRun { kind, step, source } => { - ErrorEvent { - when: chrono::Utc::now(), - scope: ErrorScope::Node { - kind: kind.encode().to_string(), - step: *step, - }, - error: WeaveError::msg(format!("{}", source)), - tags: vec!["node".into()], - context: serde_json::json!({}), - } - } - crate::schedulers::SchedulerError::Join(_) => ErrorEvent { - when: chrono::Utc::now(), - scope: ErrorScope::Scheduler { - step: session_state.step, - }, - error: WeaveError::msg(format!("{}", e)), - tags: vec!["scheduler".into()], - context: serde_json::json!({}), - }, - }, - _ => ErrorEvent { - when: chrono::Utc::now(), - scope: ErrorScope::Runner { - session: session_id.to_string(), - step: session_state.step, - }, - error: WeaveError::msg(format!("{}", e)), - tags: vec!["runner".into()], - context: serde_json::json!({ - "frontier": session_state.frontier.iter().map(|k| k.encode()).collect::>() - }), - }, - }; - // Inject via barrier mechanics by applying a synthetic NodePartial with errors field - let mut update_state = session_state.state.clone(); - let partial = NodePartial::new().with_errors(vec![event]); - // Apply directly using reducer registry through App + let error_event = self.build_error_event(&e, session_id, &session_state); + let mut error_state = session_state.state.clone(); + let partial = NodePartial::new().with_errors(vec![error_event]); let _ = self .app - .apply_barrier(&mut update_state, &[], vec![partial]) + .apply_barrier(&mut error_state, &[], vec![partial]) .await; - session_state.state = update_state; - // Save back to sessions map so callers can inspect accumulated errors - self.sessions.insert(session_id.to_string(), session_state); - // Re-persist if autosave + session_state.state = error_state; + self.sessions + .insert(session_id.to_string(), session_state); if self.autosave && let Some(cp) = &self.checkpointer && let Some(s) = self.sessions.get(session_id) @@ -1104,17 +880,13 @@ impl AppRunner { } }; - // Evaluate post-execution interrupts BEFORE reinserting to minimize clones - // If an interrupt triggers, we insert a clone for persistence and move original into PausedReport. if let Some(node) = step_report .ran_nodes .iter() .find(|n| options.interrupt_after.contains(n)) { - // Persist a clone, return original in PausedReport let persisted = session_state.clone(); self.sessions.insert(session_id.to_string(), persisted); - // Re-persist via helper self.maybe_checkpoint(session_id, step_report.step).await; return Ok(StepResult::Paused(PausedReport { session_state, @@ -1124,7 +896,6 @@ impl AppRunner { if options.interrupt_each_step { let persisted = session_state.clone(); self.sessions.insert(session_id.to_string(), persisted); - // Re-persist via helper self.maybe_checkpoint(session_id, step_report.step).await; return Ok(StepResult::Paused(PausedReport { session_state, @@ -1132,14 +903,61 @@ impl AppRunner { })); } - // Normal completion path: reinsert owned session_state directly (no clone) self.sessions.insert(session_id.to_string(), session_state); - // Persist via helper self.maybe_checkpoint(session_id, step_report.step).await; Ok(StepResult::Completed(step_report)) } - /// Schedule one step: invoke scheduler and normalize outputs to ordered partials. + fn build_error_event( + &self, + e: &RunnerError, + session_id: &str, + session_state: &SessionState, + ) -> ErrorEvent { + match e { + RunnerError::Scheduler(source) => match source { + SchedulerError::NodeNotFound { kind, step } => ErrorEvent { + when: chrono::Utc::now(), + scope: ErrorScope::Scheduler { step: *step }, + error: WeaveError::msg(format!("node {:?} not found in registry", kind)), + tags: vec!["scheduler".into(), "node_not_found".into()], + context: serde_json::json!({ "kind": kind.encode() }), + }, + SchedulerError::NodeRun { kind, step, source } => ErrorEvent { + when: chrono::Utc::now(), + scope: ErrorScope::Node { + kind: kind.encode().to_string(), + step: *step, + }, + error: WeaveError::msg(source.to_string()), + tags: vec!["node".into()], + context: serde_json::json!({}), + }, + SchedulerError::Join(_) => ErrorEvent { + when: chrono::Utc::now(), + scope: ErrorScope::Scheduler { + step: session_state.step, + }, + error: WeaveError::msg(e.to_string()), + tags: vec!["scheduler".into()], + context: serde_json::json!({}), + }, + }, + _ => ErrorEvent { + when: chrono::Utc::now(), + scope: ErrorScope::Runner { + session: session_id.to_string(), + step: session_state.step, + }, + error: WeaveError::msg(e.to_string()), + tags: vec!["runner".into()], + context: serde_json::json!({ + "frontier": session_state.frontier.iter().map(|k| k.encode()).collect::>() + }), + }, + } + } + #[inline] async fn schedule_step( &self, @@ -1148,7 +966,6 @@ impl AppRunner { step: u64, ) -> Result { let snapshot = session_state.state.snapshot(); - // If an observer is attached, wrap the emitter to fire on_event_bus_emit for each emit. let emitter: Arc = if let Some(obs) = &self.observer { Arc::new(ObservingEmitter { inner: self.event_bus.get_emitter(), @@ -1157,13 +974,13 @@ impl AppRunner { } else { self.event_bus.get_emitter() }; - let result = session_state + let raw = session_state .scheduler .superstep( &mut session_state.scheduler_state, self.app.nodes(), session_state.frontier.clone(), - snapshot.clone(), + snapshot, step, SchedulerRunContext { event_emitter: emitter, @@ -1172,26 +989,16 @@ impl AppRunner { }, ) .await?; - - let mut partials_by_kind: FxHashMap = FxHashMap::default(); - for (k, partial) in result.outputs { - partials_by_kind.insert(k, partial); - } - let executed_nodes = result.ran_nodes.clone(); - let partials = executed_nodes - .iter() - .cloned() - .filter_map(|k| partials_by_kind.remove(&k)) - .collect(); - + let mut by_kind: FxHashMap = + raw.outputs.into_iter().collect(); + let partials = raw.ran_nodes.iter().filter_map(|k| by_kind.remove(k)).collect(); Ok(SchedulerOutcome { - ran_nodes: executed_nodes, - skipped_nodes: result.skipped_nodes, + ran_nodes: raw.ran_nodes, + skipped_nodes: raw.skipped_nodes, partials, }) } - /// Apply barrier and update session state with the results. #[tracing::instrument(skip(self, session_state, partials, ran), err)] async fn apply_barrier_and_update( &self, @@ -1199,17 +1006,16 @@ impl AppRunner { ran: &[NodeKind], partials: Vec, ) -> Result { - let mut update_state = session_state.state.clone(); + let mut state = session_state.state.clone(); let outcome = self .app - .apply_barrier(&mut update_state, ran, partials) + .apply_barrier(&mut state, ran, partials) .await .map_err(RunnerError::AppBarrier)?; - session_state.state = update_state; + session_state.state = state; Ok(outcome) } - /// Compute next frontier from barrier outcome, resolving commands and conditional edges. #[inline] fn compute_next_frontier( &self, @@ -1218,95 +1024,86 @@ impl AppRunner { barrier: &BarrierOutcome, step: u64, ) -> Vec { - let mut next_frontier: Vec = Vec::new(); let graph_edges = self.app.edges(); let conditional_edges = self.app.conditional_edges(); let state_snapshot = session_state.state.snapshot(); - let mut frontier_commands_by_node: FxHashMap> = + let mut commands_by_node: FxHashMap> = FxHashMap::default(); - for (origin, command) in &barrier.frontier_commands { - frontier_commands_by_node + for (origin, cmd) in &barrier.frontier_commands { + commands_by_node .entry(origin.clone()) .or_default() - .push(command.clone()); + .push(cmd.clone()); } - for id in ran.iter() { - let default_edges = graph_edges.get(id).cloned().unwrap_or_default(); - let mut next_targets: Vec = Vec::new(); - let mut frontier_replaced = false; - - if let Some(commands) = frontier_commands_by_node.get(id) { - // Commands are processed in emission order to preserve author intent. - for command in commands { - match command { - FrontierCommand::Replace(entries) => { - if frontier_replaced { - tracing::warn!( - step, - origin = %id.encode(), - target = %entries.iter().fold(String::new(), - |acc, e| format!("{} + {}", acc, e.to_node_kind()) - ), - "Replace frontier command has been issued once already during this step, skipping." - ); - continue; - } - next_targets = entries.iter().map(NodeRoute::to_node_kind).collect(); - frontier_replaced = true; + let mut next_frontier: Vec = Vec::new(); + + for node in ran { + let default_edges = graph_edges.get(node).cloned().unwrap_or_default(); + let mut targets: Vec = Vec::new(); + let mut replaced = false; + + for cmd in commands_by_node.get(node).into_iter().flatten() { + match cmd { + FrontierCommand::Replace(entries) => { + if replaced { + tracing::warn!( + step, + origin = %node.encode(), + target = %entries.iter().fold(String::new(), + |acc, e| format!("{} + {}", acc, e.to_node_kind()) + ), + "Replace frontier command has been issued once already during this step, skipping." + ); + continue; } - FrontierCommand::Append(entries) => { - if next_targets.is_empty() && !frontier_replaced { - next_targets.extend(default_edges.clone()); - } - next_targets.extend(entries.iter().map(NodeRoute::to_node_kind)); + targets = entries.iter().map(NodeRoute::to_node_kind).collect(); + replaced = true; + } + FrontierCommand::Append(entries) => { + if targets.is_empty() && !replaced { + targets.extend(default_edges.iter().cloned()); } + targets.extend(entries.iter().map(NodeRoute::to_node_kind)); } } - - if next_targets.is_empty() && !frontier_replaced { - next_targets.extend(default_edges.clone()); - } - } else { - next_targets.extend(default_edges.clone()); } - if !frontier_replaced { - for conditional_edge in conditional_edges.iter().filter(|ce| ce.from() == id) { - tracing::debug!(from = ?conditional_edge.from(), step, "evaluating conditional edge"); - let target_node_names = (conditional_edge.predicate())(state_snapshot.clone()); + if targets.is_empty() && !replaced { + targets = default_edges; + } - for target_name in target_node_names { - let target = if target_name == "End" { + if !replaced { + for ce in conditional_edges.iter().filter(|ce| ce.from() == node) { + tracing::debug!(from = ?ce.from(), step, "evaluating conditional edge"); + for name in (ce.predicate())(state_snapshot.clone()) { + let target = if name == "End" { NodeKind::End - } else if target_name == "Start" { + } else if name == "Start" { NodeKind::Start } else { - NodeKind::Custom(target_name.clone()) + NodeKind::Custom(name) }; - tracing::debug!(target = ?target, step, "conditional edge routed"); - - next_targets.push(target); + targets.push(target); } } } - for target in next_targets { - let is_valid_target = match &target { + for target in targets { + let valid = match &target { NodeKind::End | NodeKind::Start => true, NodeKind::Custom(_) => self.app.nodes().contains_key(&target), }; - - if is_valid_target { + if valid { if !next_frontier.contains(&target) { next_frontier.push(target); } } else { tracing::warn!( step, - origin = %id.encode(), + origin = %node.encode(), target = %target.encode(), "frontier target not found; skipping" ); @@ -1317,45 +1114,42 @@ impl AppRunner { next_frontier } - /// Conditionally persist a checkpoint for the given session if autosave is enabled. async fn maybe_checkpoint(&self, session_id: &str, step: u64) { - let checkpoint_span = tracing::info_span!("checkpoint", step); - checkpoint_span + tracing::info_span!("checkpoint", step) .in_scope(|| async { - if self.autosave - && let Some(checkpointer) = &self.checkpointer - && let Some(session_state) = self.sessions.get(session_id) + if !self.autosave { + return; + } + let (Some(cp), Some(ss)) = + (&self.checkpointer, self.sessions.get(session_id)) + else { + return; + }; + let save_start = std::time::Instant::now(); + if cp + .save(Checkpoint::from_session(session_id, ss)) + .await + .is_ok() + && let Some(obs) = &self.observer { - let start = std::time::Instant::now(); - let result = checkpointer - .save(Checkpoint::from_session(session_id, session_state)) - .await; - let duration_ms = start.elapsed().as_millis() as u64; - if result.is_ok() - && let Some(obs) = &self.observer - { - let backend = self.checkpointer_descriptor.as_str(); - call_observer_hook( - || { - obs.on_checkpoint_save(&CheckpointSaveMeta { - session_id, - backend, - step, - duration_ms, - }) - }, - "on_checkpoint_save", - ); - } + let backend = self.checkpointer_descriptor.as_str(); + let duration_ms = save_start.elapsed().as_millis() as u64; + call_observer_hook( + || { + obs.on_checkpoint_save(&CheckpointSaveMeta { + session_id, + backend, + step, + duration_ms, + }) + }, + "on_checkpoint_save", + ); } }) .await; } - /// Helper method that executes exactly one superstep on the given session state. - /// - /// Applies barrier outcomes (including frontier commands) and returns the updated - /// step report with deterministic routing decisions. #[instrument(skip(self, session_state), err)] async fn run_one_superstep( &self, @@ -1368,44 +1162,40 @@ impl AppRunner { tracing::debug!(step, "starting superstep"); - // Phase 1: schedule and normalize outputs - let schedule_span = tracing::info_span!( + let scheduler_outcome = tracing::info_span!( "schedule", step, frontier_len = session_state.frontier.len() - ); - let scheduler_outcome = schedule_span - .in_scope(|| self.schedule_step(session_id, session_state, step)) - .await?; + ) + .in_scope(|| self.schedule_step(session_id, session_state, step)) + .await?; - // Phase 2: apply barrier and update state - let errors_in_partials = scheduler_outcome + let errors_in_partials: usize = scheduler_outcome .partials .iter() .filter_map(|p| p.errors.as_ref()) .map(|e| e.len()) - .sum::(); - let barrier_span = tracing::info_span!( + .sum(); + let barrier_outcome = tracing::info_span!( "barrier", ran_nodes_len = scheduler_outcome.ran_nodes.len(), errors_in_partials - ); - let barrier_outcome = barrier_span - .in_scope(|| { - self.apply_barrier_and_update( - session_state, - &scheduler_outcome.ran_nodes, - scheduler_outcome.partials, - ) - }) - .await?; + ) + .in_scope(|| { + self.apply_barrier_and_update( + session_state, + &scheduler_outcome.ran_nodes, + scheduler_outcome.partials, + ) + }) + .await?; - // Phase 3: compute next frontier - let commands_count = barrier_outcome.frontier_commands.len(); - let conditional_edges_evaluated = self.app.conditional_edges().len(); - let frontier_span = - tracing::info_span!("frontier", commands_count, conditional_edges_evaluated); - let next_frontier = frontier_span.in_scope(|| { + let next_frontier = tracing::info_span!( + "frontier", + commands_count = barrier_outcome.frontier_commands.len(), + conditional_edges_evaluated = self.app.conditional_edges().len() + ) + .in_scope(|| { self.compute_next_frontier( session_state, &scheduler_outcome.ran_nodes, @@ -1424,8 +1214,6 @@ impl AppRunner { let completed = next_frontier.is_empty() || next_frontier.iter().all(|n| *n == NodeKind::End); - - // Update session state session_state.frontier = next_frontier.clone(); let state_versions = StateVersions { @@ -1433,7 +1221,6 @@ impl AppRunner { extra_version: session_state.state.extra.version(), }; - // Emit per-node finish hooks (step-level timing, shared across all nodes in superstep). if let Some(obs) = &self.observer { let step_duration_ms = step_start.elapsed().as_millis() as u64; for node_kind in &scheduler_outcome.ran_nodes { @@ -1477,11 +1264,11 @@ impl AppRunner { }) } - /// Runs the workflow to completion (until End nodes or an empty frontier is reached). + /// Run the session to completion (until the frontier is empty or all End nodes). /// - /// This is the canonical single-invocation execution method. For iterative - /// (multi-input) sessions, use [`create_iterative_session`](Self::create_iterative_session) - /// and [`invoke_next`](Self::invoke_next) instead. + /// This is the canonical single-invocation path. For multi-turn sessions + /// use [`create_iterative_session`](Self::create_iterative_session) and + /// [`invoke_next`](Self::invoke_next). #[instrument(skip(self, session_id), err)] pub async fn run_until_complete( &mut self, @@ -1515,46 +1302,54 @@ impl AppRunner { let invocation_start = std::time::Instant::now(); loop { - // Check if we're done before trying to run - let session_state = - self.sessions - .get(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; + let ss = self + .sessions + .get(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; - if self.is_session_complete(session_state) { + if self.is_session_complete(ss) { tracing::info!( session = %session_id, - step = session_state.step, + step = ss.step, "frontier reached terminal state" ); break; } - // Run one step - let step_result = match self.run_step(session_id, StepOptions::default()).await { - Ok(res) => res, + match self.run_step(session_id, StepOptions::default()).await { + Ok(StepResult::Completed(report)) if report.completed => break, + Ok(StepResult::Completed(_)) => {} + Ok(StepResult::Paused(_)) => { + let step = self.sessions.get(session_id).map(|s| s.step); + self.emit_completion_event( + session_id, + StreamEndReason::Error { + step, + error: "execution paused unexpectedly".to_string(), + }, + completion_policy, + ); + return Err(RunnerError::UnexpectedPause); + } Err(err) => { - let reason = err.to_string(); - let step = self.sessions.get(session_id).map(|state| state.step); + let step = self.sessions.get(session_id).map(|s| s.step); self.emit_completion_event( session_id, StreamEndReason::Error { step, - error: reason, + error: err.to_string(), }, completion_policy, ); if let Some(obs) = &self.observer { let duration_ms = invocation_start.elapsed().as_millis() as u64; - let graph_id = self.app.graph_definition_hash(); - let sid = session_id; let gid = graph_id.as_str(); call_observer_hook( || { obs.on_invocation_finish(&InvocationFinishMeta { - session_id: sid, + session_id, graph_id: gid, duration_ms, outcome: InvocationOutcome::Error, @@ -1565,27 +1360,6 @@ impl AppRunner { } return Err(err); } - }; - - match step_result { - StepResult::Completed(report) => { - if report.completed { - break; - } - } - StepResult::Paused(_) => { - // This shouldn't happen with default options, but handle gracefully - let step = self.sessions.get(session_id).map(|state| state.step); - self.emit_completion_event( - session_id, - StreamEndReason::Error { - step, - error: "execution paused unexpectedly".to_string(), - }, - completion_policy, - ); - return Err(RunnerError::UnexpectedPause); - } } } @@ -1593,10 +1367,7 @@ impl AppRunner { let (final_state, versions, final_step) = self.finalize_state_snapshot(session_id)?; let messages_snapshot = final_state.messages.snapshot(); let extra_snapshot = final_state.extra.snapshot(); - let messages_version = versions.messages_version; - let extra_version = versions.extra_version; - // Print final state summary (matching App::invoke output) for (i, m) in messages_snapshot.iter().enumerate() { tracing::debug!( session = %session_id, @@ -1608,13 +1379,12 @@ impl AppRunner { } tracing::debug!( session = %session_id, - messages_version, + messages_version = versions.messages_version, "messages channel version" ); - tracing::debug!( session = %session_id, - extra_version, + extra_version = versions.extra_version, keys = extra_snapshot.len(), "extra channel summary" ); @@ -1650,31 +1420,19 @@ impl AppRunner { Ok(final_state) } - /// Get a snapshot of the current session state. - /// - /// # Parameters - /// - /// * `session_id` - The session identifier - /// - /// # Returns - /// - /// `Some(&SessionState)` if the session exists, `None` otherwise + /// Return a reference to the current state of a session. #[must_use] pub fn get_session(&self, session_id: &str) -> Option<&SessionState> { self.sessions.get(session_id) } /// List all active session IDs. - /// - /// # Returns - /// - /// A vector of session ID references #[must_use] pub fn list_sessions(&self) -> Vec<&String> { self.sessions.keys().collect() } - /// Return metadata for this runner and its compiled graph. + /// Return runtime metadata for this runner and its compiled graph. #[must_use] pub fn run_metadata(&self) -> RunMetadata { RunMetadata { @@ -1689,37 +1447,30 @@ impl AppRunner { }, } } -} -impl AppRunner { - /// Determine if a session has reached a terminal frontier (no work or only End nodes). #[inline] fn is_session_complete(&self, session_state: &SessionState) -> bool { session_state.frontier.is_empty() || session_state.frontier.iter().all(|n| *n == NodeKind::End) } - /// Return the final state clone, channel versions, and last step for the session. - /// Logging should occur after retrieval by the caller. #[inline] fn finalize_state_snapshot( &self, session_id: &str, ) -> Result<(VersionedState, StateVersions, u64), RunnerError> { - let session_state = - self.sessions - .get(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; - - let final_state = session_state.state.clone(); - let state_versions = StateVersions { + let ss = self + .sessions + .get(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; + let final_state = ss.state.clone(); + let versions = StateVersions { messages_version: final_state.messages.version(), extra_version: final_state.extra.version(), }; - let final_step = session_state.step; - Ok((final_state, state_versions, final_step)) + Ok((final_state, versions, ss.step)) } fn finalize_event_stream(&mut self, session_id: &str, reason: StreamEndReason) { diff --git a/src/runtimes/runtime_config.rs b/src/runtimes/runtime_config.rs index d5f4edf..a97f368 100644 --- a/src/runtimes/runtime_config.rs +++ b/src/runtimes/runtime_config.rs @@ -1,4 +1,4 @@ -//! Runtime configuration types for controlling event bus, sinks, and diagnostics. +//! Configuration types for the runtime event bus, sinks, and diagnostics. use std::sync::Arc; use crate::event_bus::{EventBus, EventSink, MemorySink, StdOutSink}; @@ -6,18 +6,18 @@ use crate::utils::clock::Clock; use super::Checkpointer; -/// Configuration for a single [`AppRunner`](crate::runtimes::runner::AppRunner) instance. +/// Configuration for a single [`AppRunner`](crate::runtimes::runner::AppRunner). #[derive(Clone)] pub struct RuntimeConfig { - /// Optional session ID to use; a new UUID is generated if `None`. + /// Session ID used for persistence; a UUID is generated when `None`. pub session_id: Option, - /// Custom [`Checkpointer`] to use instead of the built-in types. + /// Custom checkpointer; takes precedence over the built-in SQLite/Postgres backends. pub checkpointer_custom: Option>, - /// SQLite database file name; defaults to `SQLITE_DB_NAME` env var or `weavegraph.db`. + /// SQLite database file. Falls back to `SQLITE_DB_NAME` env var, then `weavegraph.db`. pub sqlite_db_name: Option, - /// Event bus configuration used to build the [`EventBus`]. + /// Event bus settings applied when building the [`EventBus`]. pub event_bus: EventBusConfig, - /// Optional runtime clock injected into node execution contexts. + /// Clock injected into node execution contexts. pub clock: Option>, } @@ -38,7 +38,7 @@ impl Default for RuntimeConfig { Self { session_id: None, checkpointer_custom: None, - sqlite_db_name: Self::resolve_sqlite_db_name(None), + sqlite_db_name: Self::resolve_db_name(None), event_bus: EventBusConfig::default(), clock: None, } @@ -46,53 +46,52 @@ impl Default for RuntimeConfig { } impl RuntimeConfig { - fn resolve_sqlite_db_name(provided: Option) -> Option { - if let Some(name) = provided { - return Some(name); - } - dotenvy::dotenv().ok(); - Some(std::env::var("SQLITE_DB_NAME").unwrap_or_else(|_| "weavegraph.db".to_string())) + fn resolve_db_name(name: Option) -> Option { + name.or_else(|| { + dotenvy::dotenv().ok(); + Some(std::env::var("SQLITE_DB_NAME").unwrap_or_else(|_| "weavegraph.db".to_string())) + }) } - /// Create a new `RuntimeConfig` with the given session ID and optional SQLite DB name. + /// Build a `RuntimeConfig` with the given session ID and SQLite database file name. pub fn new(session_id: Option, sqlite_db_name: Option) -> Self { Self { session_id, checkpointer_custom: None, - sqlite_db_name: Self::resolve_sqlite_db_name(sqlite_db_name), + sqlite_db_name: Self::resolve_db_name(sqlite_db_name), event_bus: EventBusConfig::default(), clock: None, } } + /// Attach a custom [`Checkpointer`]. #[must_use] - /// Set a custom [`Checkpointer`] for this configuration. pub fn checkpointer_custom(mut self, checkpointer: Arc) -> Self { self.checkpointer_custom = Some(checkpointer); self } + /// Return the custom checkpointer if one was set. #[must_use] - /// Return the custom checkpointer if one has been set. pub fn custom_checkpointer(&self) -> Option> { self.checkpointer_custom.clone() } + /// Attach a runtime clock injected into [`NodeContext`](crate::node::NodeContext). #[must_use] - /// Set the runtime clock injected into [`NodeContext`](crate::node::NodeContext). pub fn with_clock(mut self, clock: Arc) -> Self { self.clock = Some(clock); self } + /// Return the configured clock, if any. #[must_use] - /// Return the configured runtime clock, if any. pub fn clock(&self) -> Option> { self.clock.clone() } + /// Describe the clock setting: `"configured"` or `"unset"`. #[must_use] - /// Return a descriptor for the configured clock mode. pub fn clock_mode(&self) -> &'static str { if self.clock.is_some() { "configured" @@ -101,76 +100,71 @@ impl RuntimeConfig { } } - /// Return a deterministic hash of runtime configuration metadata. + /// A deterministic hex fingerprint of this configuration's metadata. #[must_use] pub fn config_hash(&self) -> String { - let mut parts = vec!["weavegraph-runtime-config-v1".to_string()]; - parts.push(format!( - "session_id:{}", - self.session_id.as_deref().unwrap_or("") - )); - parts.push(format!( - "sqlite_db_name:{}", - self.sqlite_db_name.as_deref().unwrap_or("") - )); - parts.push(format!( - "custom_checkpointer:{}", - self.checkpointer_custom.is_some() - )); - parts.push(format!("clock:{}", self.clock_mode())); - parts.extend(self.event_bus.metadata_signature()); - hash_parts(&parts) + let parts: Vec = [ + "weavegraph-runtime-config-v1".to_string(), + format!("session_id:{}", self.session_id.as_deref().unwrap_or("")), + format!("sqlite_db_name:{}", self.sqlite_db_name.as_deref().unwrap_or("")), + format!("custom_checkpointer:{}", self.checkpointer_custom.is_some()), + format!("clock:{}", self.clock_mode()), + ] + .into_iter() + .chain(self.event_bus.metadata_signature()) + .collect(); + fnv1a_hex(&parts) } + /// Replace the event bus configuration. #[must_use] - /// Replace the event bus configuration for this runtime. pub fn with_event_bus(mut self, event_bus: EventBusConfig) -> Self { self.event_bus = event_bus; self } + /// Use a stdout-only event bus. #[must_use] - /// Configure the runtime with a stdout-only event bus. pub fn with_stdout_event_bus(self) -> Self { self.with_event_bus(EventBusConfig::with_stdout_only()) } + /// Use an in-memory event bus (silent; useful in tests). #[must_use] - /// Configure the runtime with an in-memory event bus (useful for testing). pub fn with_memory_event_bus(self) -> Self { self.with_event_bus(EventBusConfig::with_memory_sink()) } } -fn hash_parts(parts: &[String]) -> String { - const FNV_OFFSET: u64 = 0xcbf29ce484222325; - const FNV_PRIME: u64 = 0x100000001b3; +fn fnv1a_hex(parts: &[String]) -> String { + const OFFSET: u64 = 0xcbf29ce484222325; + const PRIME: u64 = 0x100000001b3; - let mut hash = FNV_OFFSET; + let mut hash = OFFSET; for part in parts { - for byte in part.as_bytes().iter().copied().chain([0xff]) { + for byte in part.bytes().chain([0xff]) { hash ^= u64::from(byte); - hash = hash.wrapping_mul(FNV_PRIME); + hash = hash.wrapping_mul(PRIME); } } format!("{hash:016x}") } -/// Selects the output target for an [`EventBusConfig`] sink entry. +/// Sink target for an [`EventBusConfig`] entry. #[derive(Clone, Debug, PartialEq, Eq)] pub enum SinkConfig { /// Write events to standard output. StdOut, - /// Capture events in memory (useful for testing). + /// Capture events in memory. Memory, } -/// Configuration for building the [`EventBus`] used by a runtime. +/// Settings for the [`EventBus`] used by a runtime. #[derive(Clone, Debug)] pub struct EventBusConfig { - /// Broadcast channel capacity; events are dropped when the buffer is full. + /// Broadcast channel capacity; events are dropped when the buffer fills. pub buffer_capacity: usize, - /// Ordered list of sink targets that will receive events. + /// Ordered list of sinks that receive events. pub sinks: Vec, diagnostics: DiagnosticsConfig, } @@ -179,8 +173,10 @@ impl EventBusConfig { /// Default broadcast channel capacity. pub const DEFAULT_BUFFER_CAPACITY: usize = 1024; + /// Build an `EventBusConfig` with the given capacity and sink list. + /// + /// A zero capacity is silently promoted to [`DEFAULT_BUFFER_CAPACITY`](Self::DEFAULT_BUFFER_CAPACITY). #[must_use] - /// Create an `EventBusConfig` with the given capacity and sinks. pub fn new(buffer_capacity: usize, sinks: Vec) -> Self { Self { buffer_capacity: if buffer_capacity == 0 { @@ -193,21 +189,20 @@ impl EventBusConfig { } } + /// A single stdout sink at the default capacity. #[must_use] - /// Create an `EventBusConfig` with a single stdout sink at the default capacity. pub fn with_stdout_only() -> Self { Self::new(Self::DEFAULT_BUFFER_CAPACITY, vec![SinkConfig::StdOut]) } + /// A single in-memory sink at the default capacity (no stdout output). #[must_use] - /// Create an `EventBusConfig` with a single in-memory sink (silent stdout) at the default capacity. pub fn with_memory_sink() -> Self { - // Memory sink intentionally omits stdout so callers get a silent capture by default. Self::new(Self::DEFAULT_BUFFER_CAPACITY, vec![SinkConfig::Memory]) } + /// Append `sink` unless it is already present. #[must_use] - /// Add a sink to this configuration, ignoring duplicates. pub fn add_sink(mut self, sink: SinkConfig) -> Self { if !self.sinks.contains(&sink) { self.sinks.push(sink); @@ -215,17 +210,17 @@ impl EventBusConfig { self } - /// Returns the configured broadcast buffer capacity. + /// Broadcast channel capacity. pub fn buffer_capacity(&self) -> usize { self.buffer_capacity } - /// Returns the configured sink list. + /// Configured sink list. pub fn sinks(&self) -> &[SinkConfig] { &self.sinks } - /// Return deterministic metadata entries for this event bus configuration. + /// Deterministic metadata entries describing this configuration. #[must_use] pub fn metadata_signature(&self) -> Vec { let mut parts = vec![format!("event_buffer:{}", self.buffer_capacity)]; @@ -233,40 +228,37 @@ impl EventBusConfig { self.sinks .iter() .enumerate() - .map(|(index, sink)| format!("event_sink:{index}:{sink:?}")), + .map(|(i, s)| format!("event_sink:{i}:{s:?}")), ); parts.extend(self.diagnostics.metadata_signature()); parts } + /// Override the diagnostics configuration. #[must_use] - /// Override the diagnostics configuration for this event bus. pub fn with_diagnostics(mut self, diagnostics: DiagnosticsConfig) -> Self { self.diagnostics = diagnostics.with_default_capacity(self.buffer_capacity); self } + /// Construct and return the configured [`EventBus`]. #[must_use] - /// Build and return the configured [`EventBus`]. pub fn build_event_bus(&self) -> EventBus { - let mut sinks: Vec> = if self.sinks.is_empty() { + let sinks: Vec> = if self.sinks.is_empty() { vec![Box::new(StdOutSink::default())] } else { self.sinks .iter() - .map(|sink| match sink { + .map(|s| match s { SinkConfig::StdOut => Box::new(StdOutSink::default()) as Box, SinkConfig::Memory => Box::new(MemorySink::new()) as Box, }) .collect() }; - if sinks.is_empty() { - sinks.push(Box::new(StdOutSink::default())); - } EventBus::with_capacity_and_diag( sinks, - self.buffer_capacity(), - self.diagnostics.effective_capacity(self.buffer_capacity()), + self.buffer_capacity, + self.diagnostics.effective_capacity(self.buffer_capacity), self.diagnostics.enabled, self.diagnostics.emit_to_events, ) @@ -279,43 +271,42 @@ impl Default for EventBusConfig { } } -/// Configuration controlling the diagnostics (sink health) broadcast channel. +/// Settings for the diagnostics broadcast channel (sink health reporting). #[derive(Clone, Debug, PartialEq, Eq)] pub struct DiagnosticsConfig { /// Whether sink diagnostics are enabled. pub enabled: bool, - /// Optional override for the diagnostics channel capacity; falls back to the event bus capacity. + /// Channel capacity; falls back to the event bus capacity when `None`. pub buffer_capacity: Option, - /// Whether diagnostics should also be forwarded into the main event stream. + /// Forward diagnostics events into the main event stream. pub emit_to_events: bool, } impl DiagnosticsConfig { - fn normalize_capacity(capacity: usize) -> usize { - capacity.max(1) + fn min_one(n: usize) -> usize { + n.max(1) } - /// Create a default `DiagnosticsConfig` with the given event bus capacity. + /// Default settings tied to a specific event bus capacity. pub fn default_with_capacity(event_bus_capacity: usize) -> Self { Self { enabled: true, - buffer_capacity: Some(Self::normalize_capacity(event_bus_capacity)), + buffer_capacity: Some(Self::min_one(event_bus_capacity)), emit_to_events: false, } } - /// Fill in the buffer capacity from `event_bus_capacity` if not already set. + /// Set `buffer_capacity` from `event_bus_capacity` if it is not already provided. pub fn with_default_capacity(mut self, event_bus_capacity: usize) -> Self { - if self.buffer_capacity.is_none() { - self.buffer_capacity = Some(Self::normalize_capacity(event_bus_capacity)); - } + self.buffer_capacity + .get_or_insert_with(|| Self::min_one(event_bus_capacity)); self } - /// Return the effective diagnostics channel capacity, falling back to `event_bus_capacity`. + /// Effective channel capacity, falling back to `event_bus_capacity`. pub fn effective_capacity(&self, event_bus_capacity: usize) -> usize { self.buffer_capacity - .unwrap_or_else(|| Self::normalize_capacity(event_bus_capacity)) + .unwrap_or_else(|| Self::min_one(event_bus_capacity)) } fn metadata_signature(&self) -> Vec { @@ -324,8 +315,7 @@ impl DiagnosticsConfig { format!( "diagnostics_capacity:{}", self.buffer_capacity - .map(|capacity| capacity.to_string()) - .unwrap_or_default() + .map_or_else(String::new, |c| c.to_string()) ), format!("diagnostics_emit_to_events:{}", self.emit_to_events), ] diff --git a/src/runtimes/session.rs b/src/runtimes/session.rs index 3196144..0ee6112 100644 --- a/src/runtimes/session.rs +++ b/src/runtimes/session.rs @@ -1,73 +1,45 @@ -//! Session state management for workflow execution. -//! -//! This module defines the core types for managing session state during workflow -//! execution, including state persistence across steps and session initialization. - +//! Session state types for workflow execution. use crate::schedulers::{Scheduler, SchedulerState}; use crate::state::VersionedState; use crate::types::NodeKind; -/// Session state that needs to be persisted across steps. -/// -/// Contains all the information needed to resume a workflow from a checkpoint, -/// including the versioned state, current step number, execution frontier, -/// and scheduler state. -/// -/// # Examples -/// -/// ```rust -/// use weavegraph::runtimes::SessionState; -/// use weavegraph::state::VersionedState; -/// use weavegraph::types::NodeKind; -/// use weavegraph::schedulers::{Scheduler, SchedulerState}; +/// Persistent state for a single workflow execution session. /// -/// let session = SessionState { -/// state: VersionedState::new_with_user_message("Hello"), -/// step: 0, -/// frontier: vec![NodeKind::Custom("start".into())], -/// scheduler: Scheduler::new(4), -/// scheduler_state: SchedulerState::default(), -/// }; -/// -/// assert_eq!(session.step, 0); -/// ``` +/// Carries everything needed to resume execution from a checkpoint: +/// versioned channel state, current step, the active node frontier, and +/// scheduler bookkeeping. #[derive(Debug, Clone)] pub struct SessionState { - /// The versioned state containing messages and extra data. + /// Versioned channel state (messages, extras). pub state: VersionedState, - /// The current step number in the workflow execution. + /// Step counter at which this snapshot was taken. pub step: u64, - /// The current execution frontier - nodes to be processed next. + /// Nodes scheduled for the next execution round. pub frontier: Vec, - /// The scheduler managing concurrent node execution. + /// Scheduler instance managing concurrency limits. pub scheduler: Scheduler, - /// Internal scheduler tracking state. + /// Mutable bookkeeping owned by the scheduler between steps. pub scheduler_state: SchedulerState, } -/// Indicates how a session was initialized. -/// -/// Used to inform callers whether they're working with a fresh session -/// or one that was resumed from a checkpoint. +/// Whether a session was started fresh or resumed from a checkpoint. #[derive(Debug, Clone, PartialEq, Eq)] pub enum SessionInit { - /// A brand new session was created. + /// No prior checkpoint exists; execution starts from step 0. Fresh, - /// An existing session was resumed from a checkpoint. + /// A checkpoint was found and loaded at this step. Resumed { - /// The step number at which the session was checkpointed. + /// Step number recorded in the loaded checkpoint. checkpoint_step: u64, }, } -/// Snapshot of channel versions for tracking state evolution. -/// -/// Used to detect state changes between steps and enable version-based -/// optimizations in the scheduler. +/// Channel version snapshot used to detect state changes between steps. #[derive(Debug, Clone)] pub struct StateVersions { - /// Version counter for the messages channel. + /// Current version of the messages channel. pub messages_version: u32, - /// Version counter for the extra data channel. + /// Current version of the extras channel. pub extra_version: u32, } + diff --git a/src/runtimes/streaming.rs b/src/runtimes/streaming.rs index d28aca6..2740805 100644 --- a/src/runtimes/streaming.rs +++ b/src/runtimes/streaming.rs @@ -1,44 +1,40 @@ -//! Event stream management for workflow execution. -//! -//! This module handles the lifecycle of event streams during workflow -//! execution, including finalization and cleanup. - +//! Event stream lifecycle helpers for workflow execution. use crate::event_bus::{Event, EventBus, INVOCATION_END_SCOPE, STREAM_END_SCOPE}; -/// Internal reason for ending an event stream. +/// Internal reason a workflow's event stream is being closed. pub(crate) enum StreamEndReason { - /// The workflow completed successfully. + /// Workflow ran to completion. Completed { - /// The final step number. + /// Final step number. step: u64, }, - /// The workflow ended due to an error. + /// Workflow halted due to an error. Error { - /// The step at which the error occurred (if known). + /// Step at which the error occurred, if known. step: Option, - /// Description of the error. + /// Human-readable error description. error: String, }, } impl StreamEndReason { - /// Format the stream end reason as a diagnostic message. - pub fn format_message(&self, session_id: &str) -> String { + fn format_message(&self, session_id: &str) -> String { match self { - StreamEndReason::Completed { step } => { + Self::Completed { step } => { format!("session={session_id} status=completed step={step}") } - StreamEndReason::Error { step, error } => step - .map(|s| format!("session={session_id} status=error step={s} error={error}")) - .unwrap_or_else(|| format!("session={session_id} status=error error={error}")), + Self::Error { step, error } => match step { + Some(s) => format!("session={session_id} status=error step={s} error={error}"), + None => format!("session={session_id} status=error error={error}"), + }, } } } -/// Handles event stream finalization for a workflow session. +/// Emit a stream termination event and optionally close the event channel. /// -/// This function emits a stream termination event and optionally closes -/// the event channel. Called when a workflow completes or errors. +/// Called when a workflow completes or errors. Closes the channel only when +/// `event_stream_taken` is true, then resets the flag. pub(crate) fn finalize_event_stream( event_bus: &EventBus, session_id: &str, @@ -66,7 +62,7 @@ pub(crate) fn finalize_event_stream( } } -/// Emit a logical invocation completion marker without closing the event channel. +/// Emit an invocation completion marker without closing the event channel. pub(crate) fn emit_invocation_end(event_bus: &EventBus, session_id: &str, reason: StreamEndReason) { let message = reason.format_message(session_id); @@ -83,3 +79,4 @@ pub(crate) fn emit_invocation_end(event_bus: &EventBus, session_id: &str, reason ); } } + diff --git a/src/runtimes/types.rs b/src/runtimes/types.rs index 0651029..6f2476c 100644 --- a/src/runtimes/types.rs +++ b/src/runtimes/types.rs @@ -1,103 +1,50 @@ /*! -Type-safe wrappers for runtime execution identifiers and values. +Type-safe identifiers for runtime execution tracking. -This module provides newtype patterns to improve type safety and prevent -common errors like passing session IDs where step numbers are expected. -These types are specific to the runtime execution layer. - -For core workflow types (node kinds, channel types), see [`crate::types`]. - -# Design Philosophy - -Runtime types are **execution infrastructure** - they manage how workflows -are executed, tracked, and persisted, but don't define what workflows *are*. - -- **Type Safety**: Prevent mixing up session IDs, step numbers, etc. -- **Domain Modeling**: Make execution concepts explicit in the type system -- **Evolution**: Allow runtime infrastructure to evolve independently - -# Examples - -```rust -use weavegraph::runtimes::types::{SessionId, StepNumber}; - -// Type-safe session management -let session = SessionId::generate(); -let step = StepNumber::zero(); - -// The compiler prevents mixing these up -// process_session(step); // ✗ Compile error! -// process_step(session); // ✗ Compile error! -``` +Newtype wrappers that prevent mixing session IDs with step numbers or +with raw strings and integers. For core workflow types (node kinds, +channel types), see [`crate::types`]. */ use serde::{Deserialize, Serialize}; use std::fmt; -/// Type-safe wrapper for session identifiers. -/// -/// This prevents accidentally passing arbitrary strings where session IDs -/// are expected and provides utilities for generating valid session IDs. +/// A type-safe session identifier. /// -/// # Examples +/// Wraps a `String` so the compiler rejects raw strings wherever a +/// `SessionId` is expected. /// /// ```rust /// use weavegraph::runtimes::types::SessionId; /// -/// let session_id = SessionId::new("my_session"); -/// let generated_id = SessionId::generate(); -/// -/// // Type safety - can't accidentally pass a string -/// // process_session(session_id); // ✓ OK -/// // process_session("my_session"); // ✗ Compile error +/// let id = SessionId::new("my_session"); +/// let generated = SessionId::generate(); +/// assert_ne!(id.as_str(), generated.as_str()); /// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct SessionId(String); impl SessionId { - /// Create a new session ID from a string. - /// - /// # Parameters - /// - /// * `id` - The session identifier string - /// - /// # Returns - /// - /// A type-safe `SessionId` wrapper + /// Wrap a string as a session ID. #[must_use] pub fn new(id: impl Into) -> Self { Self(id.into()) } - /// Generate a new unique session ID. - /// - /// Uses the ID generator utilities to create a unique identifier - /// suitable for session tracking. - /// - /// # Returns - /// - /// A newly generated unique `SessionId` + /// Generate a unique session ID. #[must_use] pub fn generate() -> Self { use crate::utils::id_generator::IdGenerator; Self(IdGenerator::new().generate_session_id()) } - /// Get the inner string value. - /// - /// # Returns - /// - /// A reference to the underlying string + /// Borrow the inner string slice. #[must_use] pub fn as_str(&self) -> &str { &self.0 } - /// Convert into the inner string value. - /// - /// # Returns - /// - /// The underlying string + /// Consume the wrapper, returning the inner `String`. #[must_use] pub fn into_string(self) -> String { self.0 @@ -106,7 +53,7 @@ impl SessionId { impl fmt::Display for SessionId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) + f.write_str(&self.0) } } @@ -118,7 +65,7 @@ impl From for SessionId { impl From<&str> for SessionId { fn from(s: &str) -> Self { - Self(s.to_string()) + Self(s.to_owned()) } } @@ -128,73 +75,47 @@ impl AsRef for SessionId { } } -/// Type-safe wrapper for workflow step numbers. -/// -/// Prevents confusion between step numbers and other numeric values, -/// and provides utilities for step arithmetic. +/// A type-safe workflow step counter. /// -/// # Examples +/// Wraps a `u64` to prevent mixing step numbers with other numeric values. +/// Ordering is derived so step numbers can be compared and sorted directly. /// /// ```rust /// use weavegraph::runtimes::types::StepNumber; /// /// let step = StepNumber::new(5); -/// let next_step = step.next(); -/// assert_eq!(next_step.value(), 6); +/// assert_eq!(step.next().value(), 6); +/// assert!(StepNumber::zero().is_initial()); /// ``` #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] pub struct StepNumber(u64); impl StepNumber { - /// Create a new step number. - /// - /// # Parameters - /// - /// * `step` - The step number value - /// - /// # Returns - /// - /// A type-safe `StepNumber` wrapper + /// Create a step number from a raw value. #[must_use] pub fn new(step: u64) -> Self { Self(step) } - /// Create step number zero (initial step). - /// - /// # Returns - /// - /// A `StepNumber` representing step 0 + /// The initial step (step 0). #[must_use] pub fn zero() -> Self { Self(0) } - /// Get the numeric value of this step. - /// - /// # Returns - /// - /// The underlying u64 value + /// The underlying `u64` value. #[must_use] pub fn value(self) -> u64 { self.0 } - /// Get the next step number. - /// - /// # Returns - /// - /// A new `StepNumber` incremented by 1 + /// The next step, saturating at `u64::MAX`. #[must_use] pub fn next(self) -> Self { Self(self.0.saturating_add(1)) } - /// Check if this is the initial step (step 0). - /// - /// # Returns - /// - /// `true` if this is step 0 + /// Returns `true` if this is step 0. #[must_use] pub fn is_initial(self) -> bool { self.0 == 0 @@ -203,7 +124,7 @@ impl StepNumber { impl fmt::Display for StepNumber { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) + fmt::Display::fmt(&self.0, f) } } From 67b85a28d077805b57579b6474071e494591131e Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 13:14:30 -0400 Subject: [PATCH 02/15] heavy revision work for base modules --- src/app.rs | 662 ++++++++++--------------------------------------- src/control.rs | 30 +-- src/message.rs | 154 +++--------- src/node.rs | 252 ++++++------------- src/state.rs | 546 ++++++++-------------------------------- src/types.rs | 214 ++++------------ 6 files changed, 402 insertions(+), 1456 deletions(-) diff --git a/src/app.rs b/src/app.rs index 636159d..43ed4c2 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,7 +1,7 @@ -//! Application layer providing the high-level [`App`] entry point for workflow invocation. +//! High-level [`App`] entry point for workflow invocation. //! -//! `App` manages node registration, graph compilation, and dispatches execution to -//! an [`AppRunner`]. +//! `App` manages node registration, graph compilation, and delegates execution +//! to an [`AppRunner`]. use rustc_hash::FxHashMap; use std::sync::Arc; @@ -23,18 +23,16 @@ use thiserror::Error; use tokio::task::JoinHandle; use tracing::instrument; -/// Orchestrates graph execution and applies reducers at barriers. +/// Central coordination point for workflow execution. /// -/// `App` is the central coordination point for workflow execution, managing: -/// - Node graph topology (nodes, edges, conditional routing) -/// - State reduction through configurable reducers -/// - Runtime configuration and checkpointing +/// Holds the compiled graph topology (nodes, edges, conditional routing), +/// the reducer registry, and runtime configuration. +/// Construct via [`GraphBuilder`](crate::graphs::GraphBuilder). /// /// # Examples /// /// ```rust,no_run /// use weavegraph::graphs::GraphBuilder; -/// use weavegraph::runtimes::CheckpointerType; /// use weavegraph::state::VersionedState; /// use weavegraph::types::NodeKind; /// use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; @@ -55,8 +53,7 @@ use tracing::instrument; /// .add_edge(NodeKind::Custom("process".into()), NodeKind::End) /// .compile()?; /// -/// let initial_state = VersionedState::new_with_user_message("Hello"); -/// let final_state = app.invoke(initial_state).await?; +/// let final_state = app.invoke(VersionedState::new_with_user_message("Hello")).await?; /// # Ok(()) /// # } /// ``` @@ -69,22 +66,20 @@ pub struct App { runtime_config: RuntimeConfig, } -/// Combined handle exposing the configured event bus and a single subscription. +/// Event bus handle and initial subscription, obtained from [`App::event_stream`]. /// -/// Obtained from [`App::event_stream()`], it lets callers attach additional sinks -/// before execution starts or choose how to consume the broadcast feed (async -/// stream, blocking iterator, or timed polling). +/// Lets callers attach additional sinks before execution starts, or consume the +/// broadcast feed as an async stream, blocking iterator, or timed poll. pub struct AppEventStream { event_bus: EventBus, event_stream: Option, } -/// Errors returned when accessing an [`AppEventStream`] after its subscription -/// has already been consumed. +/// Error returned when an [`AppEventStream`] subscription is accessed after being consumed. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum AppEventStreamError { - /// The event stream has already been taken from this invocation handle. + /// The event stream has already been taken from this handle. #[error("event stream has already been taken")] #[cfg_attr( feature = "diagnostics", @@ -100,25 +95,24 @@ type AppEventStreamResult = Result; /// Handle for a streaming workflow invocation. /// -/// Dropping the handle aborts the workflow task. Use [`join`](InvocationHandle::join) -/// to await graceful completion; the paired event stream will emit a diagnostic with -/// scope [`STREAM_END_SCOPE`](crate::event_bus::STREAM_END_SCOPE) before closing. +/// Dropping the handle aborts the workflow task. Call [`join`](InvocationHandle::join) +/// to await graceful completion; the paired event stream emits a diagnostic with scope +/// [`STREAM_END_SCOPE`](crate::event_bus::STREAM_END_SCOPE) before closing. pub struct InvocationHandle { join_handle: Option>>, } /// Result of applying node partials at a barrier. /// -/// The outcome aggregates channel and error information in a deterministic -/// order so downstream consumers (runner, checkpointers, tests) observe stable -/// behaviour across executions. +/// Aggregates channel and error information in deterministic order so downstream +/// consumers observe stable behaviour across executions. #[derive(Debug, Clone, Default)] pub struct BarrierOutcome { - /// Channel identifiers that were updated during the barrier. + /// Channel identifiers updated during this barrier. pub updated_channels: Vec<&'static str>, - /// Aggregated error events emitted by nodes in the superstep. + /// Error events emitted by nodes in the superstep. pub errors: Vec, - /// Frontier manipulation commands emitted during the barrier. + /// Frontier commands emitted during the barrier. pub frontier_commands: Vec<(NodeKind, FrontierCommand)>, } @@ -140,15 +134,14 @@ pub struct GraphMetadata { pub reducer_signature: Vec, } -fn hash_parts(parts: &[String]) -> String { - const FNV_OFFSET: u64 = 0xcbf29ce484222325; - const FNV_PRIME: u64 = 0x100000001b3; - - let mut hash = FNV_OFFSET; +fn fnv1a_hex(parts: &[String]) -> String { + const OFFSET: u64 = 0xcbf29ce484222325; + const PRIME: u64 = 0x100000001b3; + let mut hash = OFFSET; for part in parts { - for byte in part.as_bytes().iter().copied().chain([0xff]) { + for byte in part.bytes().chain([0xff]) { hash ^= u64::from(byte); - hash = hash.wrapping_mul(FNV_PRIME); + hash = hash.wrapping_mul(PRIME); } } format!("{hash:016x}") @@ -167,9 +160,7 @@ impl AppEventStream { &self.event_bus } - /// Mutable access to the underlying broadcast subscription. - /// - /// Returns an error if the stream was already consumed by another accessor. + /// Mutable reference to the underlying broadcast subscription. pub fn event_stream(&mut self) -> AppEventStreamResult<&mut EventStream> { self.event_stream .as_mut() @@ -177,8 +168,6 @@ impl AppEventStream { } /// Consume the handle and return the raw event stream. - /// - /// Subsequent calls will error with [`AppEventStreamError::AlreadyTaken`]. pub fn into_stream(mut self) -> AppEventStreamResult { self.event_stream .take() @@ -199,16 +188,12 @@ impl AppEventStream { Ok((self.event_bus, stream)) } - /// Consume and convert the stream into a blocking iterator. - /// - /// Fails if the stream was already taken through another accessor. + /// Convert into a blocking iterator. pub fn into_blocking_iter(self) -> AppEventStreamResult { Ok(self.into_stream()?.into_blocking_iter()) } - /// Consume and convert the stream into an async iterator. - /// - /// Fails if the stream was already taken through another accessor. + /// Convert into a boxed async stream. pub fn into_async_stream( self, ) -> AppEventStreamResult> { @@ -216,8 +201,6 @@ impl AppEventStream { } /// Await the next event with a timeout, skipping lag notifications. - /// - /// Fails if the stream was already taken through another accessor. pub async fn next_timeout( &mut self, duration: std::time::Duration, @@ -227,44 +210,34 @@ impl AppEventStream { } impl InvocationHandle { - /// Abort the underlying workflow task. `join` will return a join error afterwards. - /// - /// Equivalent to dropping the handle explicitly. + /// Abort the underlying workflow task. pub fn abort(&self) { if let Some(handle) = &self.join_handle { handle.abort(); } } - /// Returns true if the underlying workflow task has completed or aborted. + /// Returns `true` if the underlying task has finished or been aborted. #[must_use] pub fn is_finished(&self) -> bool { - self.join_handle - .as_ref() - .map(|h| h.is_finished()) - .unwrap_or(true) + self.join_handle.as_ref().is_none_or(|h| h.is_finished()) } /// Await the workflow result. /// - /// # Errors - /// - /// Returns [`RunnerError::JoinHandleConsumed`] if `join()` was already called, - /// or [`RunnerError::Join`] if the underlying task panicked or was cancelled. + /// Returns [`RunnerError::JoinHandleConsumed`] if called more than once, + /// or [`RunnerError::Join`] if the task panicked or was cancelled. pub async fn join(mut self) -> Result { let handle = self .join_handle .take() .ok_or(RunnerError::JoinHandleConsumed)?; - match handle.await { - Ok(result) => result, - Err(err) => Err(RunnerError::Join(err)), - } + handle.await.map_err(RunnerError::Join)? } } impl App { - /// Internal (crate) factory to build an App while keeping nodes/edges private. + /// Build an `App` from pre-validated graph components. pub(crate) fn from_parts( nodes: FxHashMap>, edges: FxHashMap>, @@ -281,115 +254,80 @@ impl App { } } - /// Returns a reference to the conditional edges in this graph. - /// - /// Conditional edges enable dynamic routing based on runtime state, - /// allowing workflows to branch based on computed conditions. Predicates - /// return a String which is interpreted as the next target node: - /// - "End" and "Start" are recognized as virtual endpoints - /// - any other string is treated as the name of a custom node - /// - /// At runtime, targets are validated before being pushed to the frontier. - /// Unknown custom targets are skipped with a warning, preserving progress. - /// - /// # Returns - /// A slice of conditional edge specifications. + /// Conditional edges registered in the graph. #[must_use] - pub fn conditional_edges(&self) -> &Vec { + pub fn conditional_edges(&self) -> &[crate::graphs::ConditionalEdge] { &self.conditional_edges } - /// Returns a reference to the nodes registry. - /// - /// Provides access to all registered node implementations in the graph. - /// Nodes are keyed by their `NodeKind` identifier. - /// - /// # Returns - /// A map from `NodeKind` to the corresponding `Node` implementation. + /// Nodes registered in the graph, keyed by [`NodeKind`]. #[must_use] pub fn nodes(&self) -> &FxHashMap> { &self.nodes } - /// Returns a reference to the unconditional edges in this graph. - /// - /// Unconditional edges define the static topology of the workflow graph, - /// specifying which nodes can transition to which other nodes. - /// - /// # Returns - /// A map from source `NodeKind` to a list of destination `NodeKind`s. + /// Unconditional edges in the graph. #[must_use] pub fn edges(&self) -> &FxHashMap> { &self.edges } - /// Returns a reference to the runtime configuration. - /// - /// Runtime configuration includes checkpointer settings, session IDs, - /// and other execution parameters. - /// - /// # Returns - /// The current runtime configuration. + /// Runtime configuration for this app instance. #[must_use] pub fn runtime_config(&self) -> &RuntimeConfig { &self.runtime_config } - /// Return the Weavegraph crate version compiled into this binary. + /// Weavegraph crate version compiled into this binary. #[must_use] pub fn weavegraph_version(&self) -> &'static str { env!("CARGO_PKG_VERSION") } - /// Return metadata describing this graph definition. + /// Metadata describing this compiled graph definition. /// - /// The hash covers registered node IDs, unconditional edges, conditional - /// edge sources/counts, and reducer definition labels. Conditional edge - /// predicate closure bodies are opaque to Rust, so changing a predicate - /// implementation without changing its registration shape is not detectable. + /// The hash covers node IDs, unconditional edges, conditional edge sources/counts, + /// and reducer labels. Predicate closure bodies are opaque, so a predicate change + /// with the same registration shape is undetectable. #[must_use] pub fn graph_metadata(&self) -> GraphMetadata { let mut parts = vec!["weavegraph-graph-v1".to_string()]; - let mut nodes: Vec = self.nodes.keys().map(NodeKind::encode).collect(); - nodes.sort(); - parts.extend(nodes.iter().map(|node| format!("node:{node}"))); + let mut node_keys: Vec = self.nodes.keys().map(NodeKind::encode).collect(); + node_keys.sort(); + parts.extend(node_keys.iter().map(|n| format!("node:{n}"))); - let mut edges: Vec = self + let mut edge_keys: Vec = self .edges .iter() .flat_map(|(from, targets)| { targets .iter() - .map(move |target| format!("edge:{}->{}", from.encode(), target.encode())) + .map(move |to| format!("edge:{}->{}", from.encode(), to.encode())) }) .collect(); - edges.sort(); - parts.extend(edges); + edge_keys.sort(); + parts.extend(edge_keys); let mut conditional_sources: Vec = self .conditional_edges .iter() - .map(|edge| edge.from().encode()) + .map(|e| e.from().encode()) .collect(); conditional_sources.sort(); parts.extend( conditional_sources .iter() .enumerate() - .map(|(index, from)| format!("conditional:{index}:{from}")), + .map(|(i, from)| format!("conditional:{i}:{from}")), ); let reducer_signature = self.reducer_registry.definition_signature(); - parts.extend( - reducer_signature - .iter() - .map(|entry| format!("reducer:{entry}")), - ); + parts.extend(reducer_signature.iter().map(|r| format!("reducer:{r}"))); GraphMetadata { weavegraph_version: self.weavegraph_version().to_string(), - graph_hash: hash_parts(&parts), + graph_hash: fnv1a_hex(&parts), node_count: self.nodes.len(), edge_count: self.edges.values().map(Vec::len).sum(), conditional_edge_count: self.conditional_edges.len(), @@ -397,17 +335,16 @@ impl App { } } - /// Return only the graph definition hash. + /// The graph definition hash without the full metadata. #[must_use] pub fn graph_definition_hash(&self) -> String { self.graph_metadata().graph_hash } - /// Create a subscription to the configured event bus without starting execution. + /// Build a subscription to the configured event bus without starting execution. /// - /// This is the low-level entry point when you want to inspect the stream or - /// register additional sinks before running the workflow (e.g. in tests or - /// fully-custom server integrations). + /// Useful for attaching additional sinks or observing events before running + /// the workflow — for example in tests or custom server integrations. /// /// ```no_run /// use futures_util::StreamExt; @@ -459,50 +396,40 @@ impl App { (checkpointer_type, custom_checkpointer) } - /// Internal helper that centralises runner setup for the public `invoke*` helpers. - /// - /// - `R` represents any auxiliary handle the caller wants to extract alongside the run - /// result (for example, a `flume::Receiver` when wiring a channel). - /// - `F` is a closure that is invoked exactly once to construct the `EventBus` - /// together with that auxiliary handle. Using `FnOnce` lets the closure move - /// ownership of channels or sink vectors. - /// - /// The helper resolves the effective checkpointer configuration, spins up an - /// `AppRunner`, executes the session, and returns both the workflow result and the - /// caller-provided handle so wrappers can keep their surface area small and - /// consistent. - async fn invoke_with_bus_builder( + async fn build_runner( &self, - initial_state: VersionedState, + event_bus: EventBus, autosave: bool, checkpointer_override: Option, - build_event_bus: F, - ) -> (Result, R) - where - F: FnOnce() -> (EventBus, R), - { - let (event_bus, output) = build_event_bus(); + ) -> AppRunner { let (checkpointer_type, custom_checkpointer) = self.resolve_checkpointer(checkpointer_override); - - let mut runner_builder = AppRunner::builder() + let builder = AppRunner::builder() .app(self.clone()) .autosave(autosave) .event_bus(event_bus) .start_listener(true); - - runner_builder = if let Some(custom) = custom_checkpointer { - runner_builder.checkpointer_custom(custom) - } else { - runner_builder.checkpointer(checkpointer_type) + let builder = match custom_checkpointer { + Some(custom) => builder.checkpointer_custom(custom), + None => builder.checkpointer(checkpointer_type), }; + builder.build().await + } - let runner = runner_builder.build().await; - + async fn invoke_with_bus_builder( + &self, + initial_state: VersionedState, + autosave: bool, + checkpointer_override: Option, + build_event_bus: F, + ) -> (Result, R) + where + F: FnOnce() -> (EventBus, R), + { + let (event_bus, output) = build_event_bus(); + let runner = self.build_runner(event_bus, autosave, checkpointer_override).await; let session_id = self.next_session_id(); - let result = Self::run_session(runner, session_id, initial_state).await; - - (result, output) + (Self::run_session(runner, session_id, initial_state).await, output) } /// Invoke the workflow asynchronously while streaming events to the caller. @@ -510,13 +437,9 @@ impl App { /// Returns a join handle for the workflow outcome and an [`EventStream`] that yields /// every event emitted during execution. The stream closes after emitting a /// diagnostic with scope [`STREAM_END_SCOPE`](crate::event_bus::STREAM_END_SCOPE). - /// Any sinks configured on the runtime event bus continue to receive events. - /// - /// # Cancellation /// - /// Dropping the [`InvocationHandle`] (or calling [`InvocationHandle::abort`]) stops - /// the workflow immediately. Dropping the event stream does **not** cancel the run; - /// use the handle if you want to interrupt execution when the client disconnects. + /// Dropping the [`InvocationHandle`] stops the workflow. Dropping the stream alone + /// does not; use the handle to interrupt execution on client disconnect. /// /// ```no_run /// use futures_util::StreamExt; @@ -563,134 +486,33 @@ impl App { &self, initial_state: VersionedState, ) -> (InvocationHandle, EventStream) { - let (checkpointer_type, custom_checkpointer) = self.resolve_checkpointer(None); - - let event_handle = self.event_stream(); - // SAFETY: We just created the event handle via event_stream(), so the stream - // is guaranteed to be unconsumed. If this somehow fails, it indicates a bug - // in the AppEventStream implementation. - let (event_bus, event_stream) = event_handle.split().unwrap_or_else(|_| { - unreachable!("fresh App::event_stream() always yields unused stream") - }); - - let mut runner_builder = AppRunner::builder() - .app(self.clone()) - .autosave(true) - .event_bus(event_bus) - .start_listener(true); - - runner_builder = if let Some(custom) = custom_checkpointer { - runner_builder.checkpointer_custom(custom) - } else { - runner_builder.checkpointer(checkpointer_type) - }; - - let runner = runner_builder.build().await; - + let (event_bus, event_stream) = self + .event_stream() + .split() + .unwrap_or_else(|_| unreachable!("fresh event_stream() always owns its stream")); + let runner = self.build_runner(event_bus, true, None).await; let session_id = self.next_session_id(); let join = tokio::spawn(Self::run_session(runner, session_id, initial_state)); - - ( - InvocationHandle { - join_handle: Some(join), - }, - event_stream, - ) + (InvocationHandle { join_handle: Some(join) }, event_stream) } - /// Execute the entire workflow until completion or no nodes remain. + /// Execute the workflow to completion using the runtime-configured event bus. /// - /// This is the primary entry point for simple workflow execution. It creates an - /// `AppRunner` with the runtime-configured event bus (stdout sink by default), - /// manages session state, and coordinates execution through to completion. - /// - /// # Event Handling - /// - /// This method uses the **EventBus defined on the `RuntimeConfig`**. Out of the box - /// that means a stdout sink only, but you can customise the configuration when - /// building the app. - /// - /// For streaming-first scenarios consider [`invoke_streaming`](Self::invoke_streaming), + /// For streaming scenarios consider [`invoke_streaming`](Self::invoke_streaming), /// [`invoke_with_channel`](Self::invoke_with_channel), or - /// [`invoke_with_sinks`](Self::invoke_with_sinks). Drop down to - /// [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder) - /// when you need per-request isolation or bespoke runner lifecycle management. - /// - /// See [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder) - /// for streaming events to custom sinks. - /// - /// # Parameters - /// * `initial_state` - The starting state for workflow execution - /// - /// # Returns - /// * `Ok(VersionedState)` - The final state after workflow completion - /// * `Err(RunnerError)` - If execution fails due to node errors, - /// checkpointer issues, or other runtime problems - /// - /// # Examples - /// - /// ## Simple Execution (Default EventBus) + /// [`invoke_with_sinks`](Self::invoke_with_sinks). For per-request isolation or + /// custom lifecycle control use + /// [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder). /// /// ```rust,no_run /// use weavegraph::state::VersionedState; /// use weavegraph::channels::Channel; /// # use weavegraph::app::App; /// # async fn example(app: App) -> Result<(), Box> { - /// let initial = VersionedState::new_with_user_message("Start workflow"); - /// let final_state = app.invoke(initial).await?; - /// println!("Workflow completed with {} messages", final_state.messages.len()); + /// let final_state = app.invoke(VersionedState::new_with_user_message("Start")).await?; /// # Ok(()) /// # } /// ``` - /// - /// ## Custom Event Streaming (Use AppRunner) - /// - /// For streaming events to web clients, use `AppRunner` with a custom EventBus: - /// - /// ```rust,no_run - /// use weavegraph::event_bus::{EventBus, ChannelSink}; - /// use weavegraph::runtimes::{AppRunner, CheckpointerType}; - /// use weavegraph::state::VersionedState; - /// # use weavegraph::app::App; - /// # async fn example(app: App) -> Result<(), Box> { - /// // Create channel for streaming events - /// let (tx, rx) = flume::unbounded(); - /// - /// // Create EventBus with custom sink - /// let bus = EventBus::with_sinks(vec![ - /// Box::new(ChannelSink::new(tx)) - /// ]); - /// - /// // Use AppRunner with custom EventBus - /// let mut runner = AppRunner::builder() - /// .app(app) - /// .checkpointer(CheckpointerType::InMemory) - /// .autosave(false) - /// .event_bus(bus) - /// .build() - /// .await; - /// - /// let session_id = "my-session".to_string(); - /// let initial = VersionedState::new_with_user_message("Process this"); - /// runner.create_session(session_id.clone(), initial).await?; - /// - /// // Events now stream to the channel - /// tokio::spawn(async move { - /// while let Ok(event) = rx.recv_async().await { - /// println!("Event: {:?}", event); - /// } - /// }); - /// - /// runner.run_until_complete(&session_id).await?; - /// # Ok(()) - /// # } - /// ``` - /// - /// # Workflow Lifecycle - /// 1. Creates an `AppRunner` with the configured checkpointer and event bus - /// 2. Initializes or resumes a session - /// 3. Executes supersteps until End nodes or empty frontier - /// 4. Returns the final accumulated state #[instrument(skip(self, initial_state), err)] pub async fn invoke( &self, @@ -703,117 +525,29 @@ impl App { .0 } - /// Execute workflow with event streaming to a channel. - /// - /// This is a convenience method that combines `AppRunner::builder()` - /// with channel creation and management. It's ideal for simple use cases where - /// you want to stream events without manually managing the EventBus. - /// - /// # When to Use This - /// - /// - Simple scripts or CLI tools that need event streaming - /// - Single-execution scenarios (not web servers) - /// - You want both the final state AND the event stream + /// Execute the workflow and stream events to a [`flume`] channel. /// - /// # When NOT to Use This - /// - /// - Web servers with per-request streaming (use `AppRunner::builder()`) - /// - Need multiple EventSinks beyond ChannelSink (use `invoke_with_sinks()`) - /// - Need fine-grained control over EventBus lifecycle - /// - /// The runtime-configured sinks remain active; this helper simply appends a channel - /// sink so you can consume events alongside any existing logging destinations. - /// - /// # Returns - /// - /// Returns a tuple of: - /// - `Result` - Final workflow state - /// - `flume::Receiver` - Stream of events from workflow execution - /// - /// # Examples - /// - /// ## Basic Usage + /// Builds an `EventBus` from the runtime configuration and appends a + /// [`ChannelSink`], then returns both the execution result and the receiver. /// /// ```rust,no_run /// use weavegraph::state::VersionedState; /// # use weavegraph::app::App; /// # async fn example(app: App) -> Result<(), Box> { - /// // Execute with streaming /// let (result, events) = app.invoke_with_channel( /// VersionedState::new_with_user_message("Process this") /// ).await; /// - /// // Process events in parallel with execution /// tokio::spawn(async move { /// while let Ok(event) = events.recv_async().await { - /// println!("Event: {:?}", event); + /// println!("Event: {event:?}"); /// } /// }); /// /// let final_state = result?; - /// println!("Workflow completed!"); /// # Ok(()) /// # } /// ``` - /// - /// ## With Structured Event Processing - /// - /// ```rust,no_run - /// use weavegraph::event_bus::Event; - /// use weavegraph::state::VersionedState; - /// # use weavegraph::app::App; - /// # async fn example(app: App) -> Result<(), Box> { - /// let (result_future, events) = app.invoke_with_channel( - /// VersionedState::new_with_user_message("Analyze data") - /// ).await; - /// - /// // Collect all events - /// let event_collector = tokio::spawn(async move { - /// let mut collected = Vec::new(); - /// while let Ok(event) = events.recv_async().await { - /// match &event { - /// Event::Node(ne) => { - /// if let Some(node_id) = ne.node_id() { - /// println!("Node {}: {}", node_id, ne.message()); - /// } - /// } - /// Event::Diagnostic(de) => { - /// println!("Diagnostic: {}", de.message()); - /// } - /// Event::LLM(llm) => { - /// println!( - /// "LLM stream {}: {}", - /// llm.stream_id().unwrap_or("default"), - /// llm.chunk() - /// ); - /// } - /// } - /// collected.push(event); - /// } - /// collected - /// }); - /// - /// let final_state = result_future?; - /// let all_events = event_collector.await?; - /// println!("Captured {} events", all_events.len()); - /// # Ok(()) - /// # } - /// ``` - /// - /// # Architecture - /// - /// This method internally: - /// 1. Creates a `flume::unbounded()` channel - /// 2. Builds an EventBus from the runtime configuration and appends a `ChannelSink` - /// 3. Uses `AppRunner::builder()` with the custom EventBus - /// 4. Returns both the execution result and receiver - /// - /// # See Also - /// - /// - [`invoke_with_sinks()`](Self::invoke_with_sinks) - For multiple EventSinks - /// - [`invoke_streaming()`](Self::invoke_streaming) - Async `EventStream` helper - /// - [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder) - For web servers - /// - [`invoke()`](Self::invoke) - Simple execution without streaming #[instrument(skip(self, initial_state))] pub async fn invoke_with_channel( &self, @@ -831,40 +565,9 @@ impl App { .await } - /// Execute workflow with custom EventSinks for advanced streaming patterns. - /// - /// This convenience method allows you to specify multiple EventSinks while - /// still maintaining the simplicity of a single method call. Use this when - /// you need more control over event handling than `invoke_with_channel()` - /// provides, but don't need the full flexibility of `AppRunner`. - /// - /// # When to Use This - /// - /// - Need multiple sinks (e.g., stdout + channel + file) - /// - Want to configure EventBus but don't need per-request isolation - /// - Building a CLI tool with rich event handling - /// - /// # When NOT to Use This - /// - /// - Web servers with per-request streaming (use `AppRunner::builder()`) - /// - Need to create EventBus instances per HTTP request - /// - Require fine-grained control over runner lifecycle - /// - /// Sinks configured on the `RuntimeConfig` remain active; the provided collection is - /// appended so you can layer additional destinations without rebuilding the app. + /// Execute the workflow with additional event sinks. /// - /// # Parameters - /// - /// - `initial_state` - Starting state for workflow execution - /// - `sinks` - Vector of boxed EventSink implementations - /// - /// # Returns - /// - /// Final workflow state after completion - /// - /// # Examples - /// - /// ## Multiple Sinks + /// Runtime-configured sinks remain active; the provided sinks are appended. /// /// ```rust,no_run /// use weavegraph::event_bus::{ChannelSink, StdOutSink}; @@ -876,53 +579,13 @@ impl App { /// let final_state = app.invoke_with_sinks( /// VersionedState::new_with_user_message("Process data"), /// vec![ - /// Box::new(StdOutSink::default()), // Server logs - /// Box::new(ChannelSink::new(tx)), // Client stream + /// Box::new(StdOutSink::default()), + /// Box::new(ChannelSink::new(tx)), /// ], /// ).await?; - /// - /// // Process events from channel - /// tokio::spawn(async move { - /// while let Ok(event) = rx.recv_async().await { - /// println!("Client sees: {:?}", event); - /// } - /// }); - /// - /// println!("Workflow completed!"); - /// # Ok(()) - /// # } - /// ``` - /// - /// ## Custom Sink Implementation - /// - /// ```rust,no_run - /// use weavegraph::event_bus::{EventSink, Event}; - /// # use weavegraph::app::App; - /// # use weavegraph::state::VersionedState; - /// - /// struct MetricsSink; - /// - /// impl EventSink for MetricsSink { - /// fn handle(&mut self, event: &Event) -> std::io::Result<()> { - /// // Send to metrics system - /// Ok(()) - /// } - /// } - /// - /// # async fn example(app: App) -> Result<(), Box> { - /// let final_state = app.invoke_with_sinks( - /// VersionedState::new_with_user_message("Monitored workflow"), - /// vec![Box::new(MetricsSink)], - /// ).await?; /// # Ok(()) /// # } /// ``` - /// - /// # See Also - /// - /// - [`invoke_with_channel()`](Self::invoke_with_channel) - Simpler channel-only variant - /// - [`invoke_streaming()`](Self::invoke_streaming) - Async `EventStream` without channels - /// - [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder) - Full control #[instrument(skip(self, initial_state, sinks), err)] pub async fn invoke_with_sinks( &self, @@ -940,12 +603,6 @@ impl App { .0 } - /// Generate the session identifier for the next invocation. - /// - /// Prefers an explicit session id from the runtime configuration and - /// falls back to a randomly generated identifier when none is supplied. - /// Consolidating this logic helps keep new entry points from accidentally - /// reusing the same hard-coded id. fn next_session_id(&self) -> String { self.runtime_config .session_id @@ -953,17 +610,16 @@ impl App { .unwrap_or_else(|| IdGenerator::new().generate_run_id()) } - /// Drive a workflow session to completion, resuming from checkpoints when available. async fn run_session( mut runner: AppRunner, session_id: String, initial_state: VersionedState, ) -> Result { - let init_state = runner + let init = runner .create_session(session_id.clone(), initial_state) .await?; - if let SessionInit::Resumed { checkpoint_step } = init_state { + if let SessionInit::Resumed { checkpoint_step } = init { tracing::info!( session = %session_id, checkpoint_step, @@ -976,47 +632,10 @@ impl App { /// Merge node outputs and apply state reductions after a superstep. /// - /// This method coordinates the barrier synchronization phase of workflow - /// execution, where all node outputs from a superstep are collected, - /// merged, and applied to the global state via registered reducers. The - /// returned [`BarrierOutcome`] captures channel updates, aggregated errors, - /// and frontier commands in a stable order so downstream consumers can rely - /// on deterministic behaviour. - /// - /// # Parameters - /// * `state` - Mutable reference to the current versioned state - /// * `run_ids` - Slice of node kinds that executed in this superstep - /// * `node_partials` - Vector of partial updates from each executed node - /// - /// # Returns - /// * `Ok(Vec<&'static str>)` - Names of channels that were updated - /// * `Err(Box)` - If reducer application fails - /// - /// # State Management - /// - Aggregates messages, extra data, and errors from all nodes - /// - Applies registered reducers to merge updates into global state - /// - Intelligently bumps version numbers only when content changes - /// - Preserves deterministic merge behavior for reproducible execution - /// - /// # Examples - /// - /// ```rust,no_run - /// # use weavegraph::app::App; - /// # use weavegraph::node::NodePartial; - /// # use weavegraph::state::VersionedState; - /// # use weavegraph::types::NodeKind; - /// # use weavegraph::message::{Message, Role}; - /// # async fn example(app: App, state: &mut VersionedState) -> Result<(), String> { - /// let partials = vec![NodePartial::new().with_messages(vec![ - /// Message::with_role(Role::Assistant, "test"), - /// ])]; - /// let outcome = app.apply_barrier(state, &[NodeKind::Custom("process".into())], partials).await - /// .map_err(|e| format!("Error: {}", e))?; - /// println!("Updated channels: {:?}", outcome.updated_channels); - /// println!("Errors emitted: {}", outcome.errors.len()); - /// # Ok(()) - /// # } - /// ``` + /// Collects messages, extra data, errors, and frontier commands from each partial, + /// runs the reducer registry over the merged update, and bumps channel versions only + /// when content changes. Returns a [`BarrierOutcome`] with stable ordering so + /// downstream consumers observe deterministic behaviour across executions. #[instrument(skip(self, state, run_ids, node_partials), err)] pub async fn apply_barrier( &self, @@ -1029,10 +648,11 @@ impl App { let mut errors_all: Vec = Vec::new(); let mut frontier_commands: Vec<(NodeKind, FrontierCommand)> = Vec::new(); - for (i, p) in node_partials.iter().enumerate() { - let fallback = NodeKind::Custom("?".to_string()); - let nid = run_ids.get(i).unwrap_or(&fallback); - + let unknown = NodeKind::Custom("?".to_string()); + for (p, nid) in node_partials + .iter() + .zip(run_ids.iter().chain(std::iter::repeat(&unknown))) + { if let Some(ms) = &p.messages && !ms.is_empty() { @@ -1044,10 +664,9 @@ impl App { && !ex.is_empty() { tracing::debug!(node = ?nid, keys = ex.len(), "Node produced extra data"); - // Sort keys to keep the merged map deterministic across runs. - let mut sorted_pairs: Vec<_> = ex.iter().collect(); - sorted_pairs.sort_by(|(left, _), (right, _)| left.cmp(right)); - for (k, v) in sorted_pairs { + let mut sorted: Vec<_> = ex.iter().collect(); + sorted.sort_by(|(a, _), (b, _)| a.cmp(b)); + for (k, v) in sorted { extra_all.insert(k.clone(), v.clone()); } } @@ -1073,52 +692,30 @@ impl App { } } - // Sort aggregated errors so downstream consumers observe a stable order. errors_all.sort_by(|a, b| { - let key_a = scope_sort_key(&a.scope); - let key_b = scope_sort_key(&b.scope); - key_a - .cmp(&key_b) + scope_sort_key(&a.scope) + .cmp(&scope_sort_key(&b.scope)) .then_with(|| a.when.cmp(&b.when)) .then_with(|| a.error.message.cmp(&b.error.message)) }); - let errors_for_state = if errors_all.is_empty() { - None - } else { - Some(errors_all.clone()) - }; - - let merged_updates = NodePartial { - messages: if msgs_all.is_empty() { - None - } else { - Some(msgs_all) - }, - extra: if extra_all.is_empty() { - None - } else { - Some(extra_all) - }, - errors: errors_for_state, + let merged = NodePartial { + messages: (!msgs_all.is_empty()).then_some(msgs_all), + extra: (!extra_all.is_empty()).then_some(extra_all), + errors: (!errors_all.is_empty()).then(|| errors_all.clone()), frontier: None, }; - // Record before-states for version bump decisions let msgs_before_len = state.messages.len(); let msgs_before_ver = state.messages.version(); let extra_before = state.extra.snapshot(); let extra_before_ver = state.extra.version(); - // Apply reducers (they do NOT bump versions) - self.reducer_registry - .apply_all(&mut *state, &merged_updates)?; + self.reducer_registry.apply_all(&mut *state, &merged)?; - // Detect changes & bump versions responsibly let mut updated: Vec<&'static str> = Vec::new(); - let msgs_changed = state.messages.len() != msgs_before_len; - if msgs_changed { + if state.messages.len() != msgs_before_len { state .messages .set_version(msgs_before_ver.saturating_add(1)); @@ -1135,8 +732,7 @@ impl App { } let extra_after = state.extra.snapshot(); - let extra_changed = extra_after != extra_before; - if extra_changed { + if extra_after != extra_before { state.extra.set_version(extra_before_ver.saturating_add(1)); tracing::info!( target: "weavegraph::app", diff --git a/src/control.rs b/src/control.rs index 0dfe3f9..8288d06 100644 --- a/src/control.rs +++ b/src/control.rs @@ -1,29 +1,23 @@ -//! Control-flow primitives emitted by nodes to influence subsequent scheduling. -//! -//! Frontier commands are kept separate from state updates so nodes can -//! express routing intent without mutating application state directly. The -//! barrier aggregates these directives in a deterministic order and the runner -//! reconciles them with unconditional / conditional edges. +//! Control-flow directives emitted by nodes to shape the next scheduling frontier. use crate::types::NodeKind; -/// Route identifier used by frontier commands. +/// A destination node referenced by a frontier routing command. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum NodeRoute { - /// Route to another node in the graph. + /// Concrete node in the graph. Node(NodeKind), } impl NodeRoute { - /// Return the concrete `NodeKind` for this route. + /// Borrow the underlying [`NodeKind`]. #[must_use] pub fn kind(&self) -> &NodeKind { - match self { - NodeRoute::Node(kind) => kind, - } + let Self::Node(k) = self; + k } - /// Clone the underlying `NodeKind`. + /// Return an owned copy of the underlying [`NodeKind`]. #[must_use] pub fn to_node_kind(&self) -> NodeKind { self.kind().clone() @@ -31,16 +25,16 @@ impl NodeRoute { } impl From for NodeRoute { - fn from(kind: NodeKind) -> Self { - NodeRoute::Node(kind) + fn from(k: NodeKind) -> Self { + Self::Node(k) } } -/// Command emitted by a node to manipulate the next frontier. +/// Directive emitted by a node to modify how the barrier builds the next frontier. #[derive(Clone, Debug, PartialEq, Eq)] pub enum FrontierCommand { - /// Append additional routes to the existing frontier calculation. + /// Add routes alongside those produced by normal edge resolution. Append(Vec), - /// Replace the default routes emitted for the node. + /// Replace normal edge routes with exactly these routes. Replace(Vec), } diff --git a/src/message.rs b/src/message.rs index eb82e6f..6e5d612 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,66 +1,43 @@ -//! Message types representing chat turns and content in a workflow conversation. +//! Chat message types for conversation turns and roles. + use serde::{Deserialize, Serialize}; use std::fmt; -/// The role of a message sender in a conversation. -/// -/// This enum represents the standard roles used in chat-based AI interactions. -/// For custom roles not covered by the standard variants, use [`Role::Custom`]. -/// -/// # Serialization -/// -/// Roles serialize to/from lowercase strings for JSON compatibility: -/// - `Role::User` ↔ `"user"` -/// - `Role::Assistant` ↔ `"assistant"` -/// - `Role::System` ↔ `"system"` -/// - `Role::Tool` ↔ `"tool"` -/// - `Role::Custom("foo")` ↔ `"foo"` -/// -/// # Examples -/// -/// ``` -/// use weavegraph::message::Role; +/// Participant role in a conversation turn. /// -/// let role = Role::User; -/// assert_eq!(role.as_str(), "user"); -/// -/// let parsed: Role = "assistant".into(); -/// assert_eq!(parsed, Role::Assistant); -/// -/// // Custom roles for extensibility -/// let custom = Role::Custom("function".to_string()); -/// assert_eq!(custom.as_str(), "function"); -/// ``` +/// Roles serialize to/from lowercase strings: +/// `User` ↔ `"user"`, `Assistant` ↔ `"assistant"`, `System` ↔ `"system"`, +/// `Tool` ↔ `"tool"`, `Custom("foo")` ↔ `"foo"`. #[derive(Clone, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize)] #[serde(into = "String", try_from = "String")] pub enum Role { - /// User input message role. + /// Human / end-user turn. #[default] User, - /// AI assistant response message role. + /// Model response turn. Assistant, - /// System prompt or instruction message role. + /// System-level instruction turn. System, - /// Tool/function call result message role. + /// Tool or function result turn. Tool, - /// Custom role for extensibility (e.g., "function", "context"). + /// Any role not covered by the standard variants. Custom(String), } impl Role { - /// Returns the string representation of this role. + /// String form of this role. #[must_use] pub fn as_str(&self) -> &str { match self { - Role::User => "user", - Role::Assistant => "assistant", - Role::System => "system", - Role::Tool => "tool", - Role::Custom(s) => s.as_str(), + Self::User => "user", + Self::Assistant => "assistant", + Self::System => "system", + Self::Tool => "tool", + Self::Custom(s) => s.as_str(), } } - /// Returns true if this role matches the given string. + /// Returns `true` if this role's string form equals `role_str`. #[must_use] pub fn matches(&self, role_str: &str) -> bool { self.as_str() == role_str @@ -69,129 +46,69 @@ impl Role { impl fmt::Display for Role { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.as_str()) + f.write_str(self.as_str()) } } impl From<&str> for Role { fn from(s: &str) -> Self { match s { - "user" => Role::User, - "assistant" => Role::Assistant, - "system" => Role::System, - "tool" => Role::Tool, - other => Role::Custom(other.to_string()), + "user" => Self::User, + "assistant" => Self::Assistant, + "system" => Self::System, + "tool" => Self::Tool, + other => Self::Custom(other.to_owned()), } } } impl From for Role { fn from(s: String) -> Self { - Role::from(s.as_str()) + Self::from(s.as_str()) } } impl From for String { fn from(role: Role) -> Self { - role.as_str().to_string() + role.as_str().to_owned() } } -/// A message in a conversation, containing a role and text content. -/// -/// Messages are the primary data structure for representing chat interactions, -/// AI conversations, and communication between nodes in the workflow system. -/// Each message has a role (typically "user", "assistant", or "system") and -/// text content. -/// -/// # Examples -/// -/// ``` -/// use weavegraph::message::{Message, Role}; -/// -/// // Using typed roles (recommended) -/// let user_msg = Message::with_role(Role::User, "What is the weather?"); -/// let assistant_msg = Message::with_role(Role::Assistant, "It's sunny today!"); -/// let system_msg = Message::with_role(Role::System, "You are a helpful assistant."); -/// -/// // Using Role enum directly -/// let msg = Message::with_role(Role::User, "Hello!"); -/// assert_eq!(msg.role, Role::User); -/// -/// // For custom roles -/// let function_msg = Message::with_role(Role::Custom("function".into()), "Result: 42"); -/// ``` +/// A single message in a conversation: a role paired with text content. #[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)] pub struct Message { - /// The role of the message sender. - /// - /// This field is serialized as a string for backward compatibility. - #[serde(with = "role_serde")] + /// Who sent the message. pub role: Role, - /// The text content of the message. + /// Text body of the message. pub content: String, } -mod role_serde { - use super::Role; - use serde::{Deserialize, Deserializer, Serializer}; - - pub fn serialize(role: &Role, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(role.as_str()) - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let s = String::deserialize(deserializer)?; - Ok(Role::from(s)) - } -} - impl Message { - /// Creates a new message with a typed [`Role`] and content. - /// - /// This is the recommended way to create messages with standard roles. - /// - /// # Examples - /// - /// ``` - /// use weavegraph::message::{Message, Role}; - /// - /// let msg = Message::with_role(Role::Assistant, "Hello!"); - /// assert_eq!(msg.role, Role::Assistant); - /// ``` + /// Construct a message with an explicit role and content. #[must_use] pub fn with_role(role: Role, content: &str) -> Self { - Self { - role, - content: content.to_string(), - } + Self { role, content: content.to_owned() } } - /// Creates a user message with the specified content. + /// Construct a `User` message. #[must_use] pub fn user(content: &str) -> Self { Self::with_role(Role::User, content) } - /// Creates an assistant message with the specified content. + /// Construct an `Assistant` message. #[must_use] pub fn assistant(content: &str) -> Self { Self::with_role(Role::Assistant, content) } - /// Creates a system message with the specified content. + /// Construct a `System` message. #[must_use] pub fn system(content: &str) -> Self { Self::with_role(Role::System, content) } - /// Creates a tool message with the specified content. + /// Construct a `Tool` message. #[must_use] pub fn tool(content: &str) -> Self { Self::with_role(Role::Tool, content) @@ -254,7 +171,6 @@ mod tests { #[test] fn test_message_backward_compatibility() { - // Old-style JSON should still parse let json = r#"{"role": "user", "content": "hello"}"#; let msg: Message = serde_json::from_str(json).unwrap(); assert_eq!(msg.role, Role::User); diff --git a/src/node.rs b/src/node.rs index e1c3892..b90301f 100644 --- a/src/node.rs +++ b/src/node.rs @@ -1,14 +1,13 @@ //! Node execution framework for the Weavegraph workflow system. //! -//! This module provides the core abstractions for executable workflow nodes, -//! including the [`Node`] trait, execution context, state updates, and error handling. -// Standard library and external crates +//! Provides the core abstractions for executable workflow nodes: +//! the [`Node`] trait, execution context, partial state updates, and error types. + use async_trait::async_trait; use rustc_hash::FxHashMap; use serde_json; use thiserror::Error; -// Internal crate modules use crate::channels::errors::ErrorEvent; use crate::control::{FrontierCommand, NodeRoute}; use crate::event_bus::{Event, EventEmitter, LLMStreamingEvent}; @@ -18,64 +17,29 @@ use crate::types::NodeKind; use crate::utils::clock::Clock; use std::sync::Arc; -// ============================================================================ -// Core Trait -// ============================================================================ - -/// Core trait defining executable workflow nodes. -/// -/// The `Node` trait represents a single unit of computation within a workflow. -/// Nodes receive the current state snapshot and execution context, perform -/// their work, and return partial state updates. +/// Core trait for executable workflow nodes. /// -/// # Design Principles +/// A node receives the current state snapshot and execution context, performs +/// its work, and returns partial state updates. /// -/// - **Stateless**: Nodes should be stateless and deterministic -/// - **Focused**: Each node should have a single, well-defined responsibility -/// - **Composable**: Nodes should be easily combined into larger workflows -/// - **Observable**: Use the context to emit events for monitoring and debugging +/// # Error handling /// -/// # Error Handling +/// - **Fatal errors**: return `Err(NodeError)` to halt execution. +/// - **Recoverable errors**: push into [`NodePartial::errors`] and return `Ok`. /// -/// Nodes can handle errors in two ways: -/// 1. **Fatal errors**: Return `Err(NodeError)` to stop workflow execution -/// 2. **Recoverable errors**: Add to `NodePartial.errors` and return `Ok` -/// -/// # Examples +/// # Example /// /// ```rust,no_run /// use weavegraph::node::{Node, NodeContext, NodePartial, NodeError}; /// use weavegraph::state::StateSnapshot; -/// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; /// use async_trait::async_trait; /// -/// struct ValidationNode { -/// required_fields: Vec, -/// } +/// struct GreetNode; /// /// #[async_trait] -/// impl Node for ValidationNode { -/// async fn run(&self, snapshot: StateSnapshot, ctx: NodeContext) -> Result { -/// ctx.emit("validation", "Starting validation")?; -/// -/// for field in &self.required_fields { -/// if !snapshot.extra.contains_key(field) { -/// return Err(NodeError::ValidationFailed(format!("Missing field: {}", field))); -/// } -/// } -/// -/// // Demonstrate the fluent API for success with warnings -/// if snapshot.messages.is_empty() { -/// let warning = ErrorEvent { -/// error: WeaveError { -/// message: "No messages to validate, but continuing".to_string(), -/// ..Default::default() -/// }, -/// ..Default::default() -/// }; -/// return Ok(NodePartial::new().with_errors(vec![warning])); -/// } -/// +/// impl Node for GreetNode { +/// async fn run(&self, _snapshot: StateSnapshot, ctx: NodeContext) -> Result { +/// ctx.emit("greet", "hello")?; /// Ok(NodePartial::default()) /// } /// } @@ -90,14 +54,10 @@ pub trait Node: Send + Sync { ) -> Result; } -// ============================================================================ -// Execution Context -// ============================================================================ - /// Execution context passed to nodes during workflow execution. /// -/// Provides nodes with access to their execution environment, including step -/// information, node identity, and communication channels for observability. +/// Carries the node's identity, current step, and communication channels +/// for observability. #[derive(Clone, Debug)] #[non_exhaustive] pub struct NodeContext { @@ -109,12 +69,12 @@ pub struct NodeContext { pub event_emitter: Arc, /// Optional runtime clock for deterministic tests and replay. pub clock: Option>, - /// Optional invocation or run identifier attached to node events. + /// Optional invocation identifier attached to node events. pub invocation_id: Option, } impl NodeContext { - /// Construct a node context with no runtime clock or invocation metadata. + /// Construct a context with no clock or invocation metadata. pub fn new( node_id: impl Into, step: u64, @@ -129,13 +89,13 @@ impl NodeContext { } } - /// Return the current runtime clock timestamp in Unix milliseconds, if configured. + /// Current runtime clock timestamp in Unix milliseconds, if configured. #[must_use] pub fn now_unix_ms(&self) -> Option { - self.clock.as_ref().map(|clock| clock.now_unix_ms()) + self.clock.as_ref().map(|c| c.now_unix_ms()) } - /// Return the invocation identifier, if configured. + /// Invocation identifier, if configured. #[must_use] pub fn invocation_id(&self) -> Option<&str> { self.invocation_id.as_deref() @@ -143,49 +103,37 @@ impl NodeContext { /// Emit a node-scoped event enriched with this context's metadata. /// - /// Creates structured events that include the node's ID and step information, - /// making them traceable in the workflow execution log. + /// Attaches `invocation_id` and `now_unix_ms` when available, making + /// emitted events traceable in the execution log. pub fn emit( &self, scope: impl Into, message: impl Into, - ) -> Result<(), NodeContextError> { - self.emit_node(scope, message) - } - - /// Emit a node event using this context's node identifier and step metadata. - pub fn emit_node( - &self, - scope: impl Into, - message: impl Into, ) -> Result<(), NodeContextError> { let mut metadata = FxHashMap::default(); - if let Some(invocation_id) = &self.invocation_id { + if let Some(id) = &self.invocation_id { metadata.insert( - "invocation_id".to_string(), - serde_json::Value::String(invocation_id.clone()), + "invocation_id".to_owned(), + serde_json::Value::String(id.clone()), ); } - if let Some(now_unix_ms) = self.now_unix_ms() { - metadata.insert("now_unix_ms".to_string(), serde_json::json!(now_unix_ms)); + if let Some(ts) = self.now_unix_ms() { + metadata.insert("now_unix_ms".to_owned(), serde_json::json!(ts)); } - - if metadata.is_empty() { - self.emit_event(Event::node_message_with_meta( - self.node_id.clone(), - self.step, - scope, - message, - )) + let event = if metadata.is_empty() { + Event::node_message_with_meta(self.node_id.clone(), self.step, scope, message) } else { - self.emit_event(Event::node_message_with_metadata( + Event::node_message_with_metadata( self.node_id.clone(), self.step, scope, message, metadata, - )) - } + ) + }; + self.event_emitter + .emit(event) + .map_err(|_| NodeContextError::EventBusUnavailable) } /// Emit a diagnostic event for general workflow telemetry. @@ -233,7 +181,7 @@ impl NodeContext { self.emit_event(Event::LLM(event)) } - /// Emit an LLM error event with the provided error message. + /// Emit an LLM error event. pub fn emit_llm_error( &self, session_id: Option, @@ -256,82 +204,52 @@ impl NodeContext { } } -// ============================================================================ -// State Updates -// ============================================================================ - /// Partial state updates returned by node execution. /// -/// Represents the changes a node wants to make to the workflow state. -/// All fields are optional, allowing nodes to update only the state aspects -/// they care about. The workflow runtime merges these partial updates. +/// All fields are optional; nodes update only the aspects of state they care +/// about. The workflow runtime merges these partial updates at each barrier. /// -/// # Examples +/// # Example /// /// ```rust /// use weavegraph::node::NodePartial; /// use weavegraph::message::{Message, Role}; -/// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; /// use serde_json::json; /// use weavegraph::utils::collections::new_extra_map; /// -/// // Simple message-only response -/// let partial = NodePartial::new() -/// .with_messages(vec![Message::with_role(Role::Assistant, "Done")]); -/// -/// // Rich response with metadata /// let mut extra = new_extra_map(); /// extra.insert("status".to_string(), json!("success")); -/// extra.insert("duration_ms".to_string(), json!(150)); -/// let partial = NodePartial::new() -/// .with_messages(vec![Message::with_role( -/// Role::Assistant, -/// "Processing complete", -/// )]) -/// .with_extra(extra); /// -/// // Response with warnings -/// let errors = vec![ErrorEvent { -/// error: WeaveError { -/// message: "Low confidence result".to_string(), -/// ..Default::default() -/// }, -/// ..Default::default() -/// }]; /// let partial = NodePartial::new() -/// .with_messages(vec![Message::with_role( -/// Role::Assistant, -/// "Result with warnings", -/// )]) -/// .with_errors(errors); +/// .with_messages(vec![Message::with_role(Role::Assistant, "Done")]) +/// .with_extra(extra); /// ``` #[derive(Clone, Debug, Default)] pub struct NodePartial { - /// Messages to add to the workflow's message history. + /// Messages to append to the workflow's message history. pub messages: Option>, - /// Additional key-value data to merge into the workflow's extra storage. + /// Key-value data to merge into the workflow's extra storage. pub extra: Option>, - /// Errors to add to the workflow's error collection. + /// Errors to record in the workflow's error collection. pub errors: Option>, - /// Frontier commands emitted by the node to influence subsequent routing. + /// Frontier command to influence subsequent routing. pub frontier: Option, } impl NodePartial { - /// Create an empty `NodePartial` with all fields set to `None`. + /// Create an empty partial with all fields unset. pub fn new() -> Self { - Self { - ..Default::default() - } + Self::default() } - /// Create a `NodePartial` with one or more messages. + + /// Set the messages to append. #[must_use] pub fn with_messages(mut self, messages: Vec) -> Self { self.messages = Some(messages); self } - /// Create a `NodePartial` with extra data. + /// Set the extra data to merge. #[must_use] pub fn with_extra(mut self, extra: FxHashMap) -> Self { self.extra = Some(extra); @@ -341,9 +259,8 @@ impl NodePartial { /// Insert a typed value into this partial's extra updates. /// /// The value is serialized to JSON and stored under the key returned by - /// [`StateKey::storage_key`]. If this partial already contains extra data, - /// the typed slot is merged into it and any existing value at the same - /// storage key is replaced. + /// [`StateKey::storage_key`]. Any existing value at the same storage key + /// is replaced. pub fn with_typed_extra( mut self, key: StateKey, @@ -361,16 +278,16 @@ impl NodePartial { Ok(self) } - /// Create a `NodePartial` with one or more errors. + /// Set the errors to record. #[must_use] pub fn with_errors(mut self, errors: Vec) -> Self { self.errors = Some(errors); self } - /// Replace the default frontier with the provided list of targets. + /// Replace the default frontier with the provided targets. /// - /// The runner will skip conditional edges for the originating node when a + /// The runner skips conditional edges for the originating node when a /// replace command is present. #[must_use] pub fn with_frontier_replace(mut self, targets: I) -> Self @@ -382,10 +299,10 @@ impl NodePartial { self } - /// Append additional targets to the frontier alongside the default routes. + /// Append additional targets to the frontier alongside default routes. /// - /// The default unconditional edges remain in place and the supplied - /// routes are appended in-order for deterministic processing. + /// Default unconditional edges remain in place; the supplied routes are + /// appended in order for deterministic processing. #[must_use] pub fn with_frontier_append(mut self, targets: I) -> Self where @@ -403,14 +320,12 @@ impl NodePartial { self } - /// Remove the given extra keys from state on the next barrier application. + /// Write `null` markers for the given extra keys. /// - /// Writes `serde_json::Value::Null` markers into the partial. [`MapMerge`](crate::reducers::MapMerge) - /// (the built-in extra reducer) follows JSON Merge Patch semantics (RFC 7396) and - /// **deletes** keys whose incoming value is `null`, so no separate cleanup reducer - /// is needed. + /// [`MapMerge`](crate::reducers::MapMerge) follows JSON Merge Patch semantics + /// (RFC 7396) and **deletes** keys whose incoming value is `null`. /// - /// # Examples + /// # Example /// /// ```rust /// use weavegraph::node::NodePartial; @@ -431,16 +346,12 @@ impl NodePartial { self } - /// Remove a single typed extra key from state on the next barrier application. + /// Write a `null` marker for a single typed extra key. /// - /// Typed companion to [`clear_extra_keys`](Self::clear_extra_keys). The storage key - /// is derived from the `StateKey`'s `(namespace, name, schema_version)` triple so - /// that the same constant used to write a value can be used to delete it. + /// Typed companion to [`clear_extra_keys`](Self::clear_extra_keys). The storage + /// key is derived from the `StateKey`'s `(namespace, name, schema_version)` triple. /// - /// [`MapMerge`](crate::reducers::MapMerge) deletes keys with null values (RFC 7396), - /// so no separate cleanup reducer is needed. - /// - /// # Examples + /// # Example /// /// ```rust /// use weavegraph::node::NodePartial; @@ -457,15 +368,11 @@ impl NodePartial { } } -// ============================================================================ -// Error Types -// ============================================================================ - -/// Errors that can occur when using NodeContext methods. +/// Errors that can occur when using [`NodeContext`] methods. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum NodeContextError { - /// Event could not be sent due to event bus disconnection or capacity issues. + /// Event could not be sent — event bus is disconnected or at capacity. #[error("failed to emit event: event bus unavailable")] #[cfg_attr( feature = "diagnostics", @@ -477,16 +384,15 @@ pub enum NodeContextError { EventBusUnavailable, } -/// Errors that can occur during node execution. +/// Fatal errors returned by node execution. /// -/// `NodeError` represents fatal errors that should halt workflow execution. -/// For recoverable errors that should be tracked but not halt execution, -/// use `NodePartial.errors` instead. +/// For recoverable errors that should be tracked without halting execution, +/// use [`NodePartial::errors`] instead. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] #[non_exhaustive] pub enum NodeError { - /// Expected input data is missing from the state snapshot. + /// Expected input data is absent from the state snapshot. #[error("missing expected input: {what}")] #[cfg_attr( feature = "diagnostics", @@ -496,7 +402,7 @@ pub enum NodeError { ) )] MissingInput { - /// Description of the missing input data. + /// Description of the missing input. what: &'static str, }, @@ -504,7 +410,7 @@ pub enum NodeError { #[error("provider error ({provider}): {message}")] #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::provider)))] Provider { - /// Name of the external provider that produced the error. + /// Name of the provider that failed. provider: &'static str, /// Human-readable description of the error. message: String, @@ -515,7 +421,7 @@ pub enum NodeError { #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::node::other)))] Other(#[from] Box), - /// JSON serialization/deserialization error. + /// JSON serialization or deserialization error. #[error(transparent)] #[cfg_attr( feature = "diagnostics", @@ -551,12 +457,12 @@ impl NodeError { } } -/// Canonical node result type for framework and user node implementations. +/// Canonical result type for node implementations. pub type NodeResult = std::result::Result; -/// Extension trait for ergonomic conversion into [`NodeError`]. +/// Ergonomic conversion into [`NodeError`] for the `?` operator. pub trait NodeResultExt { - /// Convert any error type into [`NodeError::Other`] for `?` propagation. + /// Convert any error into [`NodeError::Other`]. fn node_err(self) -> NodeResult; } diff --git a/src/state.rs b/src/state.rs index c6befb0..5116960 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,39 +1,10 @@ -//! State management for the Weavegraph workflow framework. +//! Versioned state management for workflow execution. //! -//! This module provides versioned state management with multiple channels -//! for different types of workflow data. State is managed through versioned -//! channels that support snapshotting, deep cloning, and persistence. +//! State is organized into three independent channels — messages, extras, and errors — +//! each carrying its own version number for change detection and optimistic concurrency. //! -//! # Core Types -//! -//! - [`VersionedState`]: The main state container with versioned channels -//! - [`StateSnapshot`]: Immutable snapshot of state at a point in time -//! -//! # Channels -//! -//! State is organized into three main channels: -//! - **Messages**: Conversation messages and chat data -//! - **Extra**: Custom metadata and intermediate results -//! - **Errors**: Error events and diagnostic information -//! -//! # Examples -//! -//! ```rust -//! use weavegraph::state::VersionedState; -//! use weavegraph::channels::Channel; -//! use serde_json::json; -//! -//! // Create initial state with user message -//! let mut state = VersionedState::new_with_user_message("Hello, world!"); -//! -//! // Add some metadata -//! state.extra.get_mut().insert("user_id".to_string(), json!("user123")); -//! -//! // Take snapshot for processing -//! let snapshot = state.snapshot(); -//! assert_eq!(snapshot.messages.len(), 1); -//! assert_eq!(snapshot.extra.get("user_id"), Some(&json!("user123"))); -//! ``` +//! The main types are [`VersionedState`] (mutable runtime state) and [`StateSnapshot`] +//! (point-in-time read-only view passed to nodes during execution). use rustc_hash::FxHashMap; use serde::{Serialize, de::DeserializeOwned}; @@ -49,62 +20,46 @@ use crate::{ /// Lifecycle classification for a state slot. /// -/// A slot's lifecycle is **metadata** — it does not change the storage key or -/// affect `PartialEq` / `Hash` comparisons. Two `StateKey` values with the -/// same `(namespace, name, schema_version)` but different lifecycle annotations -/// refer to the same underlying storage slot and compare as equal. +/// Lifecycle is **metadata** — it does not affect the storage key or identity comparisons. +/// Two `StateKey` values with the same `(namespace, name, schema_version)` but different +/// lifecycle annotations refer to the same slot and compare as equal. /// -/// Lifecycle is consumed by [`StateNormalizeProfile`](crate::runtimes::replay::StateNormalizeProfile) -/// and by [`NodePartial::clear_typed_extra_key`](crate::node::NodePartial::clear_typed_extra_key) +/// Consumed by [`StateNormalizeProfile`](crate::runtimes::replay::StateNormalizeProfile) +/// and [`NodePartial::clear_typed_extra_key`](crate::node::NodePartial::clear_typed_extra_key) /// to distinguish durable state from per-invocation scratch values. /// /// # Registration-time conflict detection /// -/// When you register a key with a lifecycle annotation (e.g. via -/// `StateNormalizeProfile::ignore_key`), the profile detects and panics on -/// conflicting annotations for the same storage key. This catches the common -/// mistake of defining the same slot constant twice with different lifecycle -/// policies. +/// When you register a key with a lifecycle annotation, the profile detects and panics on +/// conflicting annotations for the same storage key. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StateLifecycle { - /// The slot contains durable state that persists across invocations. - /// - /// This is the default. + /// Persists across invocations. This is the default. Durable, - /// The slot contains per-invocation scratch data that should be excluded - /// from durable state comparisons and resume normalization. + /// Per-invocation scratch data excluded from durable comparisons and resume normalization. InvocationScoped, } -/// A schema-versioned key for typed values stored in [`VersionedState::extra`]. +/// Schema-versioned key for typed values stored in [`VersionedState::extra`]. /// -/// `StateKey` is a thin helper over the JSON-compatible `extra` map. Domain -/// crates can define constants and use them from nodes, reducers, tests, and -/// replay code without repeating string literals. +/// Domain crates define constants using `StateKey` so nodes, reducers, tests, and replay +/// code can reference typed slots without repeating string literals. /// /// # Equality and hashing /// -/// `PartialEq`, `Eq`, and `Hash` are based solely on `(namespace, name, -/// schema_version)`. The `lifecycle` field is metadata and is **excluded** -/// from equality comparisons, so that two keys for the same slot compare equal -/// regardless of their lifecycle annotation. -/// -/// # Examples +/// `PartialEq`, `Eq`, and `Hash` are based solely on `(namespace, name, schema_version)`. +/// The `lifecycle` field is excluded so two keys for the same slot compare equal regardless +/// of lifecycle annotation. /// /// ```rust /// use serde::{Deserialize, Serialize}; /// use weavegraph::state::{StateKey, StateLifecycle}; /// /// #[derive(Serialize, Deserialize)] -/// struct PortfolioSnapshot { -/// cash: i64, -/// } +/// struct PortfolioSnapshot { cash: i64 } /// -/// const PORTFOLIO: StateKey = -/// StateKey::new("wq", "portfolio_snapshot", 1); -/// -/// const CURRENT_EVENT: StateKey = -/// StateKey::new("wq", "event", 1).invocation_scoped(); +/// const PORTFOLIO: StateKey = StateKey::new("wq", "portfolio_snapshot", 1); +/// const CURRENT_EVENT: StateKey = StateKey::new("wq", "event", 1).invocation_scoped(); /// /// assert_eq!(PORTFOLIO.storage_key(), "wq:portfolio_snapshot:v1"); /// assert_eq!(CURRENT_EVENT.lifecycle(), StateLifecycle::InvocationScoped); @@ -163,20 +118,17 @@ impl StateKey { } } - /// Return a copy of this key annotated as [`StateLifecycle::InvocationScoped`]. + /// Return a copy annotated as [`StateLifecycle::InvocationScoped`]. /// - /// The returned key compares equal to the original — lifecycle is metadata, - /// not identity. Use this when defining constants that represent - /// per-invocation scratch slots so that normalization profiles and cleanup - /// helpers can distinguish them from durable state. + /// The returned key compares equal to the original. Use this when defining constants + /// for per-invocation scratch slots so normalization and cleanup helpers can distinguish + /// them from durable state. /// /// ```rust /// use weavegraph::state::{StateKey, StateLifecycle}; /// - /// const TICK_EVENT: StateKey = - /// StateKey::new("wq", "tick_event", 1).invocation_scoped(); - /// - /// assert_eq!(TICK_EVENT.lifecycle(), StateLifecycle::InvocationScoped); + /// const TICK: StateKey = StateKey::new("wq", "tick", 1).invocation_scoped(); + /// assert_eq!(TICK.lifecycle(), StateLifecycle::InvocationScoped); /// ``` #[must_use] pub const fn invocation_scoped(mut self) -> Self { @@ -208,11 +160,10 @@ impl StateKey { self.schema_version } - /// Return the concrete `extra` map key used for storage. + /// Return the `extra` map key used for storage: `namespace:name:v{schema_version}`. /// - /// The format is `namespace:name:v{schema_version}`. Changing the schema - /// version intentionally writes to a different slot, avoiding silent - /// collisions between incompatible payload shapes. + /// Bumping the schema version writes to a new slot, preventing silent collisions + /// between incompatible payload shapes. #[must_use] pub fn storage_key(&self) -> String { format!("{}:{}:v{}", self.namespace, self.name, self.schema_version) @@ -224,117 +175,81 @@ impl StateKey { #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] #[non_exhaustive] pub enum StateSlotError { - /// The requested typed slot was not present in the state. + /// The requested slot was absent. #[error("state slot not found: {key}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::state::slot_missing)) )] Missing { - /// The concrete storage key that was not found. + /// Concrete storage key that was not found. key: String, }, - /// A typed slot value could not be serialized to JSON. + /// A slot value could not be serialized to JSON. #[error("failed to serialize state slot {key}: {source}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::state::slot_serialize)) )] Serialize { - /// The concrete storage key being written. + /// Concrete storage key being written. key: String, - /// The underlying serde error. + /// Underlying serde serialization error. #[source] source: serde_json::Error, }, - /// A typed slot value could not be deserialized from JSON. + /// A slot value could not be deserialized from JSON. #[error("failed to deserialize state slot {key}: {source}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::state::slot_deserialize)) )] Deserialize { - /// The concrete storage key being read. + /// Concrete storage key being read. key: String, - /// The underlying serde error. + /// Underlying serde deserialization error. #[source] source: serde_json::Error, }, } -/// The main state container for workflow execution. -/// -/// `VersionedState` manages three independent channels of versioned data: -/// messages, custom extras, and error events. Each channel maintains its own -/// version number for optimistic concurrency control and change detection. -/// -/// # Channels -/// -/// - **messages**: Chat messages and conversation data ([`MessagesChannel`]) -/// - **extra**: Custom metadata and intermediate results ([`ExtrasChannel`]) -/// - **errors**: Error events and diagnostics ([`ErrorsChannel`]) +/// Runtime state container for workflow execution. /// -/// # Examples +/// Manages three independent channels of versioned data: messages, custom extras, and +/// error events. Each channel tracks its own version number for change detection. /// /// ```rust /// use weavegraph::state::VersionedState; -/// use weavegraph::message::{Message, Role}; /// use weavegraph::channels::Channel; /// use serde_json::json; /// -/// // Initialize with user message -/// let mut state = VersionedState::new_with_user_message("Process this data"); +/// let mut state = VersionedState::new_with_user_message("Process this"); +/// state.add_extra("session_id", json!("sess_123")); /// -/// // Add metadata -/// state.extra.get_mut().insert("session_id".to_string(), json!("sess_123")); -/// state.extra.get_mut().insert("priority".to_string(), json!("high")); -/// -/// // Add assistant response -/// state -/// .messages -/// .get_mut() -/// .push(Message::with_role(Role::Assistant, "Processing your request...")); -/// -/// // Take snapshot /// let snapshot = state.snapshot(); -/// assert_eq!(snapshot.messages.len(), 2); -/// assert_eq!(snapshot.extra.len(), 2); +/// assert_eq!(snapshot.messages.len(), 1); +/// assert_eq!(snapshot.extra.get("session_id"), Some(&json!("sess_123"))); +/// ``` +/// let snapshot = state.snapshot(); +/// assert_eq!(snapshot.messages.len(), 1); +/// assert_eq!(snapshot.extra.get("session_id"), Some(&json!("sess_123"))); /// ``` #[derive(Clone, Debug, PartialEq, Eq)] pub struct VersionedState { - /// Message channel containing conversation data + /// Conversation messages. pub messages: MessagesChannel, - /// Extra channel for custom metadata and intermediate results + /// Custom metadata and intermediate results. pub extra: ExtrasChannel, - /// Error channel for diagnostic information + /// Error events and diagnostics. pub errors: ErrorsChannel, } -/// Immutable snapshot of workflow state at a specific point in time. +/// Immutable point-in-time view of workflow state. /// -/// `StateSnapshot` provides a read-only view of the state that nodes can -/// safely access during execution without affecting the underlying state. -/// It contains cloned data from the messages and extra channels along with -/// their version numbers. -/// -/// # Fields -/// -/// - `messages`: Cloned message data at snapshot time -/// - `messages_version`: Version of messages channel when snapshot was taken -/// - `extra`: Cloned extra data at snapshot time -/// - `extra_version`: Version of extra channel when snapshot was taken -/// - `errors`: Cloned error events at snapshot time -/// - `errors_version`: Version of errors channel when snapshot was taken -/// -/// # Usage -/// -/// Snapshots are automatically created by [`VersionedState::snapshot()`] and -/// passed to nodes during workflow execution. Nodes should treat snapshots -/// as immutable input data. -/// -/// # Examples +/// Created by [`VersionedState::snapshot`] and passed to nodes during execution. +/// Contains cloned channel data and version numbers at the moment of the snapshot. /// /// ```rust /// use weavegraph::state::VersionedState; @@ -345,84 +260,47 @@ pub struct VersionedState { /// state.extra.get_mut().insert("key".to_string(), json!("value")); /// /// let snapshot = state.snapshot(); -/// -/// // Snapshot is independent of original state /// state.extra.get_mut().clear(); +/// /// assert_eq!(snapshot.extra.get("key"), Some(&json!("value"))); /// assert!(state.extra.snapshot().is_empty()); /// ``` #[derive(Clone, Debug)] pub struct StateSnapshot { - /// Messages at the time of snapshot + /// Messages at snapshot time. pub messages: Vec, - /// Version of messages channel when snapshot was taken + /// Version of the messages channel when snapshotted. pub messages_version: u32, - /// Extra data at the time of snapshot + /// Extra data at snapshot time. pub extra: FxHashMap, - /// Version of extra channel when snapshot was taken + /// Version of the extra channel when snapshotted. pub extra_version: u32, - /// Error events at the time of snapshot + /// Error events at snapshot time. pub errors: Vec, - /// Version of errors channel when snapshot was taken + /// Version of the errors channel when snapshotted. pub errors_version: u32, } impl VersionedState { - /// Creates a new versioned state initialized with a user message. - /// - /// This is the primary constructor for starting workflow execution. - /// text as the first user message. - /// - /// # Parameters - /// - /// - `user_text`: The initial user message content - /// - /// # Returns - /// - /// A new `VersionedState` with: - /// - One user message in the messages channel - /// - Empty extra and error channels - /// - All channels initialized to version 1 - /// - /// # Examples + /// Construct state initialized with a single user message. /// /// ```rust /// use weavegraph::state::VersionedState; /// - /// let state = VersionedState::new_with_user_message("Analyze this data"); - /// let snapshot = state.snapshot(); - /// - /// assert_eq!(snapshot.messages.len(), 1); - /// assert_eq!(snapshot.messages[0].role, weavegraph::message::Role::User); - /// assert_eq!(snapshot.messages[0].content, "Analyze this data"); - /// assert_eq!(snapshot.messages_version, 1); - /// assert!(snapshot.extra.is_empty()); + /// let state = VersionedState::new_with_user_message("Analyze this"); + /// let snap = state.snapshot(); + /// assert_eq!(snap.messages.len(), 1); + /// assert_eq!(snap.messages[0].role, weavegraph::message::Role::User); /// ``` pub fn new_with_user_message(user_text: &str) -> Self { - let messages = vec![Message::with_role(Role::User, user_text)]; Self { - messages: MessagesChannel::new(messages, 1), + messages: MessagesChannel::new(vec![Message::with_role(Role::User, user_text)], 1), extra: ExtrasChannel::default(), errors: ErrorsChannel::default(), } } - /// Creates a new versioned state initialized with a vector of messages. - /// - /// This constructor is useful for starting a workflow with an existing chat history. - /// - /// # Parameters - /// - /// - `messages`: The initial messages content - /// - /// # Returns - /// - /// A new `VersionedState` with: - /// - Multiple messages in the messages channel - /// - Empty extra and error channels - /// - All channels initialized to version 1 - /// - /// # Examples + /// Construct state initialized with an existing message list. /// /// ```rust /// use weavegraph::state::VersionedState; @@ -430,17 +308,10 @@ impl VersionedState { /// /// let messages = vec![ /// Message::with_role(Role::User, "Explain error handling in Rust"), - /// Message::with_role( - /// Role::Assistant, - /// "Use Result and the ? operator to propagate errors cleanly.", - /// ), + /// Message::with_role(Role::Assistant, "Use Result and the ? operator."), /// ]; /// let state = VersionedState::new_with_messages(messages); - /// let snapshot = state.snapshot(); - /// - /// assert_eq!(snapshot.messages.len(), 2); - /// assert_eq!(snapshot.messages_version, 1); - /// assert!(snapshot.extra.is_empty()); + /// assert_eq!(state.snapshot().messages.len(), 2); /// ``` pub fn new_with_messages(messages: Vec) -> Self { Self { @@ -450,60 +321,12 @@ impl VersionedState { } } - /// Creates a builder for constructing VersionedState with fluent API. - /// - /// The builder pattern provides an ergonomic way to construct state - /// with custom initial data, versions, and multiple messages. - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// use weavegraph::channels::Channel; - /// use serde_json::json; - /// - /// let state = VersionedState::builder() - /// .with_user_message("Hello, assistant!") - /// .with_assistant_message("Hello! How can I help you?") - /// .with_extra("session_id", json!("session_123")) - /// .with_extra("priority", json!("high")) - /// .build(); - /// - /// let snapshot = state.snapshot(); - /// assert_eq!(snapshot.messages.len(), 2); - /// assert_eq!(snapshot.extra.len(), 2); - /// ``` + /// Return a builder for fluent state construction. pub fn builder() -> VersionedStateBuilder { - VersionedStateBuilder::new() + VersionedStateBuilder::default() } - /// Convenience method for adding a message to the state. - /// - /// This method adds a message with the specified role and content - /// to the messages channel. The version is not automatically incremented - /// as that's handled by the barrier system. - /// - /// # Parameters - /// - /// - `role`: The role of the message sender (e.g., "user", "assistant", "system") - /// - `content`: The message content - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// - /// let mut state = VersionedState::new_with_user_message("Initial message"); - /// state.add_message( - /// weavegraph::message::Role::Assistant.as_str(), - /// "I understand your request.", - /// ); - /// - /// let snapshot = state.snapshot(); - /// assert_eq!(snapshot.messages.len(), 2); - /// assert_eq!(snapshot.messages[1].role, weavegraph::message::Role::Assistant); - /// ``` - #[must_use = "consider using the returned self for method chaining"] + /// Append a message to the messages channel. pub fn add_message(&mut self, role: &str, content: &str) -> &mut Self { self.messages .get_mut() @@ -511,43 +334,13 @@ impl VersionedState { self } - /// Convenience method for adding metadata to the extra channel. - /// - /// This method adds a key-value pair to the extra channel for custom - /// metadata and intermediate results. The version is not automatically - /// incremented as that's handled by the barrier system. - /// - /// # Parameters - /// - /// - `key`: The metadata key - /// - `value`: The metadata value as a serde_json::Value - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// use weavegraph::channels::Channel; - /// use serde_json::json; - /// - /// let mut state = VersionedState::new_with_user_message("Test"); - /// state.add_extra("user_id", json!("user_123")) - /// .add_extra("timestamp", json!(1234567890)); - /// - /// let snapshot = state.snapshot(); - /// assert_eq!(snapshot.extra.len(), 2); - /// assert_eq!(snapshot.extra.get("user_id"), Some(&json!("user_123"))); - /// ``` - #[must_use = "consider using the returned self for method chaining"] + /// Insert a key-value pair into the extra channel. pub fn add_extra(&mut self, key: &str, value: Value) -> &mut Self { - self.extra.get_mut().insert(key.to_string(), value); + self.extra.get_mut().insert(key.to_owned(), value); self } - /// Adds a typed value to the extra channel using a schema-versioned key. - /// - /// The value is serialized to JSON and stored under - /// [`StateKey::storage_key`]. The channel version is still advanced by the - /// normal barrier system during graph execution. + /// Serialize `value` and insert it under `key.storage_key()` in the extra channel. pub fn add_typed_extra( &mut self, key: StateKey, @@ -563,24 +356,7 @@ impl VersionedState { Ok(self) } - /// Creates an immutable snapshot of the current state. - /// - /// This method clones the current channel data and version numbers, - /// creating a point-in-time view that is safe to access concurrently - /// while the original state may be modified. - /// - /// # Returns - /// - /// A [`StateSnapshot`] containing cloned data from messages and extra - /// channels along with their current version numbers. - /// - /// # Performance - /// - /// This operation clones all channel data, so it has O(n) complexity - /// relative to the amount of data in the channels. For large states, - /// consider whether the snapshot is necessary. - /// - /// # Examples + /// Clone all channel data and version numbers into an immutable [`StateSnapshot`]. /// /// ```rust /// use weavegraph::state::VersionedState; @@ -589,14 +365,10 @@ impl VersionedState { /// /// let mut state = VersionedState::new_with_user_message("Test"); /// state.extra.get_mut().insert("status".to_string(), json!("processing")); - /// /// let snapshot = state.snapshot(); - /// - /// // Snapshot is independent - mutations don't affect it /// state.extra.get_mut().insert("status".to_string(), json!("complete")); /// /// assert_eq!(snapshot.extra.get("status"), Some(&json!("processing"))); - /// assert_eq!(state.extra.snapshot().get("status"), Some(&json!("complete"))); /// ``` pub fn snapshot(&self) -> StateSnapshot { StateSnapshot { @@ -611,10 +383,9 @@ impl VersionedState { } impl StateSnapshot { - /// Read an optional typed value from the extra channel. + /// Deserialize an optional typed value from the extra channel. /// - /// Returns `Ok(None)` when the slot is absent. Deserialization errors are - /// reported with the concrete storage key. + /// Returns `Ok(None)` when the slot is absent. pub fn get_typed( &self, key: StateKey, @@ -623,8 +394,8 @@ impl StateSnapshot { self.extra .get(&storage_key) .cloned() - .map(|value| { - serde_json::from_value(value).map_err(|source| StateSlotError::Deserialize { + .map(|v| { + serde_json::from_value(v).map_err(|source| StateSlotError::Deserialize { key: storage_key, source, }) @@ -632,9 +403,9 @@ impl StateSnapshot { .transpose() } - /// Read a required typed value from the extra channel. + /// Deserialize a required typed value from the extra channel. /// - /// Use this when a node cannot proceed without a specific typed slot. + /// Returns [`StateSlotError::Missing`] when the slot is absent. pub fn require_typed( &self, key: StateKey, @@ -645,14 +416,7 @@ impl StateSnapshot { } } -/// Builder for constructing VersionedState with fluent API. -/// -/// `VersionedStateBuilder` provides an ergonomic way to construct workflow state -/// with custom initial data, multiple messages, and metadata. This is particularly -/// useful when setting up complex initial states for testing or when restoring -/// state from persistence. -/// -/// # Examples +/// Fluent builder for [`VersionedState`]. /// /// ```rust /// use weavegraph::state::VersionedState; @@ -662,14 +426,10 @@ impl StateSnapshot { /// let state = VersionedState::builder() /// .with_user_message("What's the weather like?") /// .with_assistant_message("I'll help you check the weather.") -/// .with_system_message("Weather API access enabled") /// .with_extra("location", json!("New York")) -/// .with_extra("units", json!("celsius")) /// .build(); /// -/// let snapshot = state.snapshot(); -/// assert_eq!(snapshot.messages.len(), 3); -/// assert_eq!(snapshot.extra.len(), 2); +/// assert_eq!(state.snapshot().messages.len(), 2); /// ``` #[derive(Debug, Default)] pub struct VersionedStateBuilder { @@ -678,121 +438,37 @@ pub struct VersionedStateBuilder { } impl VersionedStateBuilder { - /// Creates a new empty builder. - fn new() -> Self { - Self::default() - } - - /// Adds a user message to the builder. - /// - /// # Parameters - /// - /// - `content`: The user message content - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// - /// let state = VersionedState::builder() - /// .with_user_message("Hello") - /// .build(); - /// ``` + /// Append a user message. pub fn with_user_message(mut self, content: &str) -> Self { self.messages.push(Message::with_role(Role::User, content)); self } - /// Adds an assistant message to the builder. - /// - /// # Parameters - /// - /// - `content`: The assistant message content - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// - /// let state = VersionedState::builder() - /// .with_user_message("Hello") - /// .with_assistant_message("Hi there!") - /// .build(); - /// ``` + /// Append an assistant message. pub fn with_assistant_message(mut self, content: &str) -> Self { - self.messages - .push(Message::with_role(Role::Assistant, content)); + self.messages.push(Message::with_role(Role::Assistant, content)); self } - /// Adds a system message to the builder. - /// - /// # Parameters - /// - /// - `content`: The system message content - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// - /// let state = VersionedState::builder() - /// .with_system_message("Session started") - /// .with_user_message("Hello") - /// .build(); - /// ``` + /// Append a system message. pub fn with_system_message(mut self, content: &str) -> Self { - self.messages - .push(Message::with_role(Role::System, content)); + self.messages.push(Message::with_role(Role::System, content)); self } - /// Adds a custom message with specified role to the builder. - /// - /// # Parameters - /// - /// - `role`: The message role - /// - `content`: The message content - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// - /// let state = VersionedState::builder() - /// .with_message("function", "API call result") - /// .build(); - /// ``` + /// Append a message with a custom role. pub fn with_message(mut self, role: &str, content: &str) -> Self { - self.messages - .push(Message::with_role(Role::from(role), content)); + self.messages.push(Message::with_role(Role::from(role), content)); self } - /// Adds metadata to the extra channel. - /// - /// # Parameters - /// - /// - `key`: The metadata key - /// - `value`: The metadata value - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// use serde_json::json; - /// - /// let state = VersionedState::builder() - /// .with_user_message("Hello") - /// .with_extra("session_id", json!("sess_123")) - /// .build(); - /// ``` + /// Insert a key-value pair into the extra channel. pub fn with_extra(mut self, key: &str, value: Value) -> Self { - self.extra.insert(key.to_string(), value); + self.extra.insert(key.to_owned(), value); self } - /// Adds a typed value to the extra channel using a schema-versioned key. + /// Serialize `value` and insert it under `key.storage_key()`. pub fn with_typed_extra( mut self, key: StateKey, @@ -808,31 +484,7 @@ impl VersionedStateBuilder { Ok(self) } - /// Builds the final VersionedState. - /// - /// Creates a new VersionedState with all the configured messages and metadata. - /// All channels are initialized with version 1. If no messages were added, - /// the messages channel will be empty. - /// - /// # Returns - /// - /// A fully constructed `VersionedState` - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::state::VersionedState; - /// use serde_json::json; - /// - /// let state = VersionedState::builder() - /// .with_user_message("Hello") - /// .with_extra("key", json!("value")) - /// .build(); - /// - /// let snapshot = state.snapshot(); - /// assert_eq!(snapshot.messages.len(), 1); - /// assert_eq!(snapshot.extra.len(), 1); - /// ``` + /// Construct the [`VersionedState`]. All channels start at version 1. pub fn build(self) -> VersionedState { VersionedState { messages: MessagesChannel::new(self.messages, 1), diff --git a/src/types.rs b/src/types.rs index 467d953..484fcb8 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,105 +1,37 @@ -//! Core types for the Weavegraph workflow framework. +//! Core domain types for workflow graphs. //! -//! This module defines the fundamental types used throughout the Weavegraph system -//! for identifying nodes and channels in workflow graphs. These are the core -//! domain concepts that define what a workflow *is*. -//! -//! For runtime execution types (session IDs, step numbers), see [`crate::runtimes::types`]. -//! -//! # Key Types -//! -//! - [`NodeKind`]: Identifies different types of nodes in a workflow graph -//! - [`ChannelType`]: Identifies different types of data channels for state management -//! -//! # Type Organization -//! -//! Weavegraph organizes types by conceptual domain: -//! -//! - **Core types** (this module): Fundamental workflow concepts (`NodeKind`, `ChannelType`) -//! - **Runtime types** ([`crate::runtimes::types`]): Execution infrastructure (`SessionId`, `StepNumber`) -//! - **Utility types**: Domain-specific helpers in their respective modules -//! -//! # Examples -//! -//! ```rust -//! use weavegraph::types::{NodeKind, ChannelType}; -//! -//! // Create different types of nodes -//! let start = NodeKind::Start; -//! let custom = NodeKind::Custom("ProcessData".to_string()); -//! let end = NodeKind::End; -//! -//! // Encode for persistence -//! let encoded = custom.encode(); -//! assert_eq!(encoded, "Custom:ProcessData"); -//! -//! // Work with channels -//! let msg_channel = ChannelType::Message; -//! println!("Channel: {}", msg_channel); -//! ``` +//! Defines [`NodeKind`] (node identity) and [`ChannelType`] (state channel category). +//! For runtime execution types such as session IDs and step numbers, see +//! [`crate::runtimes::types`]. use serde::{Deserialize, Serialize}; use std::fmt; -/// Identifies the type of a node within a workflow graph. -/// -/// `NodeKind` serves as a unique identifier for nodes in the workflow execution graph. -/// It provides special handling for common workflow patterns (start/end nodes) while -/// allowing arbitrary custom node types through the `Other` variant. +/// Node identity within a workflow graph. /// -/// # Persistence -/// -/// `NodeKind` supports serialization for checkpointing and persistence through both -/// serde and the [`encode`](Self::encode)/[`decode`](Self::decode) methods. -/// -/// # Examples +/// `NodeKind` labels each node in the execution graph. `Start` and `End` are virtual +/// bookend nodes with no user implementation; `Custom` covers all application-defined +/// nodes. Supports serde and the [`encode`](Self::encode)/[`decode`](Self::decode) +/// round-trip for persistence. /// /// ```rust /// use weavegraph::types::NodeKind; /// -/// // Special workflow nodes -/// let start = NodeKind::Start; -/// let end = NodeKind::End; -/// -/// // Custom application nodes /// let processor = NodeKind::Custom("DataProcessor".to_string()); -/// -/// // Persistence round-trip -/// let encoded = processor.encode(); -/// let decoded = NodeKind::decode(&encoded); -/// assert_eq!(processor, decoded); +/// assert_eq!(NodeKind::decode(&processor.encode()), processor); /// ``` #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum NodeKind { - /// Entry point node that begins workflow execution. - /// - /// Start nodes are virtual nodes that should not be implemented, they have no incoming edges and serve as the initial - /// frontier for workflow execution. - /// The first edge for each graph execution must start from a virtual Start node. + /// Virtual entry node. Has no incoming edges; seeds the initial frontier. Start, - - /// Terminal node that completes workflow execution. - /// - /// End nodes are virtual nodes that should not be implemented, they have no outgoing edges and signal - /// the completion of a workflow branch. + /// Virtual terminal node. Has no outgoing edges; signals branch completion. End, - - /// Custom node type identified by a user-defined string. - /// - /// The string should be descriptive and unique within the workflow. - /// Common patterns include function names, service names, or step descriptions. + /// Application-defined node identified by a user-supplied name. Custom(String), } impl NodeKind { - /// Encode a NodeKind into its persisted string form. - /// - /// The encoding format is human-readable and forward-compatible: - /// - `Start` → `"Start"` - /// - `End` → `"End"` - /// - `Custom("X")` → `"Custom:X"` - /// - /// # Examples + /// Encode into a persisted string: `"Start"`, `"End"`, or `"Custom:X"`. /// /// ```rust /// # use weavegraph::types::NodeKind; @@ -109,163 +41,113 @@ impl NodeKind { #[must_use] pub fn encode(&self) -> String { match self { - NodeKind::Start => "Start".to_string(), - NodeKind::End => "End".to_string(), - NodeKind::Custom(s) => format!("Custom:{s}"), + Self::Start => "Start".to_owned(), + Self::End => "End".to_owned(), + Self::Custom(s) => format!("Custom:{s}"), } } - /// Decode a persisted string form back into a NodeKind. - /// - /// This method provides forward compatibility by falling back to - /// `Other(s)` for any unrecognized format. + /// Decode a persisted string back into a `NodeKind`. /// - /// # Examples + /// Unrecognized strings that lack a `"Custom:"` prefix are treated as + /// `Custom(s)` for forward compatibility. /// /// ```rust /// # use weavegraph::types::NodeKind; /// assert_eq!(NodeKind::decode("Start"), NodeKind::Start); /// assert_eq!(NodeKind::decode("Custom:Processor"), NodeKind::Custom("Processor".to_string())); - /// - /// // Forward compatibility - unknown formats become Other /// assert_eq!(NodeKind::decode("Unknown"), NodeKind::Custom("Unknown".to_string())); /// ``` pub fn decode(s: &str) -> Self { - if s == "Start" { - NodeKind::Start - } else if s == "End" { - NodeKind::End - } else if let Some(rest) = s.strip_prefix("Custom:") { - NodeKind::Custom(rest.to_string()) - } else { - NodeKind::Custom(s.to_string()) + match s { + "Start" => Self::Start, + "End" => Self::End, + _ => Self::Custom(s.strip_prefix("Custom:").unwrap_or(s).to_owned()), } } - /// Returns `true` if this is a [`Start`](Self::Start) node. + /// Return `true` if this is a [`Start`](Self::Start) node. #[must_use] pub fn is_start(&self) -> bool { matches!(self, Self::Start) } - /// Returns `true` if this is an [`End`](Self::End) node. + /// Return `true` if this is an [`End`](Self::End) node. #[must_use] pub fn is_end(&self) -> bool { matches!(self, Self::End) } - /// Returns `true` if this is a custom node. + /// Return `true` if this is a [`Custom`](Self::Custom) node. #[must_use] pub fn is_custom(&self) -> bool { matches!(self, Self::Custom(_)) } - /// Convert this NodeKind into a predicate target string. - /// - /// This is a convenience for building conditional edge predicates - /// without relying on raw string literals. - /// - /// Examples - /// ```rust - /// # use weavegraph::types::NodeKind; - /// let nk = NodeKind::Custom("route".into()); - /// assert_eq!(nk.as_target(), "route"); - /// ``` + /// Return the `Display` string of this node for use as a conditional-edge target. #[must_use] pub fn as_target(&self) -> String { self.to_string() } - /// Convenience: return the canonical target string for the Start endpoint. - /// - /// Examples - /// ```rust - /// # use weavegraph::types::NodeKind; - /// assert_eq!(NodeKind::start_target(), "Start"); - /// ``` + /// Return the canonical target string for the `Start` endpoint. #[must_use] pub fn start_target() -> String { - "Start".to_string() + "Start".to_owned() } - /// Convenience: return the canonical target string for the End endpoint. - /// - /// Examples - /// ```rust - /// # use weavegraph::types::NodeKind; - /// assert_eq!(NodeKind::end_target(), "End"); - /// ``` + /// Return the canonical target string for the `End` endpoint. #[must_use] pub fn end_target() -> String { - "End".to_string() + "End".to_owned() } } impl fmt::Display for NodeKind { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Start => write!(f, "Start"), - Self::End => write!(f, "End"), - Self::Custom(name) => write!(f, "{}", name), + Self::Start => f.write_str("Start"), + Self::End => f.write_str("End"), + Self::Custom(name) => f.write_str(name), } } } -// Developer Experience: allow using string literals where a NodeKind is expected. impl From<&str> for NodeKind { fn from(s: &str) -> Self { match s { - "Start" => NodeKind::Start, - "End" => NodeKind::End, - other => NodeKind::Custom(other.to_string()), + "Start" => Self::Start, + "End" => Self::End, + other => Self::Custom(other.to_owned()), } } } -/// Identifies the type of data channel used for state management. +/// State channel category. /// -/// `ChannelType` represents the different categories of state data that -/// can be managed within the workflow system. Each channel type has -/// its own reducer and update semantics. -/// -/// # Examples +/// Each variant corresponds to an independent channel in [`VersionedState`](crate::state::VersionedState) +/// with its own reducer and versioning semantics. /// /// ```rust /// use weavegraph::types::ChannelType; -/// -/// let msg_channel = ChannelType::Message; -/// let err_channel = ChannelType::Error; -/// let meta_channel = ChannelType::Extra; -/// -/// println!("Processing {} channel", msg_channel); +/// assert_eq!(ChannelType::Message.to_string(), "message"); /// ``` #[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] pub enum ChannelType { - /// Channel for chat messages and conversation data. - /// - /// Manages the sequence of messages that flow through the workflow, - /// including user inputs, assistant responses, and system notifications. + /// Conversation messages: user inputs, assistant responses, system notifications. Message, - - /// Channel for error events and diagnostic information. - /// - /// Collects both fatal errors that halt execution and non-fatal - /// errors that should be tracked for debugging and monitoring. + /// Error events: fatal halts and non-fatal diagnostics. Error, - - /// Channel for custom metadata and intermediate results. - /// - /// Provides a flexible key-value store for custom data that nodes - /// need to share, including configuration and intermediate computations. + /// Custom key-value metadata and intermediate node outputs. Extra, } impl fmt::Display for ChannelType { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Message => write!(f, "message"), - Self::Error => write!(f, "error"), - Self::Extra => write!(f, "extra"), + Self::Message => f.write_str("message"), + Self::Error => f.write_str("error"), + Self::Extra => f.write_str("extra"), } } } From f3ef77981da46966aebbc7b024a2671688be392d Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 14:37:21 -0400 Subject: [PATCH 03/15] heavy revision work for channels and event_bus --- src/channels/errors.rs | 264 ++++++----------------- src/channels/errors_channel.rs | 13 +- src/channels/extras.rs | 32 ++- src/channels/messages.rs | 32 ++- src/channels/mod.rs | 2 +- src/event_bus/bus.rs | 358 ++++++++---------------------- src/event_bus/diagnostics.rs | 68 +++--- src/event_bus/emitter.rs | 20 +- src/event_bus/event.rs | 384 ++++++++++++++++++--------------- src/event_bus/hub.rs | 154 +++++-------- src/event_bus/mod.rs | 16 +- src/event_bus/sink.rs | 327 +++++----------------------- tests/event_bus.rs | 53 ++--- 13 files changed, 567 insertions(+), 1156 deletions(-) diff --git a/src/channels/errors.rs b/src/channels/errors.rs index b35edc6..d7c6fca 100644 --- a/src/channels/errors.rs +++ b/src/channels/errors.rs @@ -1,54 +1,32 @@ -//! Error event types used to capture and propagate structured errors through the workflow. +//! Error event types for structured error capture and propagation through the workflow. use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use crate::telemetry::{FormatterMode, PlainFormatter, TelemetryFormatter}; -// Avoid depending on serde for NodeKind by using encoded string form for kind. - -/// Represents an error event with scope, error details, tags, and context. -/// -/// # JSON Serialization Format +/// A workflow error event capturing scope, error payload, tags, and context. /// -/// `ErrorEvent` serializes to JSON with the following structure: +/// Serializes to JSON with a tagged-union `scope` field: /// /// ```json /// { /// "when": "2025-11-02T10:30:00Z", -/// "scope": { -/// "scope": "node", -/// "kind": "Parser", -/// "step": 1 -/// }, +/// "scope": { "scope": "node", "kind": "Parser", "step": 1 }, /// "error": { /// "message": "Failed to parse input", -/// "cause": { -/// "message": "Invalid JSON syntax", -/// "cause": null, -/// "details": {"line": 3, "column": 15} -/// }, -/// "details": {"input_length": 1024} +/// "cause": { "message": "Invalid JSON syntax", "details": {"line": 3} } /// }, -/// "tags": ["validation", "retryable"], -/// "context": { -/// "file": "/tmp/input.json", -/// "user_id": 12345 -/// } +/// "tags": ["validation"], +/// "context": {"file": "/tmp/input.json"} /// } /// ``` /// -/// The `scope` field uses a tagged union format with a discriminator field named `"scope"`. -/// Supported scope variants are: -/// - `"node"`: Requires `kind` (string) and `step` (u64) -/// - `"scheduler"`: Requires `step` (u64) -/// - `"runner"`: Requires `session` (string) and `step` (u64) -/// - `"app"`: No additional fields +/// Scope variants: `"node"` (kind, step), `"scheduler"` (step), +/// `"runner"` (session, step), `"app"`. /// -/// See `docs/schemas/error_event.json` for the complete JSON Schema specification. +/// See `docs/schemas/error_event.json` for the full JSON Schema. /// -/// # Examples -/// -/// Using constructors and builders: +/// # Example /// /// ``` /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; @@ -58,7 +36,6 @@ use crate::telemetry::{FormatterMode, PlainFormatter, TelemetryFormatter}; /// .with_tag("validation") /// .with_context(json!({"line": 42})); /// -/// // Serialize to JSON /// let json_str = serde_json::to_string(&event).unwrap(); /// ``` #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] @@ -66,199 +43,116 @@ pub struct ErrorEvent { /// Timestamp at which the error occurred. #[serde(default = "chrono::Utc::now")] pub when: DateTime, - /// Scope identifying where in the workflow the error originated. + /// Where in the workflow the error originated. #[serde(default)] pub scope: ErrorScope, - /// Structured error payload describing the failure. + /// Structured error payload. #[serde(default)] pub error: WeaveError, - /// Arbitrary string tags for filtering and categorization. + /// String tags for filtering and categorization. #[serde(default)] pub tags: Vec, - /// Optional additional context data as a JSON value. + /// Optional structured context metadata. #[serde(default)] pub context: serde_json::Value, } impl ErrorEvent { - /// Create a node-scoped error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// - /// let err = ErrorEvent::node("my_node", 1, WeaveError::msg("Something failed")); - /// ``` - pub fn node>(kind: S, step: u64, error: WeaveError) -> Self { + fn with_scope(scope: ErrorScope, error: WeaveError) -> Self { Self { when: Utc::now(), - scope: ErrorScope::Node { - kind: kind.into(), - step, - }, + scope, error, tags: Vec::new(), context: serde_json::Value::Null, } } - /// Create a scheduler-scoped error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// - /// let err = ErrorEvent::scheduler(5, WeaveError::msg("Scheduling conflict")); - /// ``` + /// Creates a node-scoped error event. + pub fn node>(kind: S, step: u64, error: WeaveError) -> Self { + Self::with_scope(ErrorScope::Node { kind: kind.into(), step }, error) + } + + /// Creates a scheduler-scoped error event. pub fn scheduler(step: u64, error: WeaveError) -> Self { - Self { - when: Utc::now(), - scope: ErrorScope::Scheduler { step }, - error, - tags: Vec::new(), - context: serde_json::Value::Null, - } + Self::with_scope(ErrorScope::Scheduler { step }, error) } - /// Create a runner-scoped error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// - /// let err = ErrorEvent::runner("session_123", 10, WeaveError::msg("Runtime error")); - /// ``` + /// Creates a runner-scoped error event. pub fn runner>(session: S, step: u64, error: WeaveError) -> Self { - Self { - when: Utc::now(), - scope: ErrorScope::Runner { - session: session.into(), - step, - }, - error, - tags: Vec::new(), - context: serde_json::Value::Null, - } + Self::with_scope(ErrorScope::Runner { session: session.into(), step }, error) } - /// Create an app-scoped error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// - /// let err = ErrorEvent::app(WeaveError::msg("Application startup failed")); - /// ``` + /// Creates an app-scoped error event. pub fn app(error: WeaveError) -> Self { - Self { - when: Utc::now(), - scope: ErrorScope::App, - error, - tags: Vec::new(), - context: serde_json::Value::Null, - } + Self::with_scope(ErrorScope::App, error) } - /// Add multiple tags to this error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// - /// let err = ErrorEvent::node("my_node", 1, WeaveError::msg("Invalid input")) - /// .with_tags(vec!["validation".to_string(), "critical".to_string()]); - /// ``` + /// Replaces the tag list. pub fn with_tags(mut self, tags: Vec) -> Self { self.tags = tags; self } - /// Add a single tag to this error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// - /// let err = ErrorEvent::node("my_node", 1, WeaveError::msg("Invalid input")) - /// .with_tag("validation"); - /// ``` + /// Appends a single tag. pub fn with_tag>(mut self, tag: S) -> Self { self.tags.push(tag.into()); self } - /// Add context metadata to this error event. - /// - /// # Example - /// ``` - /// use weavegraph::channels::errors::{ErrorEvent, WeaveError}; - /// use serde_json::json; - /// - /// let err = ErrorEvent::node("my_node", 1, WeaveError::msg("Invalid input")) - /// .with_context(json!({"field": "username", "value": ""})); - /// ``` + /// Attaches context metadata. pub fn with_context(mut self, context: serde_json::Value) -> Self { self.context = context; self } } -/// Scope metadata describing where an [`ErrorEvent`] originated. +/// Where an [`ErrorEvent`] originated in the workflow. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] #[serde(tag = "scope", rename_all = "snake_case")] pub enum ErrorScope { - /// Error originated in a node execution. + /// Error occurred in a node execution. Node { /// Node kind identifier. kind: String, - /// Workflow step at which the error occurred. + /// Step number at which the error occurred. step: u64, }, - /// Error originated in the scheduler. + /// Error occurred in the scheduler. Scheduler { - /// Workflow step at which the error occurred. + /// Step number at which the error occurred. step: u64, }, - /// Error originated in the runner. + /// Error occurred in the runner. Runner { - /// Session identifier associated with the error. + /// Session identifier. session: String, - /// Workflow step at which the error occurred. + /// Step number at which the error occurred. step: u64, }, - /// Error originated at the application level (default). + /// Error occurred at the application level (default). #[default] App, } -/// Structured error payload used by [`ErrorEvent`]. +/// Structured error payload for an [`ErrorEvent`]. /// -/// This type supports nested causes and optional machine-readable details. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +/// Supports nested cause chains and optional machine-readable details. +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)] pub struct WeaveError { - /// Primary human-readable error message. + /// Human-readable error message. pub message: String, - /// Optional nested cause for error chaining. + /// Nested cause for error chaining. #[serde(skip_serializing_if = "Option::is_none")] pub cause: Option>, - /// Optional structured metadata attached to the error. + /// Optional structured metadata. #[serde(default)] pub details: serde_json::Value, } -impl Default for WeaveError { - fn default() -> Self { - WeaveError { - message: String::new(), - cause: None, - details: serde_json::Value::Null, - } - } -} - impl std::fmt::Display for WeaveError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.message) + f.write_str(&self.message) } } @@ -269,83 +163,61 @@ impl std::error::Error for WeaveError { } impl WeaveError { - /// Construct an error from a message. + /// Constructs an error from a message. pub fn msg>(m: M) -> Self { - WeaveError { - message: m.into(), - cause: None, - details: serde_json::Value::Null, - } + Self { message: m.into(), ..Default::default() } } - /// Attach structured details to this error. + /// Attaches structured details. pub fn with_details(mut self, details: serde_json::Value) -> Self { self.details = details; self } - /// Attach a nested cause to this error. + /// Attaches a nested cause. pub fn with_cause(mut self, cause: WeaveError) -> Self { self.cause = Some(Box::new(cause)); self } } -/// Format error events with explicit color mode control. +/// Formats error events with explicit color mode control. /// -/// This function allows you to control whether ANSI color codes are included in the output: -/// - [`FormatterMode::Auto`]: Auto-detects TTY capability (checks stderr) -/// - [`FormatterMode::Colored`]: Always includes color codes -/// - [`FormatterMode::Plain`]: Never includes color codes +/// - [`FormatterMode::Auto`]: auto-detects TTY capability on stderr +/// - [`FormatterMode::Colored`]: always emits ANSI color codes +/// - [`FormatterMode::Plain`]: never emits ANSI color codes /// -/// # Examples +/// # Example /// /// ``` /// use weavegraph::channels::errors::{ErrorEvent, WeaveError, pretty_print_with_mode}; /// use weavegraph::telemetry::FormatterMode; /// -/// let events = vec![ -/// ErrorEvent::node("parser", 1, WeaveError::msg("Parse failed")) -/// ]; +/// let events = vec![ErrorEvent::node("parser", 1, WeaveError::msg("Parse failed"))]; /// -/// // Force plain output (no colors) for log files /// let plain = pretty_print_with_mode(&events, FormatterMode::Plain); -/// assert!(!plain.contains("\x1b[")); // No ANSI codes -/// -/// // Force colored output -/// let colored = pretty_print_with_mode(&events, FormatterMode::Colored); +/// assert!(!plain.contains("\x1b[")); /// ``` pub fn pretty_print_with_mode(events: &[ErrorEvent], mode: FormatterMode) -> String { - let formatter = PlainFormatter::with_mode(mode); - let renders = formatter.render_errors(events); - let mut out = String::new(); - for (idx, render) in renders.into_iter().enumerate() { - if idx > 0 { - out.push('\n'); - } - for line in render.lines { - out.push_str(&line); - } - } - out + PlainFormatter::with_mode(mode) + .render_errors(events) + .into_iter() + .map(|r| r.join_lines()) + .collect::>() + .join("\n") } -/// Format error events as human-readable text with auto-detected color support. +/// Formats error events with auto-detected color support. /// -/// Colors are automatically enabled when stderr is a TTY and disabled otherwise. -/// For explicit control over color output, use [`pretty_print_with_mode`]. +/// Colors are enabled when stderr is a TTY. For explicit control, use [`pretty_print_with_mode`]. /// -/// # Examples +/// # Example /// /// ``` /// use weavegraph::channels::errors::{ErrorEvent, WeaveError, pretty_print}; /// -/// let events = vec![ -/// ErrorEvent::node("parser", 1, WeaveError::msg("Parse failed")) -/// ]; -/// +/// let events = vec![ErrorEvent::node("parser", 1, WeaveError::msg("Parse failed"))]; /// let output = pretty_print(&events); -/// // Colors automatically detected based on stderr TTY status /// ``` pub fn pretty_print(events: &[ErrorEvent]) -> String { pretty_print_with_mode(events, FormatterMode::Auto) diff --git a/src/channels/errors_channel.rs b/src/channels/errors_channel.rs index 3bac2f4..8006de1 100644 --- a/src/channels/errors_channel.rs +++ b/src/channels/errors_channel.rs @@ -1,3 +1,4 @@ +//! Channel implementation for accumulating [`ErrorEvent`] entries. use super::Channel; use super::errors::ErrorEvent; use serde::{Deserialize, Serialize}; @@ -10,21 +11,15 @@ pub struct ErrorsChannel { } impl ErrorsChannel { - /// Create a new `ErrorsChannel` with the given events and version counter. + /// Creates a new `ErrorsChannel` with the given events and version counter. pub fn new(events: Vec, version: u32) -> Self { - Self { - value: events, - version, - } + Self { value: events, version } } } impl Default for ErrorsChannel { fn default() -> Self { - Self { - value: Vec::new(), - version: 1, - } + Self { value: Vec::new(), version: 1 } } } diff --git a/src/channels/extras.rs b/src/channels/extras.rs index e2af80f..05137c2 100644 --- a/src/channels/extras.rs +++ b/src/channels/extras.rs @@ -1,9 +1,11 @@ +//! Channel implementation for arbitrary key-value extra data. use rustc_hash::FxHashMap; use super::Channel; use crate::types::ChannelType; type ChannelValue = FxHashMap; + /// Channel that stores arbitrary key-value extra data for the workflow state. #[derive(Clone, Debug, PartialEq, Eq)] pub struct ExtrasChannel { @@ -12,12 +14,15 @@ pub struct ExtrasChannel { } impl ExtrasChannel { - /// Create a new `ExtrasChannel` with the given map and version counter. + /// Creates a new `ExtrasChannel` with the given map and version counter. pub fn new(extras: ChannelValue, version: u32) -> Self { - Self { - value: extras, - version, - } + Self { value: extras, version } + } +} + +impl Default for ExtrasChannel { + fn default() -> Self { + Self { value: FxHashMap::default(), version: 1 } } } @@ -42,24 +47,15 @@ impl Channel for ExtrasChannel { self.version } - fn get_mut(&mut self) -> &mut ChannelValue { - &mut self.value + fn set_version(&mut self, version: u32) { + self.version = version; } - fn set_version(&mut self, version: u32) { - self.version = version + fn get_mut(&mut self) -> &mut ChannelValue { + &mut self.value } fn persistent(&self) -> bool { true } } - -impl Default for ExtrasChannel { - fn default() -> Self { - Self { - value: FxHashMap::default(), - version: 1, - } - } -} diff --git a/src/channels/messages.rs b/src/channels/messages.rs index 3c89905..205d9f4 100644 --- a/src/channels/messages.rs +++ b/src/channels/messages.rs @@ -1,7 +1,9 @@ +//! Channel implementation for the ordered conversation message list. use super::Channel; use crate::{message::Message, types::ChannelType}; type ChannelValue = Vec; + /// Channel that stores the ordered list of conversation messages. #[derive(Clone, Debug, PartialEq, Eq)] pub struct MessagesChannel { @@ -10,12 +12,15 @@ pub struct MessagesChannel { } impl MessagesChannel { - /// Create a new `MessagesChannel` with the given messages and version counter. + /// Creates a new `MessagesChannel` with the given messages and version counter. pub fn new(messages: ChannelValue, version: u32) -> Self { - Self { - value: messages, - version, - } + Self { value: messages, version } + } +} + +impl Default for MessagesChannel { + fn default() -> Self { + Self { value: Vec::new(), version: 1 } } } @@ -40,24 +45,15 @@ impl Channel for MessagesChannel { self.version } - fn get_mut(&mut self) -> &mut ChannelValue { - &mut self.value + fn set_version(&mut self, version: u32) { + self.version = version; } - fn set_version(&mut self, version: u32) { - self.version = version + fn get_mut(&mut self) -> &mut ChannelValue { + &mut self.value } fn persistent(&self) -> bool { true } } - -impl Default for MessagesChannel { - fn default() -> Self { - Self { - value: Vec::new(), - version: 1, - } - } -} diff --git a/src/channels/mod.rs b/src/channels/mod.rs index e71d184..7df9c3e 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -28,7 +28,7 @@ pub trait Channel: Sync + Send { /// Returns the current version counter. fn version(&self) -> u32; /// Sets the version counter to the given value. - fn set_version(&mut self, version: u32) -> (); + fn set_version(&mut self, version: u32); /// Returns a mutable reference to the underlying value. fn get_mut(&mut self) -> &mut T; /// Returns `true` if this channel's data should be persisted across steps. diff --git a/src/event_bus/bus.rs b/src/event_bus/bus.rs index a4fabf9..1e9f821 100644 --- a/src/event_bus/bus.rs +++ b/src/event_bus/bus.rs @@ -1,8 +1,11 @@ //! [`EventBus`] implementation: fan-out broadcast to registered [`EventSink`] workers. + +use std::collections::HashMap; use std::io; -use std::sync::Arc; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; + +use chrono::Utc; use tokio::sync::{broadcast, oneshot}; use tokio::task; @@ -10,21 +13,18 @@ use super::diagnostics::{DiagnosticsStream, HealthState, SinkDiagnostic, SinkHea use super::emitter::EventEmitter; use super::hub::{EventHub, EventHubMetrics, EventStream}; use super::sink::{EventSink, StdOutSink}; -use chrono::Utc; -/// Central event broadcasting system for workflow execution events. -/// -/// `EventBus` receives events from workflow nodes and broadcasts them to multiple -/// sinks (stdout, channels, files, monitoring systems, etc.). It's the backbone -/// of Weavegraph's observability and streaming capabilities. +const DEFAULT_BUFFER_CAPACITY: usize = 1024; + +/// Central event broadcasting system that fans out workflow execution events to registered sinks. /// -/// # Architecture +/// `EventBus` is the observability backbone of Weavegraph. Workflow nodes emit events via +/// [`EventEmitter`]; the bus delivers each event to every registered [`EventSink`] in a +/// dedicated background worker. /// -/// The EventBus is owned by [`AppRunner`](crate::runtimes::runner::AppRunner), not -/// [`App`](crate::app::App). This design allows: -/// - Multiple runners to share the same graph with different event configurations -/// - Per-request event isolation in web servers -/// - Flexible sink composition +/// Owned by [`AppRunner`](crate::runtimes::runner::AppRunner) rather than +/// [`App`](crate::app::App), so multiple runners can share the same graph with isolated +/// event configurations, and per-request isolation in web servers is straightforward. /// /// ```text /// Workflow Nodes @@ -38,136 +38,19 @@ use chrono::Utc; /// Sink Sink Sink Sink /// ``` /// -/// # Usage Patterns -/// -/// ## Default EventBus (Stdout Only) -/// -/// When using [`App::invoke()`](crate::app::App::invoke), a default EventBus -/// with stdout sink is created automatically: -/// -/// ```rust,no_run -/// # use weavegraph::app::App; -/// # use weavegraph::state::VersionedState; -/// # async fn example(app: App) -> Result<(), Box> { -/// // Events automatically go to stdout -/// let result = app.invoke(VersionedState::new_with_user_message("Hello")).await?; -/// # Ok(()) -/// # } -/// ``` -/// -/// ## Custom EventBus (Streaming to Web Clients) -/// -/// For streaming events to web clients, create a custom EventBus and pass it to -/// [`AppRunner`](crate::runtimes::runner::AppRunner): -/// -/// ```rust,no_run -/// use weavegraph::event_bus::{EventBus, ChannelSink, StdOutSink}; -/// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// use weavegraph::state::VersionedState; -/// # use weavegraph::app::App; -/// # async fn example(app: App) -> Result<(), Box> { -/// -/// // Create channel for streaming -/// let (tx, rx) = flume::unbounded(); -/// -/// // Create EventBus with multiple sinks -/// let bus = EventBus::with_sinks(vec![ -/// Box::new(StdOutSink::default()), // Server logs -/// Box::new(ChannelSink::new(tx)), // Client streaming -/// ]); -/// -/// // Pass EventBus to AppRunner -/// let mut runner = AppRunner::builder() -/// .app(app) -/// .checkpointer(CheckpointerType::InMemory) -/// .autosave(false) -/// .event_bus(bus) -/// .build() -/// .await; -/// -/// let session_id = "client-123".to_string(); -/// runner.create_session( -/// session_id.clone(), -/// VersionedState::new_with_user_message("Process this") -/// ).await?; -/// -/// // Consume events from channel -/// tokio::spawn(async move { -/// while let Ok(event) = rx.recv_async().await { -/// // Send to web client via SSE, WebSocket, etc. -/// println!("Event: {:?}", event); -/// } -/// }); -/// -/// runner.run_until_complete(&session_id).await?; -/// # Ok(()) -/// # } -/// ``` -/// -/// ## Per-Request Isolation (Web Server Pattern) -/// -/// Create a new EventBus for each HTTP request to isolate events: -/// -/// ```rust,no_run -/// use std::sync::Arc; -/// use weavegraph::event_bus::{EventBus, ChannelSink}; -/// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// use weavegraph::state::VersionedState; -/// # use weavegraph::app::App; -/// # async fn handle_request(app: Arc) -> Result<(), Box> { -/// -/// // Each request gets its own EventBus and channel -/// let (tx, rx) = flume::unbounded(); -/// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); -/// -/// // Reuse the App, create new runner with isolated EventBus -/// let mut runner = AppRunner::builder() -/// .app(Arc::try_unwrap(app).unwrap_or_else(|arc| (*arc).clone())) -/// .checkpointer(CheckpointerType::InMemory) -/// .autosave(false) -/// .event_bus(bus) -/// .build() -/// .await; -/// -/// // Run workflow - events are isolated to this request -/// let session_id = uuid::Uuid::new_v4().to_string(); -/// runner.create_session( -/// session_id.clone(), -/// VersionedState::new_with_user_message("User query") -/// ).await?; -/// runner.run_until_complete(&session_id).await?; -/// # Ok(()) -/// # } -/// ``` -/// /// # Available Sinks /// -/// - [`StdOutSink`](crate::event_bus::StdOutSink) - Write to stdout (default) -/// - [`ChannelSink`](crate::event_bus::ChannelSink) - Stream to async channels -/// - [`MemorySink`](crate::event_bus::MemorySink) - Capture for testing +/// - [`StdOutSink`](crate::event_bus::StdOutSink) — write events to stdout (default) +/// - [`ChannelSink`](crate::event_bus::ChannelSink) — stream events to async channels +/// - [`MemorySink`](crate::event_bus::MemorySink) — capture events for testing /// - Custom sinks implementing [`EventSink`](crate::event_bus::EventSink) -/// -/// # See Also -/// -/// - [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder) - How to use custom EventBus -/// - [`ChannelSink`](crate::event_bus::ChannelSink) - For streaming events -/// - Example: `examples/streaming_events.rs` - Complete streaming demonstration -const DEFAULT_BUFFER_CAPACITY: usize = 1024; - -/// Central event broadcasting system that fans out workflow events to registered sinks. -/// -/// Create with [`EventBus::with_sink`] or [`EventBus::with_sinks`] and pass to -/// [`AppRunner`](crate::runtimes::runner::AppRunner) for per-request event isolation. pub struct EventBus { sinks: Arc>>, hub: Arc, started: AtomicBool, - /// Generation counter tracking listener restarts so stale workers exit. generation: Arc, - /// Diagnostics broadcast channel for sink errors. diagnostics_tx: broadcast::Sender, - /// In-memory health tracking per sink name. - health: Arc>>, + health: Arc>>, diagnostics_enabled: bool, diagnostics_emit_to_events: bool, } @@ -180,10 +63,7 @@ impl Default for EventBus { impl EventBus { /// Create an `EventBus` with a single sink. - pub fn with_sink(sink: T) -> Self - where - T: EventSink + 'static, - { + pub fn with_sink(sink: T) -> Self { Self::with_sinks(vec![Box::new(sink)]) } @@ -205,20 +85,14 @@ impl EventBus { ) -> Self { let hub = EventHub::new(buffer_capacity); let entries = sinks.into_iter().map(SinkEntry::new).collect(); - let (diagnostics_tx, _) = if diagnostics_enabled { - broadcast::channel(diagnostics_capacity.max(1)) - } else { - // Create a tiny channel and immediately drop the sender at drop time when EventBus drops. - // Publishing will be skipped when disabled, so the capacity is largely irrelevant. - broadcast::channel(1) - }; + let (diagnostics_tx, _) = broadcast::channel(diagnostics_capacity.max(1)); Self { sinks: Arc::new(Mutex::new(entries)), hub, started: AtomicBool::new(false), generation: Arc::new(AtomicU64::new(0)), diagnostics_tx, - health: Arc::new(Mutex::new(std::collections::HashMap::new())), + health: Arc::new(Mutex::new(HashMap::new())), diagnostics_enabled, diagnostics_emit_to_events, } @@ -231,21 +105,17 @@ impl EventBus { /// Attach a new sink to the hub, starting a worker immediately if the bus is live. pub fn add_boxed_sink(&self, sink: Box) { - let mut sinks_guard = self.sinks.lock().expect("EventBus sinks mutex poisoned"); + let mut sinks = self.sinks.lock().expect("EventBus sinks mutex poisoned"); let mut entry = SinkEntry::new(sink); if self.started.load(Ordering::SeqCst) { - let generation = self.generation.load(Ordering::SeqCst); entry.spawn_worker( - self.hub.clone(), + Arc::clone(&self.hub), Arc::clone(&self.generation), - generation, - self.diagnostics_tx.clone(), - Arc::clone(&self.health), - self.diagnostics_enabled, - self.diagnostics_emit_to_events, + self.generation.load(Ordering::SeqCst), + self.worker_diag(), ); } - sinks_guard.push(entry); + sinks.push(entry); } /// Return an [`EventEmitter`] handle for publishing events to this bus. @@ -264,11 +134,9 @@ impl EventBus { self.hub.subscribe() } - /// Access a broadcast stream of sink diagnostics. + /// Return a broadcast stream of [`SinkDiagnostic`] entries emitted when sinks error. /// - /// The returned stream mirrors the `EventStream` consumption model but carries - /// `SinkDiagnostic` entries emitted when sinks report errors. This stream is - /// isolated from the main event flow to avoid feedback loops. + /// Isolated from the main event flow to avoid feedback loops. pub fn diagnostics(&self) -> DiagnosticsStream { DiagnosticsStream::new(self.diagnostics_tx.subscribe()) } @@ -278,8 +146,8 @@ impl EventBus { let health = self.health.lock().expect("EventBus health mutex poisoned"); health .iter() - .map(|(sink, state)| SinkHealth { - sink: sink.clone(), + .map(|(name, state)| SinkHealth { + sink: name.clone(), error_count: state.error_count, last_error: state.last_error.clone(), last_error_at: state.last_error_at, @@ -287,7 +155,7 @@ impl EventBus { .collect() } - /// Spawn workers for every registered sink. Safe to call multiple times. + /// Start all registered sink workers. No-op if already started. pub fn listen_for_events(&self) { if self.started.swap(true, Ordering::SeqCst) { return; @@ -296,35 +164,34 @@ impl EventBus { let generation = self.generation.load(Ordering::SeqCst); for entry in sinks.iter_mut() { entry.spawn_worker( - self.hub.clone(), + Arc::clone(&self.hub), Arc::clone(&self.generation), generation, - self.diagnostics_tx.clone(), - Arc::clone(&self.health), - self.diagnostics_enabled, - self.diagnostics_emit_to_events, + self.worker_diag(), ); } } + fn worker_diag(&self) -> WorkerDiag { + WorkerDiag { + tx: self.diagnostics_tx.clone(), + health: Arc::clone(&self.health), + enabled: self.diagnostics_enabled, + emit_as_events: self.diagnostics_emit_to_events, + } + } + /// Signal all sink workers to stop pulling from the hub. pub async fn stop_listener(&self) { if !self.started.swap(false, Ordering::SeqCst) { return; } self.generation.fetch_add(1, Ordering::SeqCst); - let workers = { + let workers: Vec = { let mut sinks = self.sinks.lock().expect("EventBus sinks mutex poisoned"); - let mut collected = Vec::with_capacity(sinks.len()); - for entry in sinks.iter_mut() { - if let Some(worker) = entry.worker.take() { - collected.push(worker); - } - } - collected + sinks.iter_mut().filter_map(|e| e.worker.take()).collect() }; - for worker in workers { - let SinkWorker { shutdown, handle } = worker; + for SinkWorker { shutdown, handle } in workers { let _ = shutdown.send(()); let _ = handle.await; } @@ -350,7 +217,6 @@ impl Drop for EventBus { struct SinkEntry { sink: Arc>>, - /// Resolved once at registration to avoid recomputing on error paths. name: String, worker: Option, } @@ -358,10 +224,9 @@ struct SinkEntry { impl SinkEntry { fn new(sink: Box) -> Self { let candidate = sink.name(); - let default_marker: &str = std::any::type_name::(); - // Prefer implementor override; otherwise fall back to the dynamic concrete type name. - let name = if candidate == default_marker { - std::any::type_name_of_val(&*sink).to_string() + let trait_default = std::any::type_name::(); + let name = if candidate == trait_default { + std::any::type_name_of_val(&*sink).to_owned() } else { candidate }; @@ -372,56 +237,24 @@ impl SinkEntry { } } - #[allow(clippy::too_many_arguments)] fn spawn_worker( &mut self, hub: Arc, - generation_state: Arc, - active_generation: u64, - diagnostics_tx: broadcast::Sender, - health: Arc>>, - diagnostics_enabled: bool, - diagnostics_emit_to_events: bool, + generation_counter: Arc, + spawned_generation: u64, + diag: WorkerDiag, ) { if self.worker.is_some() { return; } - // Each worker holds an `Arc` to the sink so consumers can add/remove sinks without - // racing the async tasks we spawn here. let sink = Arc::clone(&self.sink); let sink_name = self.name.clone(); let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); let mut stream = hub.subscribe(); - let de_enabled = diagnostics_enabled; - let de_emit = diagnostics_emit_to_events; - let hub_clone = Arc::clone(&hub); + let WorkerDiag { tx: diagnostics_tx, health, enabled: diagnostics_enabled, emit_as_events: emit_diagnostics_as_events } = diag; let handle = task::spawn(async move { - fn record_sink_error( - health: &Arc>>, - diagnostics_tx: &broadcast::Sender, - sink_name: &str, - err_msg: &str, - diagnostics_enabled: bool, - ) { - if diagnostics_enabled { - let mut map = health.lock().expect("health mutex poisoned"); - let entry = map.entry(sink_name.to_string()).or_default(); - entry.error_count = entry.error_count.saturating_add(1); - entry.last_error = Some(err_msg.to_string()); - entry.last_error_at = Some(Utc::now()); - let occurrence = entry.error_count; - drop(map); - let _ = diagnostics_tx.send(SinkDiagnostic { - sink: sink_name.to_string(), - error: err_msg.to_string(), - when: Utc::now(), - occurrence, - }); - } - } loop { - // Bail out early if the bus has been stopped/restarted since this worker spawned. - if generation_state.load(Ordering::SeqCst) != active_generation { + if generation_counter.load(Ordering::SeqCst) != spawned_generation { break; } tokio::select! { @@ -429,53 +262,41 @@ impl SinkEntry { event = stream.recv() => match event { Ok(event) => { let sink = Arc::clone(&sink); - let sink_name = sink_name.clone(); - let diagnostics_tx = diagnostics_tx.clone(); - let health = Arc::clone(&health); - let diagnostics_enabled = de_enabled; - let hub_for_emit = Arc::clone(&hub_clone); - let diagnostics_emit_to_events = de_emit; - // Dispatch potentially blocking sink logic onto the dedicated - // blocking pool so we never park the async runtime thread. let dispatch = task::spawn_blocking(move || -> io::Result<()> { - let mut guard = sink.lock().expect("sink mutex poisoned"); - guard.handle(&event) + sink.lock().expect("sink mutex poisoned").handle(&event) }); - match dispatch.await { - Ok(Ok(())) => {} - Ok(Err(err)) => { - let err_msg = err.to_string(); - tracing::error!( - target: "weavegraph::event_bus", - error = %err_msg, - sink = %sink_name, - "event sink reported an error while handling event" - ); - record_sink_error(&health, &diagnostics_tx, &sink_name, &err_msg, diagnostics_enabled); - if diagnostics_emit_to_events { - let _ = hub_for_emit.publish(super::event::Event::diagnostic( - "event_bus.sink_error", - format!("{sink}: {err}", sink=sink_name, err=err_msg), - )); - } - } - Err(err) => { - let err_msg = err.to_string(); - tracing::error!( - target: "weavegraph::event_bus", - error = %err_msg, - sink = %sink_name, - "event sink worker task failed to join" - ); - // Treat join failures as sink errors for health/diagnostics - record_sink_error(&health, &diagnostics_tx, &sink_name, &err_msg, diagnostics_enabled); - if diagnostics_emit_to_events { - let _ = hub_for_emit.publish(super::event::Event::diagnostic( - "event_bus.sink_join_error", - format!("{sink}: {err}", sink=sink_name, err=err_msg), - )); - } - } + let (label, err_msg) = match dispatch.await { + Ok(Ok(())) => continue, + Ok(Err(e)) => ("event_bus.sink_error", e.to_string()), + Err(e) => ("event_bus.sink_join_error", e.to_string()), + }; + tracing::error!( + target: "weavegraph::event_bus", + error = %err_msg, + sink = %sink_name, + %label, + "sink worker error" + ); + if diagnostics_enabled { + let mut map = health.lock().expect("health mutex poisoned"); + let state = map.entry(sink_name.clone()).or_default(); + state.error_count = state.error_count.saturating_add(1); + state.last_error = Some(err_msg.clone()); + state.last_error_at = Some(Utc::now()); + let occurrence = state.error_count; + drop(map); + let _ = diagnostics_tx.send(SinkDiagnostic { + sink: sink_name.clone(), + error: err_msg.clone(), + when: Utc::now(), + occurrence, + }); + } + if emit_diagnostics_as_events { + let _ = hub.publish(super::event::Event::diagnostic( + label, + format!("{sink_name}: {err_msg}"), + )); } } Err(tokio::sync::broadcast::error::RecvError::Closed) => break, @@ -502,3 +323,10 @@ struct SinkWorker { shutdown: oneshot::Sender<()>, handle: task::JoinHandle<()>, } + +struct WorkerDiag { + tx: broadcast::Sender, + health: Arc>>, + enabled: bool, + emit_as_events: bool, +} diff --git a/src/event_bus/diagnostics.rs b/src/event_bus/diagnostics.rs index 66bac15..054cc37 100644 --- a/src/event_bus/diagnostics.rs +++ b/src/event_bus/diagnostics.rs @@ -1,102 +1,100 @@ -//! Sink health diagnostics: per-sink error tracking and the diagnostics broadcast stream. +//! Per-sink health tracking and the diagnostics broadcast stream. + use std::time::Duration; use chrono::{DateTime, Utc}; use futures_util::stream::{self, BoxStream, StreamExt}; use serde::{Deserialize, Serialize}; -use tokio::sync::broadcast::{self, Receiver, error}; +use tokio::sync::broadcast::{Receiver, error::{RecvError, TryRecvError}}; use tokio::time::timeout; -/// A single diagnostic entry emitted when a sink reports an error. +/// A single error event emitted when a sink fails. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct SinkDiagnostic { - /// Logical sink identifier. Defaults to the sink's type name unless overridden. + /// Logical sink identifier. pub sink: String, - /// Human-readable error message produced by the sink. + /// Error message produced by the sink. pub error: String, - /// Timestamp for when the error was observed. + /// Timestamp of the failure. pub when: DateTime, - /// Monotonic occurrence counter for this sink's errors. + /// Monotonically increasing error count for this sink. pub occurrence: u64, } -/// Public snapshot type representing per-sink health. +/// Health snapshot for a single sink. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub struct SinkHealth { - /// Name of the sink this health snapshot belongs to. + /// Sink identifier. pub sink: String, - /// Total number of errors encountered by this sink. + /// Total errors recorded. pub error_count: u64, + /// Most recent error message, if any. #[serde(skip_serializing_if = "Option::is_none")] - /// Description of the most recent error, if any. pub last_error: Option, - #[serde(skip_serializing_if = "Option::is_none")] /// Timestamp of the most recent error, if any. + #[serde(skip_serializing_if = "Option::is_none")] pub last_error_at: Option>, } -/// Internal accumulator for health tracking. +/// Internal per-sink error accumulator. #[derive(Debug, Default, Clone)] pub struct HealthState { - /// Running count of errors recorded for the sink. + /// Total errors recorded. pub error_count: u64, - /// Description of the most recent error, if any. + /// Last error message. pub last_error: Option, - /// Timestamp of the most recent error, if any. + /// Timestamp of the last error. pub last_error_at: Option>, } -/// Stream wrapper for sink diagnostics, mirroring the EventStream API surface. +/// Broadcast receiver wrapper for sink diagnostics. #[derive(Debug)] pub struct DiagnosticsStream { receiver: Receiver, } impl DiagnosticsStream { - /// Create a new `DiagnosticsStream` from a broadcast receiver. + /// Wrap a broadcast receiver. pub fn new(receiver: Receiver) -> Self { Self { receiver } } - /// Receive the next diagnostic, awaiting if necessary. - pub async fn recv(&mut self) -> Result { + /// Await the next diagnostic. + pub async fn recv(&mut self) -> Result { self.receiver.recv().await } - /// Try to receive a diagnostic without awaiting. - pub fn try_recv(&mut self) -> Result { + /// Poll for the next diagnostic without blocking. + pub fn try_recv(&mut self) -> Result { self.receiver.try_recv() } - /// Consume this wrapper, returning the inner broadcast receiver. + /// Unwrap into the inner broadcast receiver. pub fn into_inner(self) -> Receiver { self.receiver } - /// Convert into a boxed async stream of diagnostics. + /// Convert into a boxed async stream, silently skipping lagged messages. pub fn into_async_stream(self) -> BoxStream<'static, SinkDiagnostic> { - let receiver = self.receiver; - stream::unfold(receiver, |mut receiver| async move { + stream::unfold(self.receiver, |mut rx| async move { loop { - match receiver.recv().await { - Ok(diag) => return Some((diag, receiver)), - // Skip lagged notifications and keep draining - Err(error::RecvError::Lagged(_)) => continue, - Err(error::RecvError::Closed) => return None, + match rx.recv().await { + Ok(diag) => return Some((diag, rx)), + Err(RecvError::Lagged(_)) => continue, + Err(RecvError::Closed) => return None, } } }) .boxed() } - /// Wait up to `duration` for the next diagnostic. + /// Wait up to `duration` for the next diagnostic, skipping lag notifications. pub async fn next_timeout(&mut self, duration: Duration) -> Option { loop { match timeout(duration, self.recv()).await { Ok(Ok(diag)) => return Some(diag), - Ok(Err(error::RecvError::Lagged(_))) => continue, - Ok(Err(error::RecvError::Closed)) => return None, - Err(_) => return None, + Ok(Err(RecvError::Lagged(_))) => continue, + Ok(Err(RecvError::Closed)) | Err(_) => return None, } } } diff --git a/src/event_bus/emitter.rs b/src/event_bus/emitter.rs index d4aef33..58f4b02 100644 --- a/src/event_bus/emitter.rs +++ b/src/event_bus/emitter.rs @@ -1,20 +1,22 @@ //! [`EventEmitter`] trait and [`EmitterError`] for publishing events to the bus. -use std::fmt; + +use std::fmt::Debug; + use thiserror::Error; use super::event::Event; -/// Trait representing an abstract event emitter that workflow nodes can clone. -pub trait EventEmitter: Send + Sync + fmt::Debug { - /// Emit an event in a synchronous, non-blocking manner. +/// Synchronous, non-blocking event publisher injected into workflow nodes. +pub trait EventEmitter: Send + Sync + Debug { + /// Publish an event. fn emit(&self, event: Event) -> Result<(), EmitterError>; } -/// Errors that can occur when emitting an event. +/// Error returned when event emission fails. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum EmitterError { - /// The event hub has been shut down and no longer accepts events. + /// The event hub has been shut down. #[error("event hub closed")] #[cfg_attr( feature = "diagnostics", @@ -37,8 +39,8 @@ pub enum EmitterError { } impl EmitterError { - /// Construct an [`EmitterError::Other`] from any string-convertible error message. - pub fn other(error: impl Into) -> Self { - Self::Other(error.into()) + /// Construct an [`EmitterError::Other`] from any string-like error message. + pub fn other(msg: impl Into) -> Self { + Self::Other(msg.into()) } } diff --git a/src/event_bus/event.rs b/src/event_bus/event.rs index 735fac3..cec9d2e 100644 --- a/src/event_bus/event.rs +++ b/src/event_bus/event.rs @@ -1,4 +1,4 @@ -//! Core event types emitted by workflow nodes and the framework itself. +//! Event types emitted by workflow nodes and the framework. use std::fmt; use chrono::{DateTime, Utc}; @@ -6,44 +6,35 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use serde_json::Value; -/// Scope constant marking the end of a streaming invocation. -/// -/// An event with this scope is emitted by the framework when the event stream closes -/// so that consumers can detect clean stream termination. +/// Emitted when the event stream closes, signalling clean stream termination. pub const STREAM_END_SCOPE: &str = "__weavegraph_stream_end__"; -/// Scope constant marking the end of one logical invocation while a stream stays open. -/// -/// Iterative runners emit this after each [`AppRunner::invoke_next`](crate::runtimes::AppRunner::invoke_next) -/// call so subscribers can separate logical inputs without treating the event bus as closed. +/// Emitted after each [`AppRunner::invoke_next`](crate::runtimes::AppRunner::invoke_next) call +/// so subscribers can separate logical inputs without treating the bus as closed. pub const INVOCATION_END_SCOPE: &str = "__weavegraph_invocation_end__"; -/// Scope constant for diagnostic events emitted by the framework. -/// -/// Use this scope when emitting internal diagnostic information -/// to distinguish framework diagnostics from user node events. -/// Consumers can filter on this scope to capture framework-level -/// telemetry without polluting the main event stream. +/// Scope for internal framework diagnostics; filter on this to separate telemetry +/// from user-emitted node events. pub const DIAGNOSTIC_SCOPE: &str = "__weavegraph_diagnostic__"; -/// A workflow event that can be emitted by nodes or the framework itself. +/// An event that can be emitted by a workflow node or by the framework itself. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum Event { - /// A structured event emitted by a workflow node. + /// A structured event from a workflow node. Node(NodeEvent), - /// A framework-internal diagnostic event. + /// A framework-internal diagnostic. Diagnostic(DiagnosticEvent), - /// An LLM streaming chunk or final/error marker. + /// An LLM streaming chunk, final marker, or error. LLM(LLMStreamingEvent), } impl Event { - /// Create a node event with only a scope and message (no node ID or step). + /// Construct a node event with only a scope and message. pub fn node_message(scope: impl Into, message: impl Into) -> Self { Event::Node(NodeEvent::new(None, None, scope.into(), message.into())) } - /// Create a node event with full metadata (node ID, step, scope, message). + /// Construct a node event with node ID, step, scope, and message. pub fn node_message_with_meta( node_id: impl Into, step: u64, @@ -58,7 +49,7 @@ impl Event { )) } - /// Create a node event with full metadata and additional runtime labels. + /// Construct a node event with full metadata including runtime labels. pub fn node_message_with_metadata( node_id: impl Into, step: u64, @@ -67,52 +58,44 @@ impl Event { metadata: FxHashMap, ) -> Self { Event::Node( - NodeEvent::new( - Some(node_id.into()), - Some(step), - scope.into(), - message.into(), - ) - .with_metadata(metadata), + NodeEvent::new(Some(node_id.into()), Some(step), scope.into(), message.into()) + .with_metadata(metadata), ) } - /// Create a diagnostic event with the given scope and message. + /// Construct a diagnostic event. pub fn diagnostic(scope: impl Into, message: impl Into) -> Self { - Event::Diagnostic(DiagnosticEvent { - scope: scope.into(), - message: message.into(), - }) + Event::Diagnostic(DiagnosticEvent { scope: scope.into(), message: message.into() }) } - /// Return the scope label string if the event carries one. + /// Returns the scope label for this event. pub fn scope_label(&self) -> Option<&str> { match self { - Event::Node(node) => Some(node.scope()), - Event::Diagnostic(diag) => Some(diag.scope()), - Event::LLM(llm) => Some(llm.scope().as_ref()), + Event::Node(n) => Some(n.scope()), + Event::Diagnostic(d) => Some(d.scope()), + Event::LLM(l) => Some(l.scope().as_ref()), } } - /// Return the primary message text for this event. + /// Returns the primary message text. pub fn message(&self) -> &str { match self { - Event::Node(node) => node.message(), - Event::Diagnostic(diag) => diag.message(), - Event::LLM(llm) => llm.chunk(), + Event::Node(n) => n.message(), + Event::Diagnostic(d) => d.message(), + Event::LLM(l) => l.chunk(), } } - /// Convert event to structured JSON value with normalized schema. + /// Serialises the event to a normalised JSON value. /// - /// Returns a JSON object with the following structure: + /// The returned object has the shape: /// ```json /// { /// "type": "node" | "diagnostic" | "llm", - /// "scope": "scope_label", - /// "message": "event_message", - /// "timestamp": "2025-11-03T12:34:56.789Z", - /// "metadata": { /* variant-specific fields */ } + /// "scope": "", + /// "message": "", + /// "timestamp": "", + /// "metadata": { ... } /// } /// ``` /// @@ -133,51 +116,36 @@ impl Event { pub fn to_json_value(&self) -> serde_json::Value { use serde_json::json; - let (event_type, metadata) = match self { - Event::Node(node) => { - let mut meta = serde_json::Map::new(); - for (key, value) in node.metadata() { - meta.insert(key.clone(), value.clone()); - } - if let Some(node_id) = node.node_id() { - meta.insert("node_id".to_string(), json!(node_id)); + let (event_type, metadata, timestamp) = match self { + Event::Node(n) => { + let mut meta: serde_json::Map = + n.metadata().iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + if let Some(id) = n.node_id() { + meta.insert("node_id".to_owned(), json!(id)); } - if let Some(step) = node.step() { - meta.insert("step".to_string(), json!(step)); + if let Some(step) = n.step() { + meta.insert("step".to_owned(), json!(step)); } - ("node", Value::Object(meta)) + ("node", Value::Object(meta), Utc::now()) } - Event::Diagnostic(_) => { - let meta = serde_json::Map::new(); - ("diagnostic", Value::Object(meta)) - } - Event::LLM(llm) => { - let mut meta = serde_json::Map::new(); - if let Some(session_id) = llm.session_id() { - meta.insert("session_id".to_string(), json!(session_id)); - } - if let Some(node_id) = llm.node_id() { - meta.insert("node_id".to_string(), json!(node_id)); + Event::Diagnostic(_) => ("diagnostic", json!({}), Utc::now()), + Event::LLM(l) => { + let mut meta: serde_json::Map = + l.metadata().iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + if let Some(id) = l.session_id() { + meta.insert("session_id".to_owned(), json!(id)); } - if let Some(stream_id) = llm.stream_id() { - meta.insert("stream_id".to_string(), json!(stream_id)); + if let Some(id) = l.node_id() { + meta.insert("node_id".to_owned(), json!(id)); } - meta.insert("is_final".to_string(), json!(llm.is_final())); - - // Include LLM metadata fields - for (key, value) in llm.metadata() { - meta.insert(key.clone(), value.clone()); + if let Some(id) = l.stream_id() { + meta.insert("stream_id".to_owned(), json!(id)); } - - ("llm", Value::Object(meta)) + meta.insert("is_final".to_owned(), json!(l.is_final())); + ("llm", Value::Object(meta), l.timestamp()) } }; - let timestamp = match self { - Event::LLM(llm) => llm.timestamp(), - _ => Utc::now(), - }; - json!({ "type": event_type, "scope": self.scope_label(), @@ -187,7 +155,7 @@ impl Event { }) } - /// Convert event to compact JSON string representation. + /// Serialises the event to a compact JSON string. /// /// # Example /// @@ -202,9 +170,7 @@ impl Event { serde_json::to_string(&self.to_json_value()) } - /// Convert event to pretty-printed JSON string with indentation. - /// - /// Useful for debugging and log files where human readability is important. + /// Serialises the event to an indented JSON string. /// /// # Example /// @@ -223,20 +189,20 @@ impl Event { impl fmt::Display for Event { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Event::Node(node) => match (node.node_id(), node.step()) { - (Some(id), Some(step)) => write!(f, "[{id}@{step}] {}", node.message()), - (Some(id), None) => write!(f, "[{id}] {}", node.message()), - (None, Some(step)) => write!(f, "[step {step}] {}", node.message()), - (None, None) => write!(f, "{}", node.message()), + Event::Node(n) => match (n.node_id(), n.step()) { + (Some(id), Some(step)) => write!(f, "[{id}@{step}] {}", n.message()), + (Some(id), None) => write!(f, "[{id}] {}", n.message()), + (None, Some(step)) => write!(f, "[step {step}] {}", n.message()), + (None, None) => write!(f, "{}", n.message()), }, - Event::Diagnostic(diag) => write!(f, "{}", diag.message()), - Event::LLM(llm) => { - if let Some(stream_id) = llm.stream_id() { - write!(f, "[LLM {stream_id}] {}", llm.chunk()) - } else if let Some(node_id) = llm.node_id() { - write!(f, "[LLM {node_id}] {}", llm.chunk()) + Event::Diagnostic(d) => write!(f, "{}", d.message()), + Event::LLM(l) => { + if let Some(id) = l.stream_id() { + write!(f, "[LLM {id}] {}", l.chunk()) + } else if let Some(id) = l.node_id() { + write!(f, "[LLM {id}] {}", l.chunk()) } else { - write!(f, "{}", llm.chunk()) + write!(f, "{}", l.chunk()) } } } @@ -255,18 +221,12 @@ pub struct NodeEvent { } impl NodeEvent { - /// Create a new `NodeEvent` with optional node ID and step number. + /// Create a new node event. pub fn new(node_id: Option, step: Option, scope: String, message: String) -> Self { - Self { - node_id, - step, - scope, - message, - metadata: FxHashMap::default(), - } + Self { node_id, step, scope, message, metadata: FxHashMap::default() } } - /// Returns the node identifier, if set. + /// Returns the node ID, if set. pub fn node_id(&self) -> Option<&str> { self.node_id.as_deref() } @@ -276,29 +236,29 @@ impl NodeEvent { self.step } - /// Returns the scope label for this event. + /// Returns the scope label. pub fn scope(&self) -> &str { &self.scope } - /// Returns the event message text. + /// Returns the message text. pub fn message(&self) -> &str { &self.message } - /// Returns the metadata map attached to this node event. + /// Returns the metadata map. pub fn metadata(&self) -> &FxHashMap { &self.metadata } - /// Return a new node event with the given metadata map. + /// Replace the metadata map and return `self`. pub fn with_metadata(mut self, metadata: FxHashMap) -> Self { self.metadata = metadata; self } } -/// A framework-internal diagnostic event emitted outside normal node execution. +/// A framework-internal diagnostic event, emitted outside normal node execution. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct DiagnosticEvent { scope: String, @@ -306,7 +266,7 @@ pub struct DiagnosticEvent { } impl DiagnosticEvent { - /// Returns the scope label for this diagnostic event. + /// Returns the scope label. pub fn scope(&self) -> &str { &self.scope } @@ -320,28 +280,28 @@ impl DiagnosticEvent { /// Scope discriminant for LLM streaming events. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum LLMStreamingEventScope { - /// An in-progress streaming session (default scope). + /// An in-progress streaming session (default). Streaming, /// A single text chunk within a streaming response. Chunk, - /// The final chunk marking the end of the stream. + /// The final chunk, marking end-of-stream. Final, - /// An error event terminating the stream. + /// An error event that terminates the stream. Error, } impl AsRef for LLMStreamingEventScope { fn as_ref(&self) -> &str { match self { - LLMStreamingEventScope::Chunk => "chunk", - LLMStreamingEventScope::Streaming => "stream", - LLMStreamingEventScope::Final => STREAM_END_SCOPE, - LLMStreamingEventScope::Error => "error", + Self::Streaming => "stream", + Self::Chunk => "chunk", + Self::Final => STREAM_END_SCOPE, + Self::Error => "error", } } } -/// An event carrying an LLM response chunk, final marker, or error from a streaming session. +/// An LLM event carrying a response chunk, a final marker, or an error. #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct LLMStreamingEvent { session_id: Option, @@ -355,31 +315,12 @@ pub struct LLMStreamingEvent { } impl LLMStreamingEvent { - #[allow(clippy::too_many_arguments)] - /// Create a new `LLMStreamingEvent` with full field control. - pub fn new( - session_id: Option, - node_id: Option, - stream_id: Option, - chunk: impl Into, - is_final: bool, - scope: Option, - metadata: FxHashMap, - timestamp: DateTime, - ) -> Self { - Self { - session_id, - node_id, - stream_id, - chunk: chunk.into(), - is_final, - scope: scope.unwrap_or(LLMStreamingEventScope::Streaming), - metadata, - timestamp, - } + /// Return a builder for constructing a new event; `chunk` is the only required field. + pub fn builder(chunk: impl Into) -> LLMStreamingEventBuilder { + LLMStreamingEventBuilder::new(chunk) } - /// Create a chunk event representing a partial LLM response. + /// Create a partial-chunk event. pub fn chunk_event( session_id: Option, node_id: Option, @@ -387,19 +328,19 @@ impl LLMStreamingEvent { chunk: impl Into, metadata: FxHashMap, ) -> Self { - Self::new( + Self { session_id, node_id, stream_id, - chunk, - false, - Some(LLMStreamingEventScope::Chunk), + chunk: chunk.into(), + is_final: false, + scope: LLMStreamingEventScope::Chunk, metadata, - Utc::now(), - ) + timestamp: Utc::now(), + } } - /// Create a final event marking the end of an LLM streaming session. + /// Create a final-chunk event marking the end of the stream. pub fn final_event( session_id: Option, node_id: Option, @@ -407,88 +348,175 @@ impl LLMStreamingEvent { chunk: impl Into, metadata: FxHashMap, ) -> Self { - Self::new( + Self { session_id, node_id, stream_id, - chunk, - true, - Some(LLMStreamingEventScope::Final), + chunk: chunk.into(), + is_final: true, + scope: LLMStreamingEventScope::Final, metadata, - Utc::now(), - ) + timestamp: Utc::now(), + } } - /// Create an error event marking a failed LLM streaming session. + /// Create an error event marking a failed streaming session. pub fn error_event( session_id: Option, node_id: Option, stream_id: Option, error_message: impl Into, ) -> Self { - let mut metadata = FxHashMap::default(); - metadata.insert("severity".to_string(), Value::String("error".to_string())); - Self::new( + let metadata = [("severity".to_owned(), Value::String("error".to_owned()))] + .into_iter() + .collect(); + Self { session_id, node_id, stream_id, - error_message, - true, - Some(LLMStreamingEventScope::Error), + chunk: error_message.into(), + is_final: true, + scope: LLMStreamingEventScope::Error, metadata, - Utc::now(), - ) + timestamp: Utc::now(), + } } - /// Returns the session identifier, if set. + /// Returns the session ID, if set. pub fn session_id(&self) -> Option<&str> { self.session_id.as_deref() } - /// Returns the node identifier, if set. + /// Returns the node ID, if set. pub fn node_id(&self) -> Option<&str> { self.node_id.as_deref() } - /// Returns the stream identifier, if set. + /// Returns the stream ID, if set. pub fn stream_id(&self) -> Option<&str> { self.stream_id.as_deref() } - /// Returns the text chunk carried by this event. + /// Returns the text chunk. pub fn chunk(&self) -> &str { &self.chunk } - /// Returns `true` if this event marks the final chunk of the stream. + /// Returns `true` if this is the final event in the stream. pub fn is_final(&self) -> bool { self.is_final } - /// Returns the scope of this streaming event. + /// Returns the scope discriminant. pub fn scope(&self) -> &LLMStreamingEventScope { &self.scope } - /// Returns the metadata map attached to this event. + /// Returns the metadata map. pub fn metadata(&self) -> &FxHashMap { &self.metadata } - /// Returns the timestamp at which this event was created. + /// Returns the event creation timestamp. pub fn timestamp(&self) -> DateTime { self.timestamp } - /// Return a new event with the given metadata map replacing the existing one. + /// Replace the metadata map and return `self`. pub fn with_metadata(mut self, metadata: FxHashMap) -> Self { self.metadata = metadata; self } - /// Return a new event with the given timestamp replacing the existing one. + /// Replace the timestamp and return `self`. pub fn with_timestamp(mut self, timestamp: DateTime) -> Self { self.timestamp = timestamp; self } } + +/// Builder for [`LLMStreamingEvent`]. +/// +/// Obtain one via [`LLMStreamingEvent::builder`]. All fields except `chunk` are optional; +/// `timestamp` defaults to [`Utc::now`] when [`build`](Self::build) is called. +pub struct LLMStreamingEventBuilder { + session_id: Option, + node_id: Option, + stream_id: Option, + chunk: String, + is_final: bool, + scope: LLMStreamingEventScope, + metadata: FxHashMap, + timestamp: Option>, +} + +impl LLMStreamingEventBuilder { + fn new(chunk: impl Into) -> Self { + Self { + session_id: None, + node_id: None, + stream_id: None, + chunk: chunk.into(), + is_final: false, + scope: LLMStreamingEventScope::Streaming, + metadata: FxHashMap::default(), + timestamp: None, + } + } + + /// Set the session ID. + pub fn session_id(mut self, id: impl Into) -> Self { + self.session_id = Some(id.into()); + self + } + + /// Set the node ID. + pub fn node_id(mut self, id: impl Into) -> Self { + self.node_id = Some(id.into()); + self + } + + /// Set the stream ID. + pub fn stream_id(mut self, id: impl Into) -> Self { + self.stream_id = Some(id.into()); + self + } + + /// Mark the event as final. + pub fn is_final(mut self, v: bool) -> Self { + self.is_final = v; + self + } + + /// Set the scope discriminant (defaults to [`LLMStreamingEventScope::Streaming`]). + pub fn scope(mut self, s: LLMStreamingEventScope) -> Self { + self.scope = s; + self + } + + /// Set the metadata map. + pub fn metadata(mut self, m: FxHashMap) -> Self { + self.metadata = m; + self + } + + /// Fix the creation timestamp; defaults to [`Utc::now`] if omitted. + pub fn timestamp(mut self, ts: DateTime) -> Self { + self.timestamp = Some(ts); + self + } + + /// Consume the builder and return the event. + pub fn build(self) -> LLMStreamingEvent { + LLMStreamingEvent { + session_id: self.session_id, + node_id: self.node_id, + stream_id: self.stream_id, + chunk: self.chunk, + is_final: self.is_final, + scope: self.scope, + metadata: self.metadata, + timestamp: self.timestamp.unwrap_or_else(Utc::now), + } + } +} diff --git a/src/event_bus/hub.rs b/src/event_bus/hub.rs index 89e0eac..9e3e899 100644 --- a/src/event_bus/hub.rs +++ b/src/event_bus/hub.rs @@ -49,36 +49,25 @@ impl EventHub { /// /// Returns [`EmitterError::Closed`] if the hub has been shut down. pub fn publish(&self, event: Event) -> Result<(), EmitterError> { - match self.current_sender() { - Some(sender) => match sender.send(event) { - Ok(_) => Ok(()), - Err(broadcast::error::SendError(event)) => { - drop(event); - Err(EmitterError::Closed) - } - }, - None => Err(EmitterError::Closed), - } + self.current_sender() + .ok_or(EmitterError::Closed) + .and_then(|s| s.send(event).map(|_| ()).map_err(|_| EmitterError::Closed)) } /// Subscribe to a fresh receiver. /// - /// If the hub has already been closed, this returns a closed receiver to keep - /// downstream code simple. + /// If the hub has already been closed, returns a closed receiver so downstream + /// code can proceed uniformly. pub fn subscribe(self: &Arc) -> EventStream { - let receiver = self - .current_sender() - .map(|sender| sender.subscribe()) - .unwrap_or_else(|| { - let (sender, receiver) = broadcast::channel(self.capacity.max(1)); - drop(sender); - receiver - }); - EventStream { - receiver, - hub: Arc::clone(self), - shutdown: None, - } + let receiver = match self.current_sender() { + Some(s) => s.subscribe(), + None => { + let (tx, rx) = broadcast::channel(self.capacity); + drop(tx); + rx + } + }; + EventStream { receiver, hub: Arc::clone(self), shutdown: None } } /// Returns the configured buffer capacity of the underlying broadcast channel. @@ -101,40 +90,28 @@ impl EventHub { /// Create a [`HubEmitter`] that publishes events to this hub. pub fn emitter(self: &Arc) -> HubEmitter { - HubEmitter { - hub: Arc::clone(self), - } + HubEmitter { hub: Arc::clone(self) } } /// Close the hub and signal all subscribers that no further events will arrive. pub fn close(&self) { - let _ = self - .sender - .write() - .expect("EventHub sender RwLock poisoned") - .take(); + self.sender.write().expect("hub sender lock poisoned").take(); } fn current_sender(&self) -> Option> { - self.sender - .read() - .expect("EventHub sender RwLock poisoned") - .clone() + self.sender.read().expect("hub sender lock poisoned").clone() } fn record_lag(&self, missed: u64) { if missed == 0 { return; } - let increment = usize::try_from(missed).unwrap_or(usize::MAX); - let total = self - .dropped_events - .fetch_add(increment, Ordering::Relaxed) - .saturating_add(increment); + let n = usize::try_from(missed).unwrap_or(usize::MAX); + let prev = self.dropped_events.fetch_add(n, Ordering::Relaxed); tracing::warn!( target: "weavegraph::event_bus", missed, - total_dropped = total, + total_dropped = prev.saturating_add(n), "event stream lagged; dropped events" ); } @@ -164,24 +141,22 @@ impl EventStream { /// Receive the next event, awaiting if the channel is empty. pub async fn recv(&mut self) -> Result { match self.receiver.recv().await { - Ok(event) => Ok(event), - Err(broadcast::error::RecvError::Lagged(missed)) => { - self.hub.record_lag(missed); - Err(broadcast::error::RecvError::Lagged(missed)) + Err(broadcast::error::RecvError::Lagged(n)) => { + self.hub.record_lag(n); + Err(broadcast::error::RecvError::Lagged(n)) } - Err(err) => Err(err), + result => result, } } /// Try to receive an event without blocking; returns immediately if none is available. pub fn try_recv(&mut self) -> Result { match self.receiver.try_recv() { - Ok(event) => Ok(event), - Err(broadcast::error::TryRecvError::Lagged(missed)) => { - self.hub.record_lag(missed); - Err(broadcast::error::TryRecvError::Lagged(missed)) + Err(broadcast::error::TryRecvError::Lagged(n)) => { + self.hub.record_lag(n); + Err(broadcast::error::TryRecvError::Lagged(n)) } - Err(err) => Err(err), + result => result, } } @@ -192,76 +167,52 @@ impl EventStream { /// Convert this stream into a synchronous blocking iterator. pub fn into_blocking_iter(self) -> BlockingEventIter { - BlockingEventIter { - receiver: self.receiver, - hub: self.hub, - } + BlockingEventIter { receiver: self.receiver, hub: self.hub } } /// Attach a shutdown watch channel; the stream ends when the watch value becomes `true`. pub fn with_shutdown(mut self, shutdown: watch::Receiver) -> Self { - // Consumers can share a `watch` channel to terminate the stream early when - // the producer side shuts down (e.g. HTTP connection dropped). self.shutdown = Some(shutdown); self } /// Convert this stream into a pinned `BoxStream` for use with async combinators. pub fn into_async_stream(self) -> BoxStream<'static, Event> { - // Convert the broadcast receiver into a boxed stream so callers can plug it into - // combinators without worrying about pinning or generics at the call site. - let EventStream { - receiver, - hub, - shutdown, - } = self; - stream::unfold((receiver, hub, shutdown), |(mut receiver, hub, mut shutdown)| async move { - loop { - if let Some(ref mut shutdown_rx) = shutdown { - tokio::select! { - biased; - changed = shutdown_rx.changed() => { - if changed.is_ok() && *shutdown_rx.borrow() { - return None; + let EventStream { receiver, hub, shutdown } = self; + stream::unfold( + (receiver, hub, shutdown), + |(mut receiver, hub, mut shutdown)| async move { + loop { + let recv_result = if let Some(ref mut rx) = shutdown { + tokio::select! { + biased; + changed = rx.changed() => { + if changed.is_ok() && *rx.borrow() { return None; } + continue; } - continue; - } - recv = receiver.recv() => { - match recv { - Ok(event) => return Some((event, (receiver, hub.clone(), shutdown))), - Err(broadcast::error::RecvError::Lagged(missed)) => { - hub.record_lag(missed); - continue; - } - Err(broadcast::error::RecvError::Closed) => return None, - } - } - } - } else { - match receiver.recv().await { - Ok(event) => return Some((event, (receiver, hub.clone(), shutdown))), - Err(broadcast::error::RecvError::Lagged(missed)) => { - hub.record_lag(missed); - continue; + result = receiver.recv() => result, } + } else { + receiver.recv().await + }; + match recv_result { + Ok(ev) => return Some((ev, (receiver, hub, shutdown))), + Err(broadcast::error::RecvError::Lagged(n)) => hub.record_lag(n), Err(broadcast::error::RecvError::Closed) => return None, } } - } - }) + }, + ) .boxed() } /// Receive the next event, waiting at most `duration`; returns `None` on timeout or close. pub async fn next_timeout(&mut self, duration: Duration) -> Option { - // Keep polling until we either obtain an event, the channel closes, or the - // deadline elapses. Lagged notifications simply increment drop metrics and retry. loop { match timeout(duration, self.recv()).await { Ok(Ok(event)) => return Some(event), Ok(Err(broadcast::error::RecvError::Lagged(_))) => continue, - Ok(Err(broadcast::error::RecvError::Closed)) => return None, - Err(_) => return None, + Ok(Err(broadcast::error::RecvError::Closed)) | Err(_) => return None, } } } @@ -280,10 +231,7 @@ impl Iterator for BlockingEventIter { loop { match self.receiver.blocking_recv() { Ok(event) => return Some(event), - Err(broadcast::error::RecvError::Lagged(missed)) => { - self.hub.record_lag(missed); - continue; - } + Err(broadcast::error::RecvError::Lagged(n)) => self.hub.record_lag(n), Err(broadcast::error::RecvError::Closed) => return None, } } diff --git a/src/event_bus/mod.rs b/src/event_bus/mod.rs index 2bb37f5..f2f25fa 100644 --- a/src/event_bus/mod.rs +++ b/src/event_bus/mod.rs @@ -1,17 +1,5 @@ -//! Event bus utilities providing fan-out, sinks, and subscriber APIs. -//! -//! The module is organised around a broadcast-based [`EventHub`] and helpers for -//! configuring sinks (`EventBus`) and consuming the resulting [`EventStream`]. -//! -//! # JSON Serialization -//! -//! Events can be serialized to JSON using: -//! - [`Event::to_json_value()`] - Structured JSON value with normalized schema -//! - [`Event::to_json_string()`] - Compact JSON string -//! - [`Event::to_json_pretty()`] - Pretty-printed JSON for debugging -//! -//! The [`JsonLinesSink`] provides machine-readable JSON Lines output for log -//! aggregation systems and monitoring tools. +//! Event bus: broadcast-based [`EventHub`], sink configuration via [`EventBus`], +//! and subscriber APIs over the resulting [`EventStream`]. pub mod bus; pub mod diagnostics; diff --git a/src/event_bus/sink.rs b/src/event_bus/sink.rs index 3786e78..f78dfee 100644 --- a/src/event_bus/sink.rs +++ b/src/event_bus/sink.rs @@ -1,4 +1,4 @@ -//! [`EventSink`] trait and built-in sink implementations: stdout, in-memory, channel, and JSON lines. +//! [`EventSink`] trait and built-in implementations: stdout, in-memory, channel, and JSON lines. use flume; use std::any::type_name; use std::fs::File; @@ -9,24 +9,23 @@ use std::sync::{Arc, Mutex}; use super::event::Event; use crate::telemetry::{PlainFormatter, TelemetryFormatter}; -/// Abstraction over an output target that consumes full Event objects. +/// Output target that consumes structured [`Event`] objects. pub trait EventSink: Sync + Send { - /// Handle a structured event. Sink decides how to serialize/format it. + /// Deliver a structured event to this sink. /// - /// Implementations are allowed to perform blocking I/O; the event bus will - /// hand the call off to `spawn_blocking` to keep the async runtime responsive. + /// Implementations may perform blocking I/O; the event bus dispatches via + /// `spawn_blocking` so callers are not stalled. fn handle(&mut self, event: &Event) -> IoResult<()>; - /// A stable, human-friendly identifier for this sink instance. + /// Human-readable identifier for this sink instance. /// - /// Defaults to the concrete type name; implementors may override to provide - /// shorter names or include configuration context. + /// Defaults to the concrete type name; override to include configuration context. fn name(&self) -> String { type_name::().to_string() } } -/// Stdout sink with optional formatting. +/// Stdout sink backed by an optional [`TelemetryFormatter`]. pub struct StdOutSink { handle: Stdout, formatter: F, @@ -42,7 +41,7 @@ impl Default for StdOutSink { } impl StdOutSink { - /// Create a `StdOutSink` that formats events using the given `TelemetryFormatter`. + /// Build a `StdOutSink` using `formatter` to render each event. pub fn with_formatter(formatter: F) -> Self { Self { handle: io::stdout(), @@ -53,98 +52,55 @@ impl StdOutSink { impl EventSink for StdOutSink { fn handle(&mut self, event: &Event) -> IoResult<()> { - let rendered = self.formatter.render_event(event).join_lines(); - self.handle.write_all(rendered.as_bytes())?; + let text = self.formatter.render_event(event).join_lines(); + self.handle.write_all(text.as_bytes())?; self.handle.flush() } } -/// In-memory sink for testing and snapshots. +/// In-memory sink that accumulates events for inspection or testing. #[derive(Clone, Default)] pub struct MemorySink { entries: Arc>>, } impl MemorySink { - /// Create a new, empty `MemorySink`. + /// Create an empty `MemorySink`. pub fn new() -> Self { Self::default() } - /// Get a snapshot of all captured events. Clones the internal buffer so callers - /// can inspect state without holding the mutex. + /// Return a snapshot of all captured events. + /// + /// Clones the internal buffer so callers do not hold the mutex. pub fn snapshot(&self) -> Vec { - self.entries - .lock() - .expect("MemorySink mutex poisoned") - .clone() + self.entries.lock().expect("MemorySink mutex poisoned").clone() } - /// Clear all captured events. + /// Discard all captured events. pub fn clear(&self) { - self.entries - .lock() - .expect("MemorySink mutex poisoned") - .clear(); + self.entries.lock().expect("MemorySink mutex poisoned").clear(); } } impl EventSink for MemorySink { fn handle(&mut self, event: &Event) -> IoResult<()> { - self.entries - .lock() - .expect("MemorySink mutex poisoned") - .push(event.clone()); + self.entries.lock().expect("MemorySink mutex poisoned").push(event.clone()); Ok(()) } } -/// JSON Lines (JSONL) sink for machine-readable structured logging. -/// -/// Outputs one JSON object per line, ideal for: -/// - Log aggregation systems (ELK, Splunk, DataDog) -/// - Stream processing pipelines -/// - Automated testing with structured assertions -/// - Integration with monitoring tools -/// -/// # Format -/// -/// Each event is serialized to a single line of JSON using the normalized schema: -/// ```json -/// {"type":"node","scope":"routing","message":"Processing","timestamp":"2025-11-03T12:34:56Z","metadata":{"node_id":"router","step":5}} -/// {"type":"diagnostic","scope":"system","message":"Ready","timestamp":"2025-11-03T12:34:57Z","metadata":{}} -/// ``` -/// -/// # Examples +/// Sink that writes one JSON object per line (JSONL / JSON Lines format). /// -/// ## Write to stdout -/// -/// ```rust,no_run -/// use weavegraph::event_bus::{EventBus, JsonLinesSink}; -/// -/// let sink = JsonLinesSink::to_stdout(); -/// let bus = EventBus::with_sinks(vec![Box::new(sink)]); -/// // Events will be written as JSON lines to stdout -/// ``` +/// Each event serializes to a single output line, suitable for log aggregation +/// pipelines (ELK, Splunk, DataDog), stream processors, and structured test assertions. /// -/// ## Write to file +/// # Example /// /// ```rust,no_run /// use weavegraph::event_bus::{EventBus, JsonLinesSink}; /// -/// let sink = JsonLinesSink::to_file("events.jsonl").unwrap(); -/// let bus = EventBus::with_sinks(vec![Box::new(sink)]); -/// // Events will be written to events.jsonl -/// ``` -/// -/// ## Pretty-printed output -/// -/// ```rust,no_run -/// use weavegraph::event_bus::JsonLinesSink; -/// use std::io; -/// -/// let sink = JsonLinesSink::with_pretty_print(Box::new(io::stdout())); -/// // Events will be pretty-printed (not valid JSONL, but human-readable) +/// let bus = EventBus::with_sinks(vec![Box::new(JsonLinesSink::to_stdout())]); /// ``` pub struct JsonLinesSink { handle: Box, @@ -152,83 +108,29 @@ pub struct JsonLinesSink { } impl JsonLinesSink { - /// Create a new JsonLinesSink with a custom writer. - /// - /// # Parameters - /// - /// * `handle` - Any writer implementing Write + Send - /// - /// # Example - /// - /// ```rust - /// use weavegraph::event_bus::JsonLinesSink; - /// use std::io::Cursor; - /// - /// let buffer = Cursor::new(Vec::new()); - /// let sink = JsonLinesSink::new(Box::new(buffer)); - /// ``` + /// Create a compact (one-line-per-event) sink writing to `handle`. pub fn new(handle: Box) -> Self { - Self { - handle, - pretty: false, - } + Self { handle, pretty: false } } - /// Create a JsonLinesSink with pretty-printing enabled. - /// - /// Note: Pretty-printed output is NOT valid JSON Lines format - /// (which requires one JSON object per line). Use this for debugging - /// and human-readable logs only. - /// - /// # Example + /// Create a pretty-printed sink writing to `handle`. /// - /// ```rust - /// use weavegraph::event_bus::JsonLinesSink; - /// use std::io::Cursor; - /// - /// let buffer = Cursor::new(Vec::new()); - /// let sink = JsonLinesSink::with_pretty_print(Box::new(buffer)); - /// ``` + /// Pretty-printed output spans multiple lines and is **not** valid JSONL. + /// Use for debugging and human-readable logs only. pub fn with_pretty_print(handle: Box) -> Self { - Self { - handle, - pretty: true, - } + Self { handle, pretty: true } } - /// Create a JsonLinesSink writing to stdout. - /// - /// # Example - /// - /// ```rust,no_run - /// use weavegraph::event_bus::JsonLinesSink; - /// - /// let sink = JsonLinesSink::to_stdout(); - /// ``` + /// Create a compact sink writing to stdout. pub fn to_stdout() -> Self { Self::new(Box::new(io::stdout())) } - /// Create a JsonLinesSink writing to a file. - /// - /// # Parameters - /// - /// * `path` - Path to the output file (will be created or truncated) + /// Create a compact sink writing to `path` (created or truncated). /// - /// # Errors - /// - /// Returns an error if the file cannot be created or opened. - /// - /// # Example - /// - /// ```rust,no_run - /// use weavegraph::event_bus::JsonLinesSink; - /// - /// let sink = JsonLinesSink::to_file("events.jsonl").unwrap(); - /// ``` + /// Returns an error if the file cannot be opened. pub fn to_file(path: impl AsRef) -> IoResult { - let file = File::create(path)?; - Ok(Self::new(Box::new(file))) + Ok(Self::new(Box::new(File::create(path)?))) } } @@ -241,7 +143,7 @@ impl EventSink for JsonLinesSink { } .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - writeln!(self.handle, "{}", json)?; + writeln!(self.handle, "{json}")?; self.handle.flush() } @@ -254,185 +156,52 @@ impl EventSink for JsonLinesSink { } } -/// Channel-based sink for streaming events to async consumers. -/// -/// `ChannelSink` forwards events to a flume channel, enabling real-time -/// event streaming to web clients, monitoring systems, or any async consumer. +/// Sink that forwards events to a [`flume`] channel for async consumers. /// -/// # Use Cases +/// Enables real-time streaming to web clients, dashboards, or monitoring systems. +/// Wire the paired receiver to your consumer before the graph starts running. /// -/// - **Server-Sent Events (SSE)**: Stream workflow progress to web browsers -/// - **WebSocket**: Real-time bidirectional communication -/// - **Live Dashboards**: Monitor workflow execution in real-time -/// - **Logging Services**: Forward events to centralized logging -/// - **Monitoring**: Send metrics to observability platforms +/// ⚠️ Must be injected via `AppRunner` — `App::invoke()` creates its own internal +/// bus and will not route events through this sink. /// -/// # Integration Pattern -/// -/// ⚠️ **Important**: `ChannelSink` must be passed to `AppRunner`, not used with `App.invoke()`: -/// -/// ```text -/// ❌ WRONG: -/// let bus = EventBus::default(); -/// bus.add_sink(ChannelSink::new(tx)); -/// graph.invoke(state).await; // Uses its OWN EventBus! -/// -/// ✅ CORRECT: -/// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); -/// let runner = AppRunner::builder().app(app).event_bus(bus).build().await; -/// runner.run_until_complete(&session_id).await; -/// ``` -/// -/// # Examples -/// -/// ## Basic Streaming +/// # Example /// /// ```rust,no_run /// use weavegraph::event_bus::{EventBus, ChannelSink}; /// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// use weavegraph::state::VersionedState; /// # use weavegraph::app::App; /// # async fn example(app: App) -> Result<(), Box> { /// -/// // Create channel /// let (tx, rx) = flume::unbounded(); -/// -/// // Create EventBus with ChannelSink /// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); /// -/// // Use AppRunner with custom EventBus /// let mut runner = AppRunner::builder() /// .app(app) /// .checkpointer(CheckpointerType::InMemory) -/// .autosave(false) /// .event_bus(bus) /// .build() /// .await; /// -/// let session_id = "my-session".to_string(); -/// runner.create_session( -/// session_id.clone(), -/// VersionedState::new_with_user_message("Process this") -/// ).await?; -/// -/// // Consume events in parallel /// tokio::spawn(async move { /// while let Ok(event) = rx.recv_async().await { -/// println!("Event: {:?}", event); +/// println!("{event:?}"); /// } /// }); /// -/// runner.run_until_complete(&session_id).await?; -/// # Ok(()) -/// # } -/// ``` -/// -/// ## Web Server Pattern (Per-Request Isolation) -/// -/// ```rust,no_run -/// use std::sync::Arc; -/// use weavegraph::event_bus::{EventBus, ChannelSink}; -/// use weavegraph::runtimes::{AppRunner, CheckpointerType}; -/// use weavegraph::state::VersionedState; -/// # use weavegraph::app::App; -/// # async fn handle_request(app: Arc, request_id: String) -> Result<(), Box> { -/// -/// // Each request gets its own channel and EventBus -/// let (tx, rx) = flume::unbounded(); -/// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); -/// -/// // Create isolated runner for this request -/// let mut runner = AppRunner::builder() -/// .app(Arc::try_unwrap(app).unwrap_or_else(|arc| (*arc).clone())) -/// .checkpointer(CheckpointerType::InMemory) -/// .autosave(false) -/// .event_bus(bus) -/// .build() -/// .await; -/// -/// let session_id = format!("request-{}", request_id); -/// runner.create_session( -/// session_id.clone(), -/// VersionedState::new_with_user_message("User query") -/// ).await?; -/// -/// // Events are isolated to this request's channel -/// runner.run_until_complete(&session_id).await?; +/// runner.run_until_complete("session").await?; /// # Ok(()) /// # } /// ``` /// -/// ## Server-Sent Events (SSE) with Axum -/// -/// ```rust,ignore -/// use axum::response::sse::{Event as SseEvent, Sse}; -/// use futures_util::stream::Stream; -/// -/// async fn stream_workflow( -/// State(app): State> -/// ) -> Sse>> { -/// let (tx, rx) = flume::unbounded(); -/// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); -/// -/// tokio::spawn(async move { -/// let mut runner = AppRunner::builder() -/// .app(Arc::try_unwrap(app).unwrap_or_else(|arc| (*arc).clone())) -/// .checkpointer(CheckpointerType::InMemory) -/// .autosave(false) -/// .event_bus(bus) -/// .build() -/// .await; -/// -/// let session_id = uuid::Uuid::new_v4().to_string(); -/// runner.create_session( -/// session_id.clone(), -/// VersionedState::new_with_user_message("Query") -/// ).await.ok(); -/// runner.run_until_complete(&session_id).await.ok(); -/// }); -/// -/// // Convert flume receiver to stream -/// let stream = rx.into_stream().map(|event| { -/// Ok(SseEvent::default().json_data(event).unwrap()) -/// }); -/// -/// Sse::new(stream) -/// } -/// ``` -/// -/// # Error Handling -/// -/// If the receiver is dropped, `handle()` returns an error which is logged by the EventBus -/// but doesn't stop event broadcasting to other sinks. -/// -/// # See Also -/// -/// - [`AppRunner::builder()`](crate::runtimes::runner::AppRunner::builder) - How to inject custom EventBus -/// - [`EventBus::with_sinks()`](crate::event_bus::EventBus::with_sinks) - Create EventBus with sinks -/// - Example: `examples/streaming_events.rs` - Complete working example +/// If the receiver is dropped before the graph finishes, `handle()` returns a +/// [`BrokenPipe`](io::ErrorKind::BrokenPipe) error. The event bus logs it and +/// continues broadcasting to any remaining sinks. pub struct ChannelSink { tx: flume::Sender, } impl ChannelSink { - /// Create a new ChannelSink that forwards events to the given channel. - /// - /// # Parameters - /// - /// * `tx` - The sender side of an unbounded flume channel - /// - /// # Returns - /// - /// A ChannelSink ready to be added to an EventBus. - /// - /// # Example - /// - /// ```rust,no_run - /// use weavegraph::event_bus::{EventBus, ChannelSink}; - /// - /// let (tx, rx) = flume::unbounded(); - /// let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); - /// ``` + /// Create a `ChannelSink` that forwards events through `tx`. pub fn new(tx: flume::Sender) -> Self { Self { tx } } diff --git a/tests/event_bus.rs b/tests/event_bus.rs index 754b41c..8fa90ee 100644 --- a/tests/event_bus.rs +++ b/tests/event_bus.rs @@ -710,17 +710,17 @@ fn event_strategy() -> impl Strategy { .prop_map( |(session_id, node_id, stream_id, chunk, metadata, is_final)| { let meta: FxHashMap = metadata.into_iter().collect(); - let event = LLMStreamingEvent::new( - session_id, - node_id, - stream_id, - chunk, - is_final, - None, - meta, - Utc::now(), - ); - Event::LLM(event) + let mut b = LLMStreamingEvent::builder(chunk).is_final(is_final).metadata(meta); + if let Some(id) = session_id { + b = b.session_id(id); + } + if let Some(id) = node_id { + b = b.node_id(id); + } + if let Some(id) = stream_id { + b = b.stream_id(id); + } + Event::LLM(b.build()) }, ); @@ -805,16 +805,13 @@ fn test_llm_event_to_json_value() { metadata.insert("token_count".to_string(), json!(42)); let timestamp = Utc::now(); - let llm_event = LLMStreamingEvent::new( - Some("session-123".to_string()), - Some("node-abc".to_string()), - Some("stream-xyz".to_string()), - "Thinking step by step...".to_string(), - false, - None, - metadata, - timestamp, - ); + let llm_event = LLMStreamingEvent::builder("Thinking step by step...") + .session_id("session-123") + .node_id("node-abc") + .stream_id("stream-xyz") + .metadata(metadata) + .timestamp(timestamp) + .build(); let event = Event::LLM(llm_event); let json = event.to_json_value(); @@ -831,16 +828,10 @@ fn test_llm_event_to_json_value() { #[test] fn test_llm_event_final_chunk() { - let llm_event = LLMStreamingEvent::new( - None, - None, - Some("stream-999".to_string()), - "Final chunk".to_string(), - true, - None, - FxHashMap::default(), - Utc::now(), - ); + let llm_event = LLMStreamingEvent::builder("Final chunk") + .stream_id("stream-999") + .is_final(true) + .build(); let event = Event::LLM(llm_event); let json = event.to_json_value(); From 14113cb919258b5501a73eb5074aabbe9ef6eb0d Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 15:38:15 -0400 Subject: [PATCH 04/15] heavy revision work for llm and graphs --- src/graphs/builder.rs | 366 ++++++---------------------------- src/graphs/compilation.rs | 351 +++++++++++--------------------- src/graphs/edges.rs | 72 ++----- src/graphs/iteration.rs | 228 ++++++--------------- src/graphs/mod.rs | 126 +----------- src/graphs/petgraph_compat.rs | 205 +++++-------------- src/llm/mod.rs | 5 +- src/llm/rig_adapter.rs | 35 ++-- src/llm/traits.rs | 18 +- 9 files changed, 333 insertions(+), 1073 deletions(-) diff --git a/src/graphs/builder.rs b/src/graphs/builder.rs index 9528be4..988a29b 100644 --- a/src/graphs/builder.rs +++ b/src/graphs/builder.rs @@ -1,7 +1,4 @@ -//! GraphBuilder implementation for constructing workflow graphs. -//! -//! This module contains the main GraphBuilder type and its fluent API -//! for constructing workflow graphs with nodes, edges, and configuration. +//! GraphBuilder — fluent API for constructing workflow graphs. use rustc_hash::FxHashMap; use std::sync::Arc; @@ -12,8 +9,7 @@ use crate::reducers::{Reducer, ReducerRegistry}; use crate::runtimes::{EventBusConfig, RuntimeConfig}; use crate::types::{ChannelType, NodeKind}; -/// Type alias for the internal parts of a GraphBuilder. -/// Used to reduce type complexity in the `into_parts()` method. +// Deconstructed builder state passed to the compiler. type GraphParts = ( FxHashMap>, FxHashMap>, @@ -22,26 +18,17 @@ type GraphParts = ( ReducerRegistry, ); -/// Builder for constructing workflow graphs with fluent API. +/// Fluent builder for workflow graphs. /// -/// `GraphBuilder` provides a builder pattern for constructing workflow graphs -/// by adding nodes, edges, and configuration before compiling to an executable -/// [`App`](crate::app::App). The builder ensures type safety and provides clear error messages -/// for common configuration mistakes. +/// Chain `add_node`, `add_edge`, and configuration calls, then call +/// [`compile`](Self::compile) to produce an executable [`App`](crate::app::App). /// -/// # Required Configuration -/// -/// Every graph must have: -/// - At least one executable node added via [`GraphBuilder::add_node`](Self::add_node) -/// - Edges connecting from `NodeKind::Start` to define entry points -/// - Edges connecting to `NodeKind::End` to define exit points -/// -/// Note: `NodeKind::Start` and `NodeKind::End` are virtual endpoints and should -/// never be registered with `add_node`. They exist only for structural definition. +/// `NodeKind::Start` and `NodeKind::End` are virtual endpoints — never register +/// them with `add_node`, but connect edges to/from them to define entry and +/// exit points. /// /// # Examples /// -/// ## Basic Usage /// ``` /// use weavegraph::graphs::GraphBuilder; /// use weavegraph::types::NodeKind; @@ -54,54 +41,22 @@ type GraphParts = ( /// # } /// # } /// -/// // Linear workflow: Start -> worker -> End /// let app = GraphBuilder::new() /// .add_node(NodeKind::Custom("worker".into()), MyNode) /// .add_edge(NodeKind::Start, NodeKind::Custom("worker".into())) /// .add_edge(NodeKind::Custom("worker".into()), NodeKind::End) /// .compile(); /// ``` -/// -/// ## Conditional Routing -/// ``` -/// use weavegraph::graphs::{GraphBuilder, EdgePredicate}; -/// use weavegraph::types::NodeKind; -/// use std::sync::Arc; -/// -/// # struct MyNode; -/// # #[async_trait::async_trait] -/// # impl weavegraph::node::Node for MyNode { -/// # async fn run(&self, _: weavegraph::state::StateSnapshot, _: weavegraph::node::NodeContext) -> Result { -/// # Ok(weavegraph::node::NodePartial::default()) -/// # } -/// # } -/// -/// let route_by_count: EdgePredicate = Arc::new(|snapshot| { -/// if snapshot.messages.len() > 5 { -/// vec!["heavy_processing".to_string()] -/// } else { -/// vec!["light_processing".to_string()] -/// } -/// }); -/// -/// let app = GraphBuilder::new() -/// .add_node(NodeKind::Custom("heavy_processing".into()), MyNode) -/// .add_node(NodeKind::Custom("light_processing".into()), MyNode) -/// .add_conditional_edge(NodeKind::Start, route_by_count) -/// .add_edge(NodeKind::Custom("heavy_processing".into()), NodeKind::End) -/// .add_edge(NodeKind::Custom("light_processing".into()), NodeKind::End) -/// .compile(); -/// ``` pub struct GraphBuilder { - /// Registry of all nodes in the graph, keyed by their identifier. + // Node registry keyed by identifier. nodes: FxHashMap>, - /// Unconditional edges defining static graph topology. + // Unconditional edges: source → targets. edges: FxHashMap>, - /// Conditional edges for dynamic routing based on state. + // Conditional edges for state-driven routing. conditional_edges: Vec, - /// Runtime configuration for the compiled application. + // Runtime configuration carried into the compiled app. runtime_config: RuntimeConfig, - /// Reducer registry for channel update operations. + // Reducer registry for channel update operations. reducer_registry: ReducerRegistry, } @@ -112,20 +67,7 @@ impl Default for GraphBuilder { } impl GraphBuilder { - /// Creates a new, empty graph builder. - /// - /// The builder starts with no nodes, edges, or configuration. - /// Use the fluent API methods to add components before calling - /// [`compile`](Self::compile). - /// - /// # Examples - /// - /// ``` - /// use weavegraph::graphs::GraphBuilder; - /// - /// let builder = GraphBuilder::new(); - /// // Add nodes, edges, and configuration... - /// ``` + /// Creates an empty builder. #[must_use] pub fn new() -> Self { Self { @@ -137,50 +79,15 @@ impl GraphBuilder { } } - /// Adds a conditional edge to the graph. - /// - /// Conditional edges enable dynamic routing based on the current state. - /// When execution reaches the `from` node, the `predicate` function is - /// evaluated with the current [`StateSnapshot`](crate::state::StateSnapshot) and returns the target - /// node names for routing. + /// Registers a node with the given identifier. /// - /// # Parameters - /// - /// - `from`: The source node for the conditional edge - /// - `predicate`: Function that determines target nodes based on state - #[must_use] - pub fn add_conditional_edge(mut self, from: NodeKind, predicate: EdgePredicate) -> Self { - self.conditional_edges - .push(ConditionalEdge::new(from, predicate)); - self - } - - /// Adds a node to the graph. - /// - /// NOTE: `NodeKind::Start` and `NodeKind::End` are virtual structural endpoints. - /// If either is passed to `add_node`, the registration is ignored and a warning - /// is emitted. They are not stored in the node registry and are never executed; - /// the scheduler skips them automatically while still allowing edges from - /// `Start` and to `End` for topology. - /// - /// Registers a node implementation with the given identifier. Each node - /// must have a unique [`NodeKind`] identifier within the graph. The node - /// implementation must implement the [`Node`] trait. - /// - /// # Parameters - /// - /// - `id`: Unique identifier for this node in the graph - /// - `node`: Implementation of the [`Node`] trait + /// Attempting to register `NodeKind::Start` or `NodeKind::End` logs a + /// warning and is otherwise a no-op — those identifiers are virtual. #[must_use] pub fn add_node(mut self, id: NodeKind, node: impl Node + 'static) -> Self { - // Ignore attempts to register virtual Start/End node kinds; emit a warning. match id { NodeKind::Start | NodeKind::End => { - tracing::warn!( - ?id, - "Ignoring registration of virtual node kind (Start/End are virtual)" - ); - // Do not insert into registry. + tracing::warn!(?id, "Ignoring registration of virtual node kind (Start/End are virtual)"); } _ => { self.nodes.insert(id, Arc::new(node)); @@ -189,60 +96,41 @@ impl GraphBuilder { self } - /// Adds an unconditional edge between two nodes. - /// - /// Creates a direct connection from one node to another. When the `from` - /// node completes execution, the scheduler will consider the `to` node - /// for execution in the next step. Multiple edges from the same node - /// create fan-out patterns, while multiple edges to the same node - /// create fan-in patterns. - /// - /// # Parameters - /// - /// - `from`: Source node identifier - /// - `to`: Target node identifier + /// Adds an unconditional edge from `from` to `to`. #[must_use] pub fn add_edge(mut self, from: NodeKind, to: NodeKind) -> Self { self.edges.entry(from).or_default().push(to); self } - /// Configures runtime settings for the compiled application. + /// Adds a conditional edge. /// - /// Runtime configuration controls execution behavior such as concurrency - /// limits, checkpointing, and session management. If not specified, - /// default configuration is used. - /// - /// # Parameters - /// - /// - `runtime_config`: Configuration for the compiled application + /// When execution reaches `from`, `predicate` is called with the current + /// [`StateSnapshot`](crate::state::StateSnapshot) and returns the names of + /// the next nodes to activate. + #[must_use] + pub fn add_conditional_edge(mut self, from: NodeKind, predicate: EdgePredicate) -> Self { + self.conditional_edges.push(ConditionalEdge::new(from, predicate)); + self + } + + /// Sets the runtime configuration for the compiled app. #[must_use] pub fn with_runtime_config(mut self, runtime_config: RuntimeConfig) -> Self { self.runtime_config = runtime_config; self } - /// Overrides only the event bus configuration while keeping other runtime settings. + /// Overrides the event bus configuration while keeping all other runtime settings. #[must_use] pub fn with_event_bus_config(mut self, config: EventBusConfig) -> Self { - let mut runtime_config = self.runtime_config.clone(); - runtime_config.event_bus = config; - self.runtime_config = runtime_config; + self.runtime_config.event_bus = config; self } - /// Registers a custom reducer for a specific channel. - /// - /// This method enables registration of custom reducers to extend or replace - /// the default reducer behavior for a channel. Multiple reducers can be - /// registered for the same channel and will be applied in registration order. - /// - /// # Parameters - /// - /// - `channel`: The channel type to register the reducer for - /// - `reducer`: The reducer implementation wrapped in Arc + /// Appends a custom reducer for `channel`. /// - /// # Examples + /// Multiple reducers for the same channel are applied in registration order. /// /// ``` /// use std::sync::Arc; @@ -271,58 +159,16 @@ impl GraphBuilder { self } - /// Replaces the entire reducer registry with a custom one. - /// - /// This method allows complete control over reducer configuration by - /// replacing the default registry. Useful when you need fine-grained - /// control over reducer ordering or want to start with an empty registry. - /// - /// # Parameters - /// - /// - `registry`: The reducer registry to use - /// - /// # Examples - /// - /// ``` - /// use std::sync::Arc; - /// use weavegraph::graphs::GraphBuilder; - /// use weavegraph::reducers::{ReducerRegistry, AddMessages}; - /// use weavegraph::types::{ChannelType, NodeKind}; - /// - /// # struct MyNode; - /// # #[async_trait::async_trait] - /// # impl weavegraph::node::Node for MyNode { - /// # async fn run(&self, _: weavegraph::state::StateSnapshot, _: weavegraph::node::NodeContext) -> Result { - /// # Ok(weavegraph::node::NodePartial::default()) - /// # } - /// # } - /// - /// let custom_registry = ReducerRegistry::new() - /// .with_reducer(ChannelType::Message, Arc::new(AddMessages)); - /// - /// let app = GraphBuilder::new() - /// .add_node(NodeKind::Custom("worker".into()), MyNode) - /// .with_reducer_registry(custom_registry) - /// .add_edge(NodeKind::Start, NodeKind::Custom("worker".into())) - /// .add_edge(NodeKind::Custom("worker".into()), NodeKind::End) - /// .compile(); - /// ``` + /// Replaces the entire reducer registry. #[must_use] pub fn with_reducer_registry(mut self, registry: ReducerRegistry) -> Self { self.reducer_registry = registry; self } - // ========================================================================= - // Iterators (petgraph-style API) - // ========================================================================= - - /// Returns an iterator over all registered nodes in the graph. - /// - /// This iterates over custom nodes only; virtual `Start` and `End` nodes - /// are not included as they are not stored in the registry. + /// Iterates over all registered node identifiers. /// - /// # Examples + /// Virtual `Start` and `End` nodes are not included. /// /// ``` /// use weavegraph::graphs::GraphBuilder; @@ -338,25 +184,18 @@ impl GraphBuilder { /// /// let builder = GraphBuilder::new() /// .add_node(NodeKind::Custom("A".into()), MyNode) - /// .add_node(NodeKind::Custom("B".into()), MyNode) - /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) - /// .add_edge(NodeKind::Custom("A".into()), NodeKind::Custom("B".into())) - /// .add_edge(NodeKind::Custom("B".into()), NodeKind::End); + /// .add_node(NodeKind::Custom("B".into()), MyNode); /// - /// let node_count = builder.nodes().count(); - /// assert_eq!(node_count, 2); + /// assert_eq!(builder.nodes().count(), 2); /// ``` pub fn nodes(&self) -> super::iteration::NodesIter<'_> { super::iteration::NodesIter::new(self.nodes.keys()) } - /// Returns an iterator over all edges in the graph as (source, target) pairs. + /// Iterates over all unconditional edges as `(source, target)` pairs. /// - /// Includes edges from/to virtual `Start` and `End` nodes. - /// The iteration order is not deterministic due to hash map iteration; - /// use [`topological_sort`](Self::topological_sort) for ordered traversal. - /// - /// # Examples + /// Iteration order is not deterministic; use [`topological_sort`](Self::topological_sort) + /// for ordered traversal. /// /// ``` /// use weavegraph::graphs::GraphBuilder; @@ -375,51 +214,29 @@ impl GraphBuilder { /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) /// .add_edge(NodeKind::Custom("A".into()), NodeKind::End); /// - /// let edge_count = builder.edges().count(); - /// assert_eq!(edge_count, 2); + /// assert_eq!(builder.edges().count(), 2); /// ``` pub fn edges(&self) -> super::iteration::EdgesIter<'_> { super::iteration::EdgesIter::new(&self.edges) } - /// Returns the number of registered nodes in the graph. - /// - /// Does not include virtual `Start` and `End` nodes. + /// Returns the number of registered nodes (excludes virtual nodes). #[must_use] pub fn node_count(&self) -> usize { self.nodes.len() } - /// Returns the number of edges in the graph. - /// - /// Counts all edges including those from/to virtual nodes. + /// Returns the total number of unconditional edges. #[must_use] pub fn edge_count(&self) -> usize { self.edges.values().map(|v| v.len()).sum() } - // ========================================================================= - // Graph Algorithms - // ========================================================================= - - /// Returns a topologically sorted list of all nodes in the graph. + /// Returns a topologically sorted node list. /// - /// The result includes virtual `Start` (always first) and `End` (always last) - /// nodes along with all custom nodes. Nodes at the same topological level - /// are sorted lexicographically for deterministic ordering. - /// - /// This is useful for: - /// - Deterministic iteration over nodes - /// - Dependency analysis - /// - Visualization and debugging - /// - /// # Note - /// - /// This method assumes the graph is acyclic. If the graph contains cycles, - /// the result will exclude nodes involved in cycles. Use [`compile`](Self::compile) - /// to validate the graph before relying on topological sort. - /// - /// # Examples + /// `NodeKind::Start` is always first; `NodeKind::End` is always last. + /// Nodes at the same depth are sorted lexicographically for determinism. + /// Nodes in cycles (if any) are excluded from the result. /// /// ``` /// use weavegraph::graphs::GraphBuilder; @@ -443,49 +260,15 @@ impl GraphBuilder { /// let sorted = builder.topological_sort(); /// assert_eq!(sorted[0], NodeKind::Start); /// assert_eq!(sorted[sorted.len() - 1], NodeKind::End); - /// - /// // A comes before B due to edge A -> B - /// let a_pos = sorted.iter().position(|n| n == &NodeKind::Custom("A".into())).unwrap(); - /// let b_pos = sorted.iter().position(|n| n == &NodeKind::Custom("B".into())).unwrap(); - /// assert!(a_pos < b_pos); /// ``` #[must_use] pub fn topological_sort(&self) -> Vec { super::iteration::topological_sort(&self.edges) } - // ========================================================================= - // petgraph Compatibility (feature-gated) - // ========================================================================= - - /// Converts the graph to a petgraph `DiGraph` for advanced algorithms. - /// - /// This is useful for: - /// - Advanced graph algorithms (shortest path, max flow, etc.) - /// - Graph analysis and metrics - /// - Integration with petgraph ecosystem tools + /// Converts the graph to a petgraph `DiGraph`. /// - /// # Feature Gate - /// - /// This method requires the `petgraph-compat` feature: - /// ```toml - /// weavegraph = { features = ["petgraph-compat"] } - /// ``` - /// - /// # Examples - /// - /// ```ignore - /// use weavegraph::graphs::GraphBuilder; - /// use petgraph::algo::is_cyclic_directed; - /// - /// let builder = GraphBuilder::new() - /// .add_node(NodeKind::Custom("A".into()), MyNode) - /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) - /// .add_edge(NodeKind::Custom("A".into()), NodeKind::End); - /// - /// let pg = builder.to_petgraph(); - /// assert!(!is_cyclic_directed(&pg.graph)); - /// ``` + /// Requires the `petgraph-compat` feature. #[cfg(feature = "petgraph-compat")] #[cfg_attr(docsrs, doc(cfg(feature = "petgraph-compat")))] #[must_use] @@ -493,33 +276,9 @@ impl GraphBuilder { super::petgraph_compat::to_petgraph(&self.edges) } - /// Exports the graph to DOT format for visualization. - /// - /// The output can be rendered using Graphviz (`dot -Tpng graph.dot -o graph.png`) - /// or online tools like . - /// - /// # Feature Gate - /// - /// This method requires the `petgraph-compat` feature: - /// ```toml - /// weavegraph = { features = ["petgraph-compat"] } - /// ``` - /// - /// # Examples - /// - /// ```ignore - /// use weavegraph::graphs::GraphBuilder; - /// use std::fs; + /// Exports the graph to DOT format for Graphviz rendering. /// - /// let builder = GraphBuilder::new() - /// .add_node(NodeKind::Custom("A".into()), MyNode) - /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) - /// .add_edge(NodeKind::Custom("A".into()), NodeKind::End); - /// - /// let dot = builder.to_dot(); - /// fs::write("workflow.dot", &dot)?; - /// // Then run: dot -Tpng workflow.dot -o workflow.png - /// ``` + /// Requires the `petgraph-compat` feature. #[cfg(feature = "petgraph-compat")] #[cfg_attr(docsrs, doc(cfg(feature = "petgraph-compat")))] #[must_use] @@ -527,14 +286,9 @@ impl GraphBuilder { super::petgraph_compat::to_dot(&self.edges) } - /// Checks if the graph contains cycles using petgraph's algorithm. - /// - /// This provides an alternative to the built-in cycle detection for - /// cross-verification or when you need petgraph's specific behavior. - /// - /// # Feature Gate + /// Checks for cycles using petgraph's algorithm. /// - /// This method requires the `petgraph-compat` feature. + /// Requires the `petgraph-compat` feature. #[cfg(feature = "petgraph-compat")] #[cfg_attr(docsrs, doc(cfg(feature = "petgraph-compat")))] #[must_use] @@ -542,11 +296,6 @@ impl GraphBuilder { super::petgraph_compat::is_cyclic(&self.edges) } - // ========================================================================= - // Internal Helpers - // ========================================================================= - - /// Extracts the components for compilation (internal use only). pub(super) fn into_parts(self) -> GraphParts { ( self.nodes, @@ -557,14 +306,15 @@ impl GraphBuilder { ) } - // Internal read-only accessors for validation in sibling modules pub(super) fn nodes_ref(&self) -> &FxHashMap> { &self.nodes } + pub(super) fn edges_ref(&self) -> &FxHashMap> { &self.edges } - pub(super) fn conditional_edges_ref(&self) -> &Vec { + + pub(super) fn conditional_edges_ref(&self) -> &[ConditionalEdge] { &self.conditional_edges } } diff --git a/src/graphs/compilation.rs b/src/graphs/compilation.rs index 96493d0..284a5a8 100644 --- a/src/graphs/compilation.rs +++ b/src/graphs/compilation.rs @@ -1,21 +1,19 @@ //! Graph compilation logic and validation. -//! -//! This module contains the logic for compiling a GraphBuilder into an -//! executable App, including structural validation and actionable errors. use crate::app::App; use crate::types::NodeKind; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::collections::VecDeque; /// Errors that can occur when compiling a graph. #[derive(Debug, thiserror::Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum GraphCompileError { - /// No entry edge was defined from the virtual Start node. + /// No entry edge defined from the virtual Start node. #[error("missing entry: no edge or conditional edge originates from Start")] MissingEntry, - /// An edge references a node that is not registered in the graph. + /// An edge references an unregistered node. #[error("unknown node referenced in edge: {0}")] UnknownNode(NodeKind), @@ -23,53 +21,40 @@ pub enum GraphCompileError { #[error("invalid edge: cannot originate from End")] EdgeFromEnd, - /// A cycle was detected in the graph. + /// A cycle was detected among unconditional edges. #[error("cycle detected in graph: {}", .cycle.iter().map(|n| n.to_string()).collect::>().join(" -> "))] CycleDetected { - /// The cycle path showing nodes forming the cycle. + /// Nodes forming the cycle, in traversal order. cycle: Vec, }, - /// One or more nodes are unreachable from the Start node. + /// One or more nodes are unreachable from Start. #[error("unreachable nodes detected (no path from Start): {}", .nodes.iter().map(|n| n.to_string()).collect::>().join(", "))] UnreachableNodes { - /// List of nodes with no path from Start. + /// Nodes with no path from Start. nodes: Vec, }, - /// One or more nodes have no path to the End node. + /// One or more nodes have no path to End. #[error("nodes with no path to End: {}", .nodes.iter().map(|n| n.to_string()).collect::>().join(", "))] NoPathToEnd { - /// List of nodes that cannot reach End. + /// Nodes that cannot reach End. nodes: Vec, }, - /// A duplicate edge was detected. + /// A duplicate unconditional edge was detected. #[error("duplicate edge detected: {} -> {}", .from, .to)] DuplicateEdge { - /// The source node of the duplicate edge. + /// Source node. from: NodeKind, - /// The target node of the duplicate edge. + /// Target node. to: NodeKind, }, } -/// Compilation logic for GraphBuilder. impl super::builder::GraphBuilder { - /// Compiles the graph into an executable application. + /// Validates the graph and produces an executable [`App`]. /// - /// Validates the graph configuration and converts it into an [`App`] that - /// can execute workflows. This method performs validation checks to prevent - /// common topology issues (missing entry, cycles, unknown nodes, duplicates). - /// - /// # Returns - /// - /// - `Ok(App)`: Successfully compiled application ready for execution - /// - `Err(GraphCompileError)`: Structural validation failed; inspect the variant - /// - /// # Examples - /// - /// Basic pattern with error propagation: /// ``` /// use weavegraph::graphs::GraphBuilder; /// use weavegraph::types::NodeKind; @@ -89,69 +74,84 @@ impl super::builder::GraphBuilder { /// .compile()?; /// # Ok::<_, weavegraph::graphs::GraphCompileError>(()) /// ``` - /// - /// Explicit handling with pattern matching: - /// ``` - /// use weavegraph::graphs::{GraphBuilder, GraphCompileError}; - /// use weavegraph::types::NodeKind; - /// - /// let result = GraphBuilder::new() - /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) - /// .compile(); - /// - /// match result { - /// Ok(_app) => {} - /// Err(GraphCompileError::MissingEntry) => { - /// eprintln!("graph has no entry edge from Start"); - /// } - /// Err(GraphCompileError::UnknownNode(nk)) => { - /// eprintln!("unknown node referenced: {nk}"); - /// } - /// Err(e) => { - /// eprintln!("graph validation failed: {e}"); - /// } - /// } - /// ``` pub fn compile(self) -> Result { - // Validate without consuming self self.validate()?; - let (nodes, edges, conditional_edges, runtime_config, reducer_registry) = self.into_parts(); - Ok(App::from_parts( - nodes, - edges, - conditional_edges, - runtime_config, - reducer_registry, - )) + Ok(App::from_parts(nodes, edges, conditional_edges, runtime_config, reducer_registry)) } - /// Detects cycles in the graph using DFS with color marking. + /// Validates the graph for structural correctness. /// - /// Only checks unconditional edges, as conditional edge targets are runtime-determined. - /// Returns the first cycle found as a path of nodes. - fn detect_cycle(&self) -> Option> { - #[derive(Clone, Copy, PartialEq)] - enum Color { - White, // Not visited - Gray, // Currently visiting - Black, // Fully visited + /// Checks in order: + /// - At least one entry edge from Start (unconditional or conditional) + /// - No cycle in unconditional edges + /// - All registered nodes reachable from Start (skipped when conditional edges exist) + /// - All registered nodes have a path to End (skipped when conditional edges exist) + /// - No duplicate unconditional edges + /// - No edge from End; all Custom node references are registered + pub fn validate(&self) -> Result<(), GraphCompileError> { + let has_start_edge = self + .edges_ref() + .get(&NodeKind::Start) + .is_some_and(|v| !v.is_empty()) + || self + .conditional_edges_ref() + .iter() + .any(|ce| ce.from() == &NodeKind::Start); + + if !has_start_edge { + return Err(GraphCompileError::MissingEntry); } - let mut colors: FxHashMap = FxHashMap::default(); - let mut path: Vec = Vec::new(); + if let Some(cycle) = self.detect_cycle() { + return Err(GraphCompileError::CycleDetected { cycle }); + } + + if self.conditional_edges_ref().is_empty() { + let unreachable = self.detect_unreachable_nodes(); + if !unreachable.is_empty() { + return Err(GraphCompileError::UnreachableNodes { nodes: unreachable }); + } - // Initialize all nodes as White - for from in self.edges_ref().keys() { - colors.entry(from.clone()).or_insert(Color::White); + let no_path = self.detect_no_path_to_end(); + if !no_path.is_empty() { + return Err(GraphCompileError::NoPathToEnd { nodes: no_path }); + } + } + + if let Some((from, to)) = self.detect_duplicate_edge() { + return Err(GraphCompileError::DuplicateEdge { from, to }); } - for tos in self.edges_ref().values() { + + for (from, tos) in self.edges_ref() { + if matches!(from, NodeKind::End) { + return Err(GraphCompileError::EdgeFromEnd); + } + if let NodeKind::Custom(_) = from + && !self.nodes_ref().contains_key(from) + { + return Err(GraphCompileError::UnknownNode(from.clone())); + } for to in tos { - colors.entry(to.clone()).or_insert(Color::White); + if let NodeKind::Custom(_) = to + && !self.nodes_ref().contains_key(to) + { + return Err(GraphCompileError::UnknownNode(to.clone())); + } } } - // DFS helper function + Ok(()) + } + + fn detect_cycle(&self) -> Option> { + #[derive(Clone, Copy, PartialEq)] + enum Color { + White, + Gray, + Black, + } + fn dfs( node: &NodeKind, colors: &mut FxHashMap, @@ -161,26 +161,21 @@ impl super::builder::GraphBuilder { colors.insert(node.clone(), Color::Gray); path.push(node.clone()); - if let Some(neighbors) = edges.get(node) { - for neighbor in neighbors { - match colors.get(neighbor).copied().unwrap_or(Color::White) { - Color::White => { - if let Some(cycle) = dfs(neighbor, colors, path, edges) { - return Some(cycle); - } + for neighbor in edges.get(node).into_iter().flatten() { + match colors.get(neighbor).copied().unwrap_or(Color::White) { + Color::White => { + if let Some(cycle) = dfs(neighbor, colors, path, edges) { + return Some(cycle); } - Color::Gray => { - // Found a back edge - extract the cycle - if let Some(cycle_start) = path.iter().position(|n| n == neighbor) { - let mut cycle = path[cycle_start..].to_vec(); - cycle.push(neighbor.clone()); // Complete the cycle - return Some(cycle); - } - } - Color::Black => { - // Already fully explored, skip + } + Color::Gray => { + if let Some(start) = path.iter().position(|n| n == neighbor) { + let mut cycle = path[start..].to_vec(); + cycle.push(neighbor.clone()); + return Some(cycle); } } + Color::Black => {} } } @@ -189,9 +184,18 @@ impl super::builder::GraphBuilder { None } - // Try DFS from each unvisited node - for node in colors.clone().keys() { - if colors.get(node).copied().unwrap_or(Color::White) == Color::White + let mut colors: FxHashMap = self + .edges_ref() + .iter() + .flat_map(|(from, tos)| std::iter::once(from).chain(tos)) + .map(|n| (n.clone(), Color::White)) + .collect(); + + let nodes: Vec = colors.keys().cloned().collect(); + let mut path = Vec::new(); + + for node in &nodes { + if colors.get(node).copied() == Some(Color::White) && let Some(cycle) = dfs(node, &mut colors, &mut path, self.edges_ref()) { return Some(cycle); @@ -201,187 +205,72 @@ impl super::builder::GraphBuilder { None } - /// Detects unreachable nodes (nodes with no path from Start). - /// - /// Only checks unconditional edges. Returns registered Custom nodes that - /// cannot be reached from Start via unconditional edges. fn detect_unreachable_nodes(&self) -> Vec { - use std::collections::VecDeque; - - let mut reachable: FxHashMap = FxHashMap::default(); + let mut reachable: FxHashSet = FxHashSet::default(); let mut queue: VecDeque = VecDeque::new(); - // Start BFS from Start node queue.push_back(NodeKind::Start); - reachable.insert(NodeKind::Start, true); + reachable.insert(NodeKind::Start); while let Some(node) = queue.pop_front() { - if let Some(neighbors) = self.edges_ref().get(&node) { - for neighbor in neighbors { - if !reachable.contains_key(neighbor) { - reachable.insert(neighbor.clone(), true); - queue.push_back(neighbor.clone()); - } + for neighbor in self.edges_ref().get(&node).into_iter().flatten() { + if reachable.insert(neighbor.clone()) { + queue.push_back(neighbor.clone()); } } } - // Find registered Custom nodes that are not reachable let mut unreachable: Vec = self .nodes_ref() .keys() - .filter(|node| !reachable.contains_key(node)) + .filter(|n| !reachable.contains(*n)) .cloned() .collect(); - - unreachable.sort_by_key(|a| a.to_string()); + unreachable.sort_by_key(|n| n.to_string()); unreachable } - /// Detects nodes with no path to End. - /// - /// Only checks unconditional edges. Returns registered Custom nodes that - /// cannot reach End via unconditional edges. fn detect_no_path_to_end(&self) -> Vec { - use std::collections::VecDeque; - - // Build reverse graph (for backward traversal from End) - let mut reverse_edges: FxHashMap> = FxHashMap::default(); + let mut reverse: FxHashMap> = FxHashMap::default(); for (from, tos) in self.edges_ref() { for to in tos { - reverse_edges - .entry(to.clone()) - .or_default() - .push(from.clone()); + reverse.entry(to.clone()).or_default().push(from.clone()); } } - let mut can_reach_end: FxHashMap = FxHashMap::default(); + let mut can_reach_end: FxHashSet = FxHashSet::default(); let mut queue: VecDeque = VecDeque::new(); - // Start BFS from End node (backward) queue.push_back(NodeKind::End); - can_reach_end.insert(NodeKind::End, true); + can_reach_end.insert(NodeKind::End); while let Some(node) = queue.pop_front() { - if let Some(predecessors) = reverse_edges.get(&node) { - for predecessor in predecessors { - if !can_reach_end.contains_key(predecessor) { - can_reach_end.insert(predecessor.clone(), true); - queue.push_back(predecessor.clone()); - } + for pred in reverse.get(&node).into_iter().flatten() { + if can_reach_end.insert(pred.clone()) { + queue.push_back(pred.clone()); } } } - // Find registered Custom nodes that cannot reach End let mut no_path: Vec = self .nodes_ref() .keys() - .filter(|node| !can_reach_end.contains_key(node)) + .filter(|n| !can_reach_end.contains(*n)) .cloned() .collect(); - - no_path.sort_by_key(|a| a.to_string()); + no_path.sort_by_key(|n| n.to_string()); no_path } - /// Detects duplicate edges in the graph. - /// - /// Returns the first duplicate edge found. fn detect_duplicate_edge(&self) -> Option<(NodeKind, NodeKind)> { - use rustc_hash::FxHashSet; - for (from, tos) in self.edges_ref() { - let mut seen: FxHashSet = FxHashSet::default(); + let mut seen: FxHashSet<&NodeKind> = FxHashSet::default(); for to in tos { - if !seen.insert(to.clone()) { - // Found a duplicate + if !seen.insert(to) { return Some((from.clone(), to.clone())); } } } None } - - /// Validates the graph for common structural issues. - /// - /// Validation rules: - /// - There must be at least one entry edge from Start (unconditional or conditional) - /// - No edge may originate from End - /// - Any Custom node referenced by an edge (as from/to) must be registered - /// - The graph must not contain cycles (checked on unconditional edges only) - /// - All registered nodes must be reachable from Start (unconditional edges only) - /// - All registered nodes must have a path to End (unconditional edges only) - /// - No duplicate edges are allowed - pub fn validate(&self) -> Result<(), GraphCompileError> { - // Rule 1: Entry edge from Start exists (either unconditional or conditional) - let has_start_edge = self - .edges_ref() - .get(&NodeKind::Start) - .map(|v| !v.is_empty()) - .unwrap_or(false) - || self - .conditional_edges_ref() - .iter() - .any(|ce| ce.from() == &NodeKind::Start); - - if !has_start_edge { - return Err(GraphCompileError::MissingEntry); - } - - // Rule 2: Detect cycles in unconditional edges - if let Some(cycle) = self.detect_cycle() { - return Err(GraphCompileError::CycleDetected { cycle }); - } - - // Rule 3 and 4: Reachability validations (skip when conditional edges exist) - let has_conditional = !self.conditional_edges_ref().is_empty(); - if !has_conditional { - // Detect unreachable nodes - let unreachable = self.detect_unreachable_nodes(); - if !unreachable.is_empty() { - return Err(GraphCompileError::UnreachableNodes { nodes: unreachable }); - } - - // Detect nodes with no path to End - let no_path_to_end = self.detect_no_path_to_end(); - if !no_path_to_end.is_empty() { - return Err(GraphCompileError::NoPathToEnd { - nodes: no_path_to_end, - }); - } - } - - // Rule 5: Detect duplicate edges - if let Some((from, to)) = self.detect_duplicate_edge() { - return Err(GraphCompileError::DuplicateEdge { from, to }); - } - - // Rule 6 and 7: Validate each unconditional edge - for (from, tos) in self.edges_ref() { - // End cannot have outgoing edges - if matches!(from, NodeKind::End) { - return Err(GraphCompileError::EdgeFromEnd); - } - - // If from is Custom, it must be registered - if let NodeKind::Custom(_) = from - && !self.nodes_ref().contains_key(from) - { - return Err(GraphCompileError::UnknownNode(from.clone())); - } - - for to in tos { - // If to is Custom, it must be registered - if let NodeKind::Custom(_) = to - && !self.nodes_ref().contains_key(to) - { - return Err(GraphCompileError::UnknownNode(to.clone())); - } - } - } - - Ok(()) - } } diff --git a/src/graphs/edges.rs b/src/graphs/edges.rs index 1f1df0b..ab1c70a 100644 --- a/src/graphs/edges.rs +++ b/src/graphs/edges.rs @@ -1,18 +1,13 @@ //! Edge types and routing predicates for conditional graph flow. -//! -//! This module contains the types and predicates used for dynamic routing -//! in workflow graphs, including conditional edges that can route based -//! on runtime state evaluation. use crate::types::NodeKind; use std::sync::Arc; /// Predicate function for conditional edge routing. /// -/// Takes a [`StateSnapshot`](crate::state::StateSnapshot) and returns target node names to determine -/// which nodes should be executed next. Predicates are used with -/// [`GraphBuilder::add_conditional_edge`](crate::graphs::GraphBuilder::add_conditional_edge) to create dynamic routing based -/// on the current state. +/// Receives a [`StateSnapshot`](crate::state::StateSnapshot) and returns the names of target +/// nodes to execute next. Used with +/// [`GraphBuilder::add_conditional_edge`](crate::graphs::GraphBuilder::add_conditional_edge). /// /// # Examples /// @@ -21,41 +16,21 @@ use std::sync::Arc; /// use weavegraph::types::NodeKind; /// use std::sync::Arc; /// -/// // Route based on message count, using NodeKind helpers for targets -/// let route_by_messages: EdgePredicate = Arc::new(|snapshot| { +/// let route: EdgePredicate = Arc::new(|snapshot| { /// if snapshot.messages.len() > 5 { /// vec![NodeKind::Custom("many_messages".into()).as_target()] /// } else { /// vec![NodeKind::Custom("few_messages".into()).as_target()] /// } /// }); -/// -/// // Route based on extra data - fan out to multiple nodes and optionally End -/// let route_by_error: EdgePredicate = Arc::new(|snapshot| { -/// if snapshot.extra.get("error").is_some() { -/// vec![ -/// NodeKind::Custom("error_handler".into()).as_target(), -/// NodeKind::Custom("logger".into()).as_target(), -/// ] -/// } else { -/// vec![NodeKind::end_target()] -/// } -/// }); /// ``` pub type EdgePredicate = Arc Vec + Send + Sync + 'static>; -/// A conditional edge that routes based on a predicate function. -/// -/// Conditional edges allow dynamic routing in workflows based on the current -/// state. When the scheduler encounters a conditional edge, it evaluates the -/// predicate function and routes to the returned target nodes. -/// -/// # Purpose +/// A conditional edge that routes to nodes determined by a predicate. /// -/// This type encapsulates conditional routing logic to enable clean builder patterns -/// and maintain consistency with other edge types. The private fields ensure that -/// conditional edges are constructed through proper APIs rather than direct field access. +/// When the scheduler encounters a conditional edge it calls the predicate +/// with the current state and dispatches to the returned target nodes. /// /// # Examples /// @@ -75,37 +50,14 @@ pub type EdgePredicate = /// ``` #[derive(Clone)] pub struct ConditionalEdge { - /// The source node for this conditional edge. + // Source node for this edge. from: NodeKind, - /// The predicate function that determines target node. + // Routing predicate. predicate: EdgePredicate, } impl ConditionalEdge { - /// Creates a new conditional edge. - /// - /// This is the preferred way to construct conditional edges, providing a clean - /// API that works with the builder pattern while ensuring proper encapsulation. - /// - /// # Parameters - /// - /// - `from`: The source node identifier - /// - `predicate`: The routing predicate function - /// - /// # Examples - /// - /// ``` - /// use weavegraph::graphs::{ConditionalEdge, EdgePredicate}; - /// use weavegraph::types::NodeKind; - /// use std::sync::Arc; - /// - /// let predicate: EdgePredicate = Arc::new(|_snapshot| { - /// vec![NodeKind::Custom("target_node".into()).as_target()] - /// }); - /// - /// let edge = ConditionalEdge::new(NodeKind::Custom("source".into()), predicate.clone()); - /// let edge2 = ConditionalEdge::new(NodeKind::Start, predicate); - /// ``` + /// Creates a conditional edge from `from` to the targets returned by `predicate`. pub fn new(from: impl Into, predicate: EdgePredicate) -> Self { Self { from: from.into(), @@ -113,12 +65,12 @@ impl ConditionalEdge { } } - /// Returns the source node of this conditional edge. + /// Returns the source node. pub fn from(&self) -> &NodeKind { &self.from } - /// Returns the predicate function of this conditional edge. + /// Returns the routing predicate. pub fn predicate(&self) -> &EdgePredicate { &self.predicate } diff --git a/src/graphs/iteration.rs b/src/graphs/iteration.rs index 58d0a1c..6954a85 100644 --- a/src/graphs/iteration.rs +++ b/src/graphs/iteration.rs @@ -1,62 +1,13 @@ -//! Graph iteration utilities and algorithms. -//! -//! This module provides idiomatic iterators and common graph algorithms -//! for inspecting and analyzing workflow graphs. Inspired by petgraph's -//! visit module patterns. -//! -//! # Iterators -//! -//! - [`NodesIter`]: Iterate over all nodes in the graph -//! - [`EdgesIter`]: Iterate over all edges as (source, target) pairs -//! -//! # Algorithms -//! -//! - [`topological_sort`](crate::graphs::GraphBuilder::topological_sort): Deterministic node ordering -//! -//! # Examples -//! -//! ``` -//! use weavegraph::graphs::GraphBuilder; -//! use weavegraph::types::NodeKind; -//! -//! # struct MyNode; -//! # #[async_trait::async_trait] -//! # impl weavegraph::node::Node for MyNode { -//! # async fn run(&self, _: weavegraph::state::StateSnapshot, _: weavegraph::node::NodeContext) -> Result { -//! # Ok(weavegraph::node::NodePartial::default()) -//! # } -//! # } -//! -//! let builder = GraphBuilder::new() -//! .add_node(NodeKind::Custom("A".into()), MyNode) -//! .add_node(NodeKind::Custom("B".into()), MyNode) -//! .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) -//! .add_edge(NodeKind::Custom("A".into()), NodeKind::Custom("B".into())) -//! .add_edge(NodeKind::Custom("B".into()), NodeKind::End); -//! -//! // Iterate over nodes -//! for node_kind in builder.nodes() { -//! println!("Node: {:?}", node_kind); -//! } -//! -//! // Iterate over edges -//! for (from, to) in builder.edges() { -//! println!("Edge: {:?} -> {:?}", from, to); -//! } -//! -//! // Get deterministic topological ordering -//! let sorted = builder.topological_sort(); -//! println!("Topological order: {:?}", sorted); -//! ``` +//! Graph iteration utilities and topological ordering. use crate::types::NodeKind; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::FxHashMap; use std::collections::VecDeque; -/// Iterator over node kinds in a graph. +/// Iterator over node kinds registered in a graph. /// -/// Yields each registered custom node kind. Does not include virtual -/// `Start` or `End` nodes as they are not stored in the node registry. +/// Yields each custom node kind. Virtual `Start` and `End` nodes are not +/// included — they are not stored in the node registry. /// /// # Examples /// @@ -71,16 +22,14 @@ use std::collections::VecDeque; /// # Ok(weavegraph::node::NodePartial::default()) /// # } /// # } -/// /// let builder = GraphBuilder::new() -/// .add_node(NodeKind::Custom("A".into()), MyNode) +/// .add_node(NodeKind::Custom("A".into()), MyNode) /// .add_node(NodeKind::Custom("B".into()), MyNode) /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) /// .add_edge(NodeKind::Custom("A".into()), NodeKind::Custom("B".into())) /// .add_edge(NodeKind::Custom("B".into()), NodeKind::End); /// -/// let nodes: Vec<_> = builder.nodes().collect(); -/// assert_eq!(nodes.len(), 2); +/// assert_eq!(builder.nodes().count(), 2); /// ``` pub struct NodesIter<'a> { inner: std::collections::hash_map::Keys<'a, NodeKind, std::sync::Arc>, @@ -110,13 +59,12 @@ impl<'a> Iterator for NodesIter<'a> { } } -impl<'a> ExactSizeIterator for NodesIter<'a> {} +impl ExactSizeIterator for NodesIter<'_> {} -/// Iterator over edges in a graph as (source, target) pairs. +/// Iterator over edges as `(source, target)` pairs. /// -/// Yields each edge in the graph, including edges from/to virtual -/// `Start` and `End` nodes. The iteration order is not guaranteed -/// to be deterministic due to hash map iteration. +/// Yields every edge including edges from/to virtual `Start` and `End` +/// nodes. Iteration order is not guaranteed to be deterministic. /// /// # Examples /// @@ -131,14 +79,12 @@ impl<'a> ExactSizeIterator for NodesIter<'a> {} /// # Ok(weavegraph::node::NodePartial::default()) /// # } /// # } -/// /// let builder = GraphBuilder::new() /// .add_node(NodeKind::Custom("A".into()), MyNode) /// .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) /// .add_edge(NodeKind::Custom("A".into()), NodeKind::End); /// -/// let edges: Vec<_> = builder.edges().collect(); -/// assert_eq!(edges.len(), 2); +/// assert_eq!(builder.edges().count(), 2); /// ``` pub struct EdgesIter<'a> { outer: std::collections::hash_map::Iter<'a, NodeKind, Vec>, @@ -169,94 +115,71 @@ impl<'a> Iterator for EdgesIter<'a> { if let Some(to) = self.current_targets.next() { return Some((self.current_from.unwrap(), to)); } - match self.outer.next() { - Some((from, targets)) => { - self.current_from = Some(from); - self.current_targets = targets.iter(); - } - None => return None, - } + let (from, targets) = self.outer.next()?; + self.current_from = Some(from); + self.current_targets = targets.iter(); } } } -/// Performs Kahn's algorithm for topological sorting. +/// Topological ordering via Kahn's algorithm. /// -/// Returns nodes in topological order (dependencies before dependents). -/// Virtual `Start` node is always first, `End` is always last. -/// Ties are broken lexicographically for deterministic ordering. +/// Returns all nodes with dependencies before dependents. `Start` is always +/// first, `End` always last; ties among custom nodes are broken +/// lexicographically for determinism. /// /// # Panics /// -/// This function assumes the graph is acyclic. If called on a graph with -/// cycles, it will return a partial ordering that excludes cycle members. -/// Use [`GraphBuilder::compile`] to validate acyclicity before calling. +/// Assumes an acyclic graph. On a cyclic graph the result is a partial +/// ordering that excludes cycle members. Use [`GraphBuilder::compile`] to +/// validate acyclicity first. pub(super) fn topological_sort(edges: &FxHashMap>) -> Vec { - // Build in-degree map and collect all nodes let mut in_degree: FxHashMap = FxHashMap::default(); - let mut all_nodes: FxHashSet = FxHashSet::default(); - - // Collect all nodes from edges - for (from, tos) in edges { - all_nodes.insert(from.clone()); + for (from, targets) in edges { in_degree.entry(from.clone()).or_insert(0); - for to in tos { - all_nodes.insert(to.clone()); + for to in targets { *in_degree.entry(to.clone()).or_insert(0) += 1; } } - // Initialize queue with nodes that have in-degree 0 - // Use a Vec and sort for deterministic ordering - let mut queue: VecDeque = VecDeque::new(); - let mut zero_in_degree: Vec<_> = in_degree + let mut seeds: Vec = in_degree .iter() - .filter(|entry| *entry.1 == 0) + .filter(|&(_, °)| deg == 0) .map(|(node, _)| node.clone()) .collect(); + seeds.sort_by(node_order); + let mut queue = VecDeque::from(seeds); - // Sort for deterministic ordering - Start always first - zero_in_degree.sort_by(|a, b| match (a, b) { - (NodeKind::Start, _) => std::cmp::Ordering::Less, - (_, NodeKind::Start) => std::cmp::Ordering::Greater, - (NodeKind::End, _) => std::cmp::Ordering::Greater, - (_, NodeKind::End) => std::cmp::Ordering::Less, - (NodeKind::Custom(a_name), NodeKind::Custom(b_name)) => a_name.cmp(b_name), - }); - - queue.extend(zero_in_degree); - - let mut result: Vec = Vec::with_capacity(all_nodes.len()); - + let mut result = Vec::with_capacity(in_degree.len()); while let Some(node) = queue.pop_front() { - result.push(node.clone()); - - if let Some(neighbors) = edges.get(&node) { - // Collect neighbors that become zero in-degree after removing this node - let mut new_zero: Vec = Vec::new(); - for neighbor in neighbors { - if let Some(deg) = in_degree.get_mut(neighbor) { + if let Some(targets) = edges.get(&node) { + let mut newly_free: Vec = targets + .iter() + .filter_map(|t| { + let deg = in_degree.get_mut(t)?; *deg = deg.saturating_sub(1); - if *deg == 0 { - new_zero.push(neighbor.clone()); - } - } - } - // Sort new zero-degree nodes for determinism - new_zero.sort_by(|a, b| match (a, b) { - (NodeKind::Start, _) => std::cmp::Ordering::Less, - (_, NodeKind::Start) => std::cmp::Ordering::Greater, - (NodeKind::End, _) => std::cmp::Ordering::Greater, - (_, NodeKind::End) => std::cmp::Ordering::Less, - (NodeKind::Custom(a_name), NodeKind::Custom(b_name)) => a_name.cmp(b_name), - }); - queue.extend(new_zero); + (*deg == 0).then(|| t.clone()) + }) + .collect(); + newly_free.sort_by(node_order); + queue.extend(newly_free); } + result.push(node); } result } +fn node_order(a: &NodeKind, b: &NodeKind) -> std::cmp::Ordering { + match (a, b) { + (NodeKind::Start, _) => std::cmp::Ordering::Less, + (_, NodeKind::Start) => std::cmp::Ordering::Greater, + (NodeKind::End, _) => std::cmp::Ordering::Greater, + (_, NodeKind::End) => std::cmp::Ordering::Less, + (NodeKind::Custom(a), NodeKind::Custom(b)) => a.cmp(b), + } +} + #[cfg(test)] mod tests { use super::*; @@ -265,76 +188,45 @@ mod tests { fn test_topological_sort_linear() { let mut edges: FxHashMap> = FxHashMap::default(); edges.insert(NodeKind::Start, vec![NodeKind::Custom("A".into())]); - edges.insert( - NodeKind::Custom("A".into()), - vec![NodeKind::Custom("B".into())], - ); + edges.insert(NodeKind::Custom("A".into()), vec![NodeKind::Custom("B".into())]); edges.insert(NodeKind::Custom("B".into()), vec![NodeKind::End]); let sorted = topological_sort(&edges); - // Start should be first, End should be last assert_eq!(sorted[0], NodeKind::Start); - assert_eq!(sorted[sorted.len() - 1], NodeKind::End); + assert_eq!(*sorted.last().unwrap(), NodeKind::End); - // A should come before B - let a_pos = sorted - .iter() - .position(|n| n == &NodeKind::Custom("A".into())) - .unwrap(); - let b_pos = sorted - .iter() - .position(|n| n == &NodeKind::Custom("B".into())) - .unwrap(); + let a_pos = sorted.iter().position(|n| n == &NodeKind::Custom("A".into())).unwrap(); + let b_pos = sorted.iter().position(|n| n == &NodeKind::Custom("B".into())).unwrap(); assert!(a_pos < b_pos); } #[test] fn test_topological_sort_diamond() { - // Start -> A, B -> C -> End (diamond pattern) let mut edges: FxHashMap> = FxHashMap::default(); edges.insert( NodeKind::Start, vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())], ); - edges.insert( - NodeKind::Custom("A".into()), - vec![NodeKind::Custom("C".into())], - ); - edges.insert( - NodeKind::Custom("B".into()), - vec![NodeKind::Custom("C".into())], - ); + edges.insert(NodeKind::Custom("A".into()), vec![NodeKind::Custom("C".into())]); + edges.insert(NodeKind::Custom("B".into()), vec![NodeKind::Custom("C".into())]); edges.insert(NodeKind::Custom("C".into()), vec![NodeKind::End]); let sorted = topological_sort(&edges); assert_eq!(sorted[0], NodeKind::Start); - assert_eq!(sorted[sorted.len() - 1], NodeKind::End); + assert_eq!(*sorted.last().unwrap(), NodeKind::End); - // A and B should both come before C - let a_pos = sorted - .iter() - .position(|n| n == &NodeKind::Custom("A".into())) - .unwrap(); - let b_pos = sorted - .iter() - .position(|n| n == &NodeKind::Custom("B".into())) - .unwrap(); - let c_pos = sorted - .iter() - .position(|n| n == &NodeKind::Custom("C".into())) - .unwrap(); + let a_pos = sorted.iter().position(|n| n == &NodeKind::Custom("A".into())).unwrap(); + let b_pos = sorted.iter().position(|n| n == &NodeKind::Custom("B".into())).unwrap(); + let c_pos = sorted.iter().position(|n| n == &NodeKind::Custom("C".into())).unwrap(); assert!(a_pos < c_pos); assert!(b_pos < c_pos); - - // A should come before B due to lexicographic ordering assert!(a_pos < b_pos); } #[test] fn test_topological_sort_deterministic() { - // Multiple runs should produce the same order let mut edges: FxHashMap> = FxHashMap::default(); edges.insert( NodeKind::Start, diff --git a/src/graphs/mod.rs b/src/graphs/mod.rs index f15735e..47c72bf 100644 --- a/src/graphs/mod.rs +++ b/src/graphs/mod.rs @@ -1,21 +1,8 @@ -//! Graph definition and compilation for workflow execution. +//! Graph construction, compilation, and iteration for workflow execution. //! -//! This module provides the core graph building functionality for creating -//! workflow graphs with nodes, edges, and conditional routing. The main -//! entry point is [`GraphBuilder`], which uses a builder pattern to -//! construct workflows that compile into executable [`App`](crate::app::App) instances. -//! -//! # Core Concepts -//! -//! - **Nodes**: Executable units of work implementing the [`Node`](crate::node::Node) trait -//! - **Edges**: Connections between nodes defining execution flow -//! - **Conditional Edges**: Dynamic routing based on state predicates -//! - **Virtual Endpoints**: `NodeKind::Start` and `NodeKind::End` for structural definition -//! - **Compilation**: Validation and conversion to executable [`App`](crate::app::App) -//! -//! # Graph Iteration -//! -//! The module provides petgraph-style iterators for inspecting graph structure: +//! The main entry point is [`GraphBuilder`], which assembles nodes, edges, and +//! conditional routes into a workflow that compiles into an executable +//! [`App`](crate::app::App). //! //! ``` //! use weavegraph::graphs::GraphBuilder; @@ -28,117 +15,13 @@ //! # Ok(weavegraph::node::NodePartial::default()) //! # } //! # } -//! -//! let builder = GraphBuilder::new() -//! .add_node(NodeKind::Custom("A".into()), MyNode) -//! .add_node(NodeKind::Custom("B".into()), MyNode) -//! .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) -//! .add_edge(NodeKind::Custom("A".into()), NodeKind::Custom("B".into())) -//! .add_edge(NodeKind::Custom("B".into()), NodeKind::End); -//! -//! // Iterate over registered nodes -//! for node in builder.nodes() { -//! println!("Node: {:?}", node); -//! } -//! -//! // Iterate over edges as (from, to) pairs -//! for (from, to) in builder.edges() { -//! println!("Edge: {:?} -> {:?}", from, to); -//! } -//! -//! // Get deterministic topological ordering -//! let sorted = builder.topological_sort(); -//! ``` -//! -//! # Quick Start -//! -//! ``` -//! use weavegraph::graphs::GraphBuilder; -//! use weavegraph::types::NodeKind; -//! use weavegraph::node::{Node, NodeContext, NodePartial, NodeError}; -//! use weavegraph::state::StateSnapshot; -//! use async_trait::async_trait; -//! -//! // Define a simple node -//! struct MyNode; -//! -//! #[async_trait] -//! impl Node for MyNode { -//! async fn run(&self, _: StateSnapshot, _: NodeContext) -> Result { -//! Ok(NodePartial::default()) -//! } -//! } -//! -//! // Build a simple workflow (virtual Start/End): -//! // Start (virtual) -> process -> End (virtual) -//! let app = GraphBuilder::new() -//! .add_node(NodeKind::Custom("process".into()), MyNode) -//! .add_edge(NodeKind::Start, NodeKind::Custom("process".into())) -//! .add_edge(NodeKind::Custom("process".into()), NodeKind::End) -//! .compile(); -//! ``` -//! -//! # Advanced Usage -//! -//! ## Conditional Routing -//! -//! ``` -//! use weavegraph::graphs::{GraphBuilder, EdgePredicate}; -//! use weavegraph::types::NodeKind; -//! use std::sync::Arc; -//! -//! // Create a predicate that routes based on message count -//! let route_by_messages: EdgePredicate = Arc::new(|snapshot| { -//! if snapshot.messages.len() > 5 { -//! vec!["process".to_string()] -//! } else { -//! vec!["skip".to_string()] -//! } -//! }); -//! -//! # struct MyNode; -//! # #[async_trait::async_trait] -//! # impl weavegraph::node::Node for MyNode { -//! # async fn run(&self, _: weavegraph::state::StateSnapshot, _: weavegraph::node::NodeContext) -> Result { -//! # Ok(weavegraph::node::NodePartial::default()) -//! # } -//! # } -//! //! let app = GraphBuilder::new() //! .add_node(NodeKind::Custom("process".into()), MyNode) -//! .add_node(NodeKind::Custom("skip".into()), MyNode) -//! // Basic structural edge from virtual Start //! .add_edge(NodeKind::Start, NodeKind::Custom("process".into())) -//! .add_conditional_edge(NodeKind::Start, route_by_messages) //! .add_edge(NodeKind::Custom("process".into()), NodeKind::End) -//! .add_edge(NodeKind::Custom("skip".into()), NodeKind::End) //! .compile(); //! ``` -//! -//! ## petgraph Integration -//! -//! With the `petgraph-compat` feature, you can convert graphs to petgraph format -//! for advanced algorithms and DOT visualization: -//! -//! ```ignore -//! // Enable with: weavegraph = { features = ["petgraph-compat"] } -//! use weavegraph::graphs::GraphBuilder; -//! -//! let builder = GraphBuilder::new() -//! .add_node(NodeKind::Custom("A".into()), MyNode) -//! .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) -//! .add_edge(NodeKind::Custom("A".into()), NodeKind::End); -//! -//! // Convert to petgraph for analysis -//! let pg = builder.to_petgraph(); -//! assert!(!petgraph::algo::is_cyclic_directed(&pg.graph)); -//! -//! // Export to DOT for visualization -//! let dot = builder.to_dot(); -//! std::fs::write("workflow.dot", dot)?; -//! ``` -// Internal module declarations mod builder; mod compilation; mod edges; @@ -147,7 +30,6 @@ mod iteration; #[cfg(feature = "petgraph-compat")] mod petgraph_compat; -// Public re-exports for backward compatibility pub use builder::GraphBuilder; pub use compilation::GraphCompileError; pub use edges::{ConditionalEdge, EdgePredicate}; diff --git a/src/graphs/petgraph_compat.rs b/src/graphs/petgraph_compat.rs index 6abce24..97fb825 100644 --- a/src/graphs/petgraph_compat.rs +++ b/src/graphs/petgraph_compat.rs @@ -1,220 +1,134 @@ //! Optional petgraph compatibility layer. //! -//! This module provides conversion between Weavegraph's graph representation -//! and petgraph's `DiGraph` type, enabling use of petgraph's rich algorithm -//! library for advanced analysis and DOT visualization. -//! -//! # Feature Gate -//! -//! This module is only available when the `petgraph-compat` feature is enabled: -//! -//! ```toml -//! [dependencies] -//! weavegraph = { version = "0.1", features = ["petgraph-compat"] } -//! ``` -//! -//! # Examples -//! -//! ## Convert to petgraph for analysis +//! Converts Weavegraph's edge map to petgraph's [`DiGraph`], enabling petgraph's +//! algorithm library and DOT visualization. Requires the `petgraph-compat` feature. //! //! ```ignore //! use weavegraph::graphs::GraphBuilder; //! use weavegraph::types::NodeKind; -//! -//! let builder = GraphBuilder::new() -//! .add_node(NodeKind::Custom("A".into()), MyNode) -//! .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) -//! .add_edge(NodeKind::Custom("A".into()), NodeKind::End); -//! -//! let petgraph = builder.to_petgraph(); -//! -//! // Use petgraph algorithms //! use petgraph::algo::is_cyclic_directed; -//! assert!(!is_cyclic_directed(&petgraph)); -//! ``` -//! -//! ## Export to DOT format -//! -//! ```ignore -//! use weavegraph::graphs::GraphBuilder; -//! use weavegraph::types::NodeKind; //! //! let builder = GraphBuilder::new() -//! .add_node(NodeKind::Custom("A".into()), MyNode) +//! .add_node(NodeKind::Custom("A".into()), my_node) //! .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) //! .add_edge(NodeKind::Custom("A".into()), NodeKind::End); //! -//! let dot = builder.to_dot(); -//! println!("{}", dot); -//! // Outputs: -//! // digraph { -//! // 0 [ label = "Start" ] -//! // 1 [ label = "A" ] -//! // 2 [ label = "End" ] -//! // 0 -> 1 [ ] -//! // 1 -> 2 [ ] -//! // } +//! let pg = builder.to_petgraph(); +//! assert!(!is_cyclic_directed(&pg.graph)); +//! println!("{}", builder.to_dot()); //! ``` use crate::types::NodeKind; use petgraph::graph::{DiGraph, NodeIndex}; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; -/// A petgraph-compatible directed graph representation of a Weavegraph workflow. -/// -/// Node weights are `NodeKind` values, edge weights are unit type `()`. +/// Directed graph with `NodeKind` node weights and unit edge weights. pub type WeaveDiGraph = DiGraph; -/// Mapping from NodeKind to petgraph NodeIndex. -/// -/// Useful for looking up nodes in the converted graph when you need to -/// perform queries or modifications. +/// Maps each `NodeKind` to its petgraph [`NodeIndex`]. pub type NodeIndexMap = FxHashMap; -/// Result of converting a Weavegraph to petgraph format. -/// -/// Contains both the graph and a mapping from NodeKind to petgraph indices -/// for convenient lookup. +/// Result of a Weavegraph-to-petgraph conversion. #[derive(Debug, Clone)] pub struct PetgraphConversion { - /// The petgraph directed graph. + /// The converted directed graph. pub graph: WeaveDiGraph, - /// Mapping from Weavegraph NodeKind to petgraph NodeIndex. + /// Lookup map from `NodeKind` to petgraph index. pub index_map: NodeIndexMap, } impl PetgraphConversion { - /// Look up the petgraph index for a NodeKind. + /// Returns the petgraph index for `node`, if present. #[must_use] pub fn index_of(&self, node: &NodeKind) -> Option { self.index_map.get(node).copied() } - /// Get the NodeKind at a petgraph index. + /// Returns the `NodeKind` at `index`, if present. #[must_use] pub fn node_at(&self, index: NodeIndex) -> Option<&NodeKind> { self.graph.node_weight(index) } } -/// Convert a Weavegraph edge map to a petgraph DiGraph. +/// Converts a Weavegraph edge map to a petgraph [`DiGraph`]. /// -/// This is the internal conversion function used by `GraphBuilder::to_petgraph()`. -/// It creates nodes for all NodeKinds referenced in edges (including Start/End) -/// and preserves the edge topology. +/// Nodes are ordered deterministically: `Start` first, `End` last, +/// `Custom` nodes alphabetically between them. pub(super) fn to_petgraph(edges: &FxHashMap>) -> PetgraphConversion { - let mut graph = DiGraph::new(); - let mut index_map: NodeIndexMap = FxHashMap::default(); + let unique: FxHashSet = edges + .iter() + .flat_map(|(from, tos)| std::iter::once(from).chain(tos)) + .cloned() + .collect(); - // Collect all unique nodes - let mut all_nodes: Vec = Vec::new(); - for (from, tos) in edges { - if !index_map.contains_key(from) { - all_nodes.push(from.clone()); - index_map.insert(from.clone(), NodeIndex::new(0)); // placeholder - } - for to in tos { - if !index_map.contains_key(to) { - all_nodes.push(to.clone()); - index_map.insert(to.clone(), NodeIndex::new(0)); // placeholder - } - } - } - - // Sort for deterministic node indices (Start first, End last, custom alphabetically) + let mut all_nodes: Vec = unique.into_iter().collect(); all_nodes.sort_by(|a, b| match (a, b) { (NodeKind::Start, _) => std::cmp::Ordering::Less, (_, NodeKind::Start) => std::cmp::Ordering::Greater, (NodeKind::End, _) => std::cmp::Ordering::Greater, (_, NodeKind::End) => std::cmp::Ordering::Less, - (NodeKind::Custom(a_name), NodeKind::Custom(b_name)) => a_name.cmp(b_name), + (NodeKind::Custom(a), NodeKind::Custom(b)) => a.cmp(b), }); - // Add nodes to graph and update index map - for node in &all_nodes { - let idx = graph.add_node(node.clone()); - index_map.insert(node.clone(), idx); - } + let mut graph = DiGraph::new(); + let index_map: NodeIndexMap = all_nodes + .into_iter() + .map(|node| { + let idx = graph.add_node(node.clone()); + (node, idx) + }) + .collect(); - // Add edges for (from, tos) in edges { let from_idx = index_map[from]; for to in tos { - let to_idx = index_map[to]; - graph.add_edge(from_idx, to_idx, ()); + graph.add_edge(from_idx, index_map[to], ()); } } PetgraphConversion { graph, index_map } } -/// Export a graph to DOT format for visualization. -/// -/// The DOT output can be rendered using Graphviz tools (`dot`, `neato`, etc.) -/// or online viewers like https://dreampuf.github.io/GraphvizOnline/. -/// -/// # Node Labels -/// -/// - `Start` → "Start" (with special styling) -/// - `End` → "End" (with special styling) -/// - `Custom("name")` → "name" +/// Renders the graph to DOT format for Graphviz visualization. /// -/// # Examples -/// -/// ```ignore -/// let dot = to_dot(&edges); -/// std::fs::write("workflow.dot", dot)?; -/// // Then: dot -Tpng workflow.dot -o workflow.png -/// ``` +/// `Start` and `End` nodes receive distinct fill colors. Output is +/// deterministic across calls because node indices are stable. pub(super) fn to_dot(edges: &FxHashMap>) -> String { + use petgraph::visit::EdgeRef; use std::fmt::Write; let conversion = to_petgraph(edges); - let mut output = String::new(); + let mut out = String::new(); - writeln!(output, "digraph {{").unwrap(); - writeln!(output, " rankdir=TB;").unwrap(); - writeln!(output, " node [shape=box, style=rounded];").unwrap(); + writeln!(out, "digraph {{").unwrap(); + writeln!(out, " rankdir=TB;").unwrap(); + writeln!(out, " node [shape=box, style=rounded];").unwrap(); - // Write nodes with labels and styling for idx in conversion.graph.node_indices() { - let node = conversion.graph.node_weight(idx).unwrap(); + let node = &conversion.graph[idx]; let (label, style) = match node { NodeKind::Start => ("Start", " style=\"filled\" fillcolor=\"lightgreen\""), NodeKind::End => ("End", " style=\"filled\" fillcolor=\"lightcoral\""), NodeKind::Custom(name) => (name.as_str(), ""), }; - writeln!( - output, - " {} [ label=\"{}\"{} ];", - idx.index(), - label, - style - ) - .unwrap(); + writeln!(out, " {} [ label=\"{}\"{} ];", idx.index(), label, style).unwrap(); } - writeln!(output).unwrap(); + writeln!(out).unwrap(); - // Write edges - for edge in conversion.graph.edge_indices() { - let (from, to) = conversion.graph.edge_endpoints(edge).unwrap(); - writeln!(output, " {} -> {};", from.index(), to.index()).unwrap(); + for edge in conversion.graph.edge_references() { + writeln!(out, " {} -> {};", edge.source().index(), edge.target().index()).unwrap(); } - writeln!(output, "}}").unwrap(); - - output + writeln!(out, "}}").unwrap(); + out } -/// Check for cycles using petgraph's algorithm. -/// -/// This provides an alternative cycle detection implementation that can be -/// used for validation fallback or cross-verification. +/// Returns `true` if the graph contains a directed cycle. #[must_use] pub fn is_cyclic(edges: &FxHashMap>) -> bool { - let conversion = to_petgraph(edges); - petgraph::algo::is_cyclic_directed(&conversion.graph) + petgraph::algo::is_cyclic_directed(&to_petgraph(edges).graph) } #[cfg(test)] @@ -237,20 +151,17 @@ mod tests { ); edges.insert( NodeKind::Custom("B".into()), - vec![NodeKind::Custom("A".into())], // cycle! + vec![NodeKind::Custom("A".into())], ); edges } #[test] fn test_to_petgraph_linear() { - let edges = make_linear_graph(); - let conversion = to_petgraph(&edges); + let conversion = to_petgraph(&make_linear_graph()); assert_eq!(conversion.graph.node_count(), 3); assert_eq!(conversion.graph.edge_count(), 2); - - // Check nodes exist assert!(conversion.index_of(&NodeKind::Start).is_some()); assert!(conversion.index_of(&NodeKind::Custom("A".into())).is_some()); assert!(conversion.index_of(&NodeKind::End).is_some()); @@ -258,21 +169,17 @@ mod tests { #[test] fn test_is_cyclic_linear() { - let edges = make_linear_graph(); - assert!(!is_cyclic(&edges)); + assert!(!is_cyclic(&make_linear_graph())); } #[test] fn test_is_cyclic_with_cycle() { - let edges = make_cyclic_graph(); - assert!(is_cyclic(&edges)); + assert!(is_cyclic(&make_cyclic_graph())); } #[test] fn test_to_dot_output() { - let edges = make_linear_graph(); - let dot = to_dot(&edges); - + let dot = to_dot(&make_linear_graph()); assert!(dot.contains("digraph {")); assert!(dot.contains("Start")); assert!(dot.contains("End")); @@ -281,9 +188,7 @@ mod tests { #[test] fn test_deterministic_indices() { - // Same graph should produce same indices across calls let edges = make_linear_graph(); - let conv1 = to_petgraph(&edges); let conv2 = to_petgraph(&edges); diff --git a/src/llm/mod.rs b/src/llm/mod.rs index 496e2ca..31c406e 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -1,7 +1,4 @@ -//! Framework-agnostic LLM abstractions and optional adapters. -//! -//! This module defines provider traits that are independent of any specific -//! LLM SDK. The Rig adapter is available behind the `rig` feature. +//! Framework-agnostic LLM provider traits and optional SDK adapters. pub mod traits; diff --git a/src/llm/rig_adapter.rs b/src/llm/rig_adapter.rs index 42c7d38..d7f3dcd 100644 --- a/src/llm/rig_adapter.rs +++ b/src/llm/rig_adapter.rs @@ -1,4 +1,4 @@ -//! Adapter implementing the weavegraph LLM traits for the [Rig](https://github.com/0xPlaygrounds/rig) framework. +//! [`From`] conversions between weavegraph [`Message`] and Rig's [`RigMessage`]. use crate::message::{Message, Role}; use rig::completion::message::{ AssistantContent, Message as RigMessage, ToolResultContent, UserContent, @@ -9,8 +9,7 @@ impl From for RigMessage { match msg.role { Role::User => RigMessage::user(msg.content), Role::Assistant => RigMessage::assistant(msg.content), - // Rig's core completion history is user/assistant-focused; map - // non-native roles to user for compatibility. + // Rig history is user/assistant only; fold remaining roles into user. Role::System | Role::Tool | Role::Custom(_) => RigMessage::user(msg.content), } } @@ -21,27 +20,21 @@ impl From for Message { match msg { RigMessage::User { content } => Message::with_role( Role::User, - &content - .iter() - .find_map(extract_user_content_text) - .unwrap_or_default(), + &content.iter().find_map(user_text).unwrap_or_default(), ), RigMessage::Assistant { content, .. } => Message::with_role( Role::Assistant, - &content - .iter() - .find_map(extract_assistant_content_text) - .unwrap_or_default(), + &content.iter().find_map(assistant_text).unwrap_or_default(), ), } } } -fn extract_user_content_text(content: &UserContent) -> Option { +fn user_text(content: &UserContent) -> Option { match content { - UserContent::Text(text) => Some(text.text.clone()), - UserContent::ToolResult(result) => result.content.iter().find_map(|chunk| match chunk { - ToolResultContent::Text(text) => Some(text.text.clone()), + UserContent::Text(t) => Some(t.text.clone()), + UserContent::ToolResult(r) => r.content.iter().find_map(|chunk| match chunk { + ToolResultContent::Text(t) => Some(t.text.clone()), ToolResultContent::Image(_) => None, }), UserContent::Image(_) @@ -51,11 +44,11 @@ fn extract_user_content_text(content: &UserContent) -> Option { } } -fn extract_assistant_content_text(content: &AssistantContent) -> Option { +fn assistant_text(content: &AssistantContent) -> Option { match content { - AssistantContent::Text(text) => Some(text.text.clone()), - AssistantContent::Reasoning(reasoning) => reasoning.reasoning.first().cloned(), - AssistantContent::ToolCall(call) => Some(format!("[tool_call:{}]", call.function.name)), + AssistantContent::Text(t) => Some(t.text.clone()), + AssistantContent::Reasoning(r) => r.reasoning.first().cloned(), + AssistantContent::ToolCall(c) => Some(format!("[tool_call:{}]", c.function.name)), AssistantContent::Image(_) => None, } } @@ -66,7 +59,7 @@ mod tests { fn first_user_text(msg: RigMessage) -> Option { match msg { - RigMessage::User { content } => content.iter().find_map(extract_user_content_text), + RigMessage::User { content } => content.iter().find_map(user_text), RigMessage::Assistant { .. } => None, } } @@ -87,7 +80,7 @@ mod tests { match assistant { RigMessage::Assistant { content, .. } => { assert_eq!( - content.iter().find_map(extract_assistant_content_text), + content.iter().find_map(assistant_text), Some("a".to_string()) ); } diff --git a/src/llm/traits.rs b/src/llm/traits.rs index bfdcd83..39fca6c 100644 --- a/src/llm/traits.rs +++ b/src/llm/traits.rs @@ -1,34 +1,34 @@ -//! Framework-agnostic traits for LLM providers (non-streaming and streaming). +//! Framework-agnostic traits for LLM providers. use crate::message::Message; use async_trait::async_trait; use futures_util::stream::BoxStream; -/// Unified error type for framework-agnostic LLM providers. +/// Boxed error type for LLM provider operations. pub type LlmError = Box; /// Completed response from an LLM provider. #[derive(Clone, Debug, Default)] pub struct LlmResponse { - /// The generated text content returned by the LLM. + /// Generated text content. pub content: String, - /// Optional provider-specific metadata (token counts, finish reason, etc.). + /// Provider metadata (token counts, finish reason, etc.). pub metadata: serde_json::Value, } -/// Trait for non-streaming LLM providers. +/// Non-streaming LLM provider. #[async_trait] pub trait LlmProvider: Send + Sync { - /// Execute a chat completion request over the provided message history. + /// Run a chat completion over `messages`. async fn chat(&self, messages: &[Message]) -> Result; } -/// Trait for streaming LLM providers. +/// Streaming LLM provider. #[async_trait] pub trait LlmStreamProvider: LlmProvider { - /// Stream chunk type produced by the provider. + /// Token chunk type produced during streaming. type Chunk: Send + 'static; - /// Execute a streaming chat completion request. + /// Stream a chat completion over `messages`. async fn chat_stream( &self, messages: &[Message], From 486b9a998aaa928d427107d2211265ec2486d5ba Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 16:08:22 -0400 Subject: [PATCH 05/15] heavy revision work for reducers, scheduler, and telemetry --- src/reducers/add_errors.rs | 12 +- src/reducers/add_messages.rs | 7 +- src/reducers/map_merge.rs | 33 +- src/reducers/mod.rs | 23 +- src/reducers/reducer_registry.rs | 105 ++----- src/schedulers/scheduler.rs | 525 ++++++------------------------- src/telemetry/mod.rs | 228 +++++--------- 7 files changed, 239 insertions(+), 694 deletions(-) diff --git a/src/reducers/add_errors.rs b/src/reducers/add_errors.rs index 1ac3cac..96c2dfb 100644 --- a/src/reducers/add_errors.rs +++ b/src/reducers/add_errors.rs @@ -1,17 +1,17 @@ -//! Reducer that appends incoming [`ErrorEvent`](crate::channels::errors::ErrorEvent) entries to the errors channel. +//! Reducer that appends [`ErrorEvent`](crate::channels::errors::ErrorEvent) entries to the errors channel. use super::Reducer; use crate::{channels::Channel, node::NodePartial, state::VersionedState}; -/// Reducer that appends error events from a [`NodePartial`](crate::node::NodePartial) to the state errors channel. -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +/// Appends each incoming error event onto the state errors channel. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct AddErrors; impl Reducer for AddErrors { fn apply(&self, state: &mut VersionedState, update: &NodePartial) { - if let Some(error_events) = &update.errors - && !error_events.is_empty() + if let Some(errors) = &update.errors + && !errors.is_empty() { - state.errors.get_mut().extend_from_slice(error_events); + state.errors.get_mut().extend_from_slice(errors); } } } diff --git a/src/reducers/add_messages.rs b/src/reducers/add_messages.rs index 4569470..3c4d041 100644 --- a/src/reducers/add_messages.rs +++ b/src/reducers/add_messages.rs @@ -1,15 +1,14 @@ -//! Reducer that appends incoming messages to the messages channel. +//! Reducer that appends messages from a [`NodePartial`] to the messages channel. use super::Reducer; use crate::{channels::Channel, node::NodePartial, state::VersionedState}; -/// Reducer that appends messages from a [`NodePartial`](crate::node::NodePartial) to the state messages channel. -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +/// Appends each incoming message onto the state messages channel. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct AddMessages; impl Reducer for AddMessages { fn apply(&self, state: &mut VersionedState, update: &NodePartial) { if let Some(msgs) = &update.messages { - // Append new messages without cloning the entire vector state.messages.get_mut().extend_from_slice(msgs); } } diff --git a/src/reducers/map_merge.rs b/src/reducers/map_merge.rs index 71063e4..3405514 100644 --- a/src/reducers/map_merge.rs +++ b/src/reducers/map_merge.rs @@ -1,33 +1,34 @@ -//! Reducer that shallow-merges incoming extra key-value pairs into the extras channel. +//! Reducer that shallow-merges extra key-value pairs into the extras channel. //! -//! Follows the [JSON Merge Patch](https://www.rfc-editor.org/rfc/rfc7396) convention: -//! an incoming `null` value **removes** the key from state rather than setting it to null. -//! This is what makes [`NodePartial::clear_extra_keys`](crate::node::NodePartial::clear_extra_keys) +//! Follows [JSON Merge Patch](https://www.rfc-editor.org/rfc/rfc7396) (RFC 7396): +//! a `null` value removes the corresponding key from state rather than writing null. +//! This gives [`NodePartial::clear_extra_keys`](crate::node::NodePartial::clear_extra_keys) //! and [`NodePartial::clear_typed_extra_key`](crate::node::NodePartial::clear_typed_extra_key) -//! functional without requiring a separate cleanup reducer. +//! full key-deletion semantics without a separate cleanup reducer. use super::Reducer; use crate::{channels::Channel, node::NodePartial, state::VersionedState}; -/// Reducer that merges extra key-value pairs from a [`NodePartial`](crate::node::NodePartial) into the state extras channel. +/// Merges extra key-value pairs from a [`NodePartial`] into the state extras channel. /// -/// Uses JSON Merge Patch semantics (RFC 7396): an incoming `null` value **removes** the -/// key from state rather than writing a null entry. This means +/// Uses JSON Merge Patch semantics (RFC 7396): an incoming `null` deletes the key; +/// any other value overwrites it. This makes /// [`NodePartial::clear_extra_keys`](crate::node::NodePartial::clear_extra_keys) and /// [`NodePartial::clear_typed_extra_key`](crate::node::NodePartial::clear_typed_extra_key) -/// fully delete the key — no separate cleanup reducer is needed. -#[derive(Debug, PartialEq, Clone, Hash, Eq)] +/// perform a complete key deletion — no cleanup reducer required. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct MapMerge; + impl Reducer for MapMerge { fn apply(&self, state: &mut VersionedState, update: &NodePartial) { - if let Some(extras_update) = &update.extra - && !extras_update.is_empty() + if let Some(patch) = &update.extra + && !patch.is_empty() { - let state_map = state.extra.get_mut(); - for (k, v) in extras_update.iter() { + let map = state.extra.get_mut(); + for (k, v) in patch { if v.is_null() { - state_map.remove(k); + map.remove(k); } else { - state_map.insert(k.clone(), v.clone()); + map.insert(k.clone(), v.clone()); } } } diff --git a/src/reducers/mod.rs b/src/reducers/mod.rs index 5606c4f..b3641eb 100644 --- a/src/reducers/mod.rs +++ b/src/reducers/mod.rs @@ -1,4 +1,4 @@ -//! State reducers that apply [`NodePartial`] updates to [`VersionedState`]. +//! Reducers apply [`NodePartial`] deltas to [`VersionedState`] channel-by-channel. mod add_errors; mod add_messages; mod map_merge; @@ -14,26 +14,25 @@ use crate::state::VersionedState; use crate::types::ChannelType; use thiserror::Error; -/// Unified reducer trait: every reducer mutates VersionedState using a NodePartial delta. -/// Channels currently implemented: messages (append) and extra (shallow JSON map merge). +/// Applies a [`NodePartial`] delta to [`VersionedState`] for one channel. pub trait Reducer: Send + Sync { - /// Stable-ish reducer identity included in graph definition metadata. + /// Stable identity string included in graph-definition metadata. /// - /// The default is the concrete Rust type path. Custom reducers can override this - /// with a durable label when the type path is too noisy for audit manifests. + /// Defaults to the concrete Rust type path. Override with a fixed label when + /// the type path is too verbose for audit manifests. fn definition_label(&self) -> &'static str { std::any::type_name::() } - /// Apply the partial update `update` to `state`, mutating it in place. + /// Mutate `state` in-place using the delta in `update`. fn apply(&self, state: &mut VersionedState, update: &NodePartial); } -/// Errors that can occur when applying reducers to workflow state. +/// Errors produced by the reducer pipeline. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum ReducerError { - /// No reducer is registered for the specified channel type. + /// No reducer is registered for the given channel. #[error("no reducers registered for channel: {0:?}")] #[cfg_attr( feature = "diagnostics", @@ -41,16 +40,16 @@ pub enum ReducerError { )] UnknownChannel(ChannelType), - /// A reducer failed while applying an update to a channel. + /// A reducer returned an error while applying an update. #[error("reducer apply failed for channel {channel:?}: {message}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::reducers::apply_failed)) )] Apply { - /// The channel type for which the reducer failed. + /// Channel the reducer failed on. channel: ChannelType, - /// Human-readable description of the failure. + /// Description of the failure. message: String, }, } diff --git a/src/reducers/reducer_registry.rs b/src/reducers/reducer_registry.rs index d6ac94c..cd6a201 100644 --- a/src/reducers/reducer_registry.rs +++ b/src/reducers/reducer_registry.rs @@ -9,84 +9,45 @@ use crate::{ }; use tracing::instrument; -/// Registry mapping channel types to ordered lists of reducers. +/// Maps channel types to ordered lists of reducers applied during state updates. #[derive(Clone)] pub struct ReducerRegistry { reducer_map: FxHashMap>>, } -/// Guard that checks whether a NodePartial actually has meaningful data -/// for the specified channel. This lets the registry skip invoking -/// reducers when there is nothing to do. -fn channel_guard(channel: &ChannelType, partial: &NodePartial) -> bool { +fn channel_has_data(channel: &ChannelType, partial: &NodePartial) -> bool { match channel { - ChannelType::Message => partial - .messages - .as_ref() - .map(|v| !v.is_empty()) - .unwrap_or(false), - ChannelType::Extra => partial - .extra - .as_ref() - .map(|m| !m.is_empty()) - .unwrap_or(false), - ChannelType::Error => partial - .errors - .as_ref() - .map(|v| !v.is_empty()) - .unwrap_or(false), + ChannelType::Message => partial.messages.as_ref().is_some_and(|v| !v.is_empty()), + ChannelType::Extra => partial.extra.as_ref().is_some_and(|m| !m.is_empty()), + ChannelType::Error => partial.errors.as_ref().is_some_and(|v| !v.is_empty()), } } impl Default for ReducerRegistry { fn default() -> Self { - let mut registry = Self::new(); - registry - .register(ChannelType::Message, Arc::new(AddMessages)) - .register(ChannelType::Extra, Arc::new(MapMerge)) - .register(ChannelType::Error, Arc::new(AddErrors)); - registry + Self::new() + .with_reducer(ChannelType::Message, Arc::new(AddMessages)) + .with_reducer(ChannelType::Extra, Arc::new(MapMerge)) + .with_reducer(ChannelType::Error, Arc::new(AddErrors)) } } impl ReducerRegistry { - /// Creates a new empty reducer registry. + /// Creates an empty registry. pub fn new() -> Self { Self { reducer_map: FxHashMap::default(), } } - /// Registers a reducer for a specific channel type. - /// - /// This method allows dynamic registration of reducers at runtime. - /// Multiple reducers can be registered for the same channel and will - /// be applied in registration order. - /// - /// # Parameters - /// - `channel`: The channel type to register the reducer for - /// - `reducer`: The reducer implementation wrapped in Arc - /// - /// # Returns - /// A mutable reference to self for method chaining + /// Appends `reducer` to the list for `channel`. Multiple reducers run in registration order. pub fn register(&mut self, channel: ChannelType, reducer: Arc) -> &mut Self { self.reducer_map.entry(channel).or_default().push(reducer); self } - /// Builder-style method for registering a reducer. - /// - /// This is a convenience method that consumes self and returns it, - /// enabling fluent API usage when constructing a ReducerRegistry. + /// Owned builder variant of [`register`](Self::register). /// - /// # Parameters - /// - `channel`: The channel type to register the reducer for - /// - `reducer`: The reducer implementation wrapped in Arc - /// - /// # Returns - /// Self for method chaining - /// - /// # Examples /// ``` /// use std::sync::Arc; /// use weavegraph::reducers::{ReducerRegistry, AddMessages}; @@ -100,61 +61,61 @@ impl ReducerRegistry { self } - /// Return a deterministic summary of registered reducers for metadata hashing. + /// Returns a sorted, deterministic summary of all registered reducer labels. /// - /// Reducer labels are recorded in registration order for each channel. It - /// changes when reducers are added, removed, reordered, or replaced with a - /// reducer that reports a different [`Reducer::definition_label`]. + /// The signature changes when reducers are added, removed, reordered, or replaced + /// with a reducer that reports a different [`Reducer::definition_label`]. #[must_use] pub fn definition_signature(&self) -> Vec { - let mut signature: Vec = self + let mut entries: Vec = self .reducer_map .iter() .map(|(channel, reducers)| { let labels = reducers .iter() .enumerate() - .map(|(index, reducer)| format!("{index}:{}", reducer.definition_label())) + .map(|(i, r)| format!("{i}:{}", r.definition_label())) .collect::>() .join(","); - format!("{}:[{}]", channel, labels) + format!("{channel}:[{labels}]") }) .collect(); - signature.sort(); - signature + entries.sort(); + entries } + /// Applies all reducers registered for `channel_type` to `state`. + /// + /// Returns `Ok(())` immediately if `to_update` carries no data for the channel. + /// Returns `Err(ReducerError::UnknownChannel)` if no reducers are registered. #[instrument(skip(self, state, to_update), err)] - /// Apply all reducers for `channel_type` to `state` using `to_update` as the delta. pub fn try_update( &self, channel_type: ChannelType, state: &mut VersionedState, to_update: &NodePartial, ) -> Result<(), ReducerError> { - // Skip if the partial has no applicable data for this channel. - if !channel_guard(&channel_type, to_update) { + if !channel_has_data(&channel_type, to_update) { return Ok(()); } - - if let Some(reducers) = self.reducer_map.get(&channel_type) { - for reducer in reducers { - reducer.apply(state, to_update); + match self.reducer_map.get(&channel_type) { + Some(reducers) => { + for reducer in reducers { + reducer.apply(state, to_update); + } + Ok(()) } - Ok(()) - } else { - Err(ReducerError::UnknownChannel(channel_type)) + None => Err(ReducerError::UnknownChannel(channel_type)), } } + /// Applies all registered reducers across every channel to `state`. #[instrument(skip(self, state, merged_updates), err)] - /// Apply all registered reducers across all channels to `state`. pub fn apply_all( &self, state: &mut VersionedState, merged_updates: &NodePartial, ) -> Result<(), ReducerError> { - // Iterate all registered channels; try_update will skip via guard if no data. for channel in self.reducer_map.keys() { self.try_update(channel.clone(), state, merged_updates)?; } diff --git a/src/schedulers/scheduler.rs b/src/schedulers/scheduler.rs index c75b85e..85e2462 100644 --- a/src/schedulers/scheduler.rs +++ b/src/schedulers/scheduler.rs @@ -2,18 +2,9 @@ #![allow(unused_assignments)] //! Frontier-based scheduler with version gating and bounded concurrency. //! -//! This module provides the core scheduling logic for the Weavegraph workflow framework. -//! The scheduler manages concurrent execution of nodes while ensuring version-based -//! consistency and preventing unnecessary re-execution through intelligent gating. -//! -//! # Core Concepts -//! -//! - **Frontier**: The set of nodes eligible for execution in the current step -//! - **Version Gating**: Skip nodes that have already processed the current state -//! - **Bounded Concurrency**: Control parallel execution with configurable limits -//! - **Superstep**: A single execution phase over the frontier -//! -//! # Examples +//! Manages concurrent node execution; nodes are skipped when they have already +//! processed the current channel versions. The scheduler is stateless — all +//! tracking lives in [`SchedulerState`]. //! //! ```rust //! use weavegraph::channels::Channel; @@ -21,19 +12,11 @@ //! use weavegraph::state::VersionedState; //! //! # async fn example() -> Result<(), Box> { -//! // Create scheduler with concurrency limit //! let scheduler = Scheduler::new(4); -//! let mut sched_state = SchedulerState::default(); -//! -//! // Check if a node should run based on version changes -//! let mut state = VersionedState::builder().build(); -//! state.messages.set_version(2); -//! let snapshot = state.snapshot(); -//! -//! let should_run = scheduler.should_run(&sched_state, "node_id", &snapshot); -//! if should_run { -//! println!("Node should run - state has changed"); -//! } +//! let mut state = SchedulerState::default(); +//! let mut vs = VersionedState::builder().build(); +//! vs.messages.set_version(2); +//! assert!(scheduler.should_run(&state, "my_node", &vs.snapshot())); //! # Ok(()) //! # } //! ``` @@ -49,61 +32,31 @@ use std::sync::Arc; use thiserror::Error; use tracing::instrument; -/// Result of executing a single superstep in the scheduler. -/// -/// This structure provides comprehensive information about what happened -/// during a superstep execution, including which nodes ran, which were -/// skipped, and their outputs. This enables detailed monitoring and -/// debugging of workflow execution. -/// -/// # Fields -/// -/// - `ran_nodes`: Nodes that executed, preserving scheduling order -/// - `skipped_nodes`: Nodes that were skipped (End nodes or version-gated) -/// - `outputs`: Results from executed nodes as (NodeKind, NodePartial) pairs -/// -/// # Examples -/// -/// ```rust -/// use weavegraph::schedulers::StepRunResult; -/// use weavegraph::types::NodeKind; -/// -/// fn analyze_step_result(result: &StepRunResult) { -/// println!("Executed {} nodes, skipped {}", -/// result.ran_nodes.len(), -/// result.skipped_nodes.len()); -/// -/// for (node_kind, partial) in &result.outputs { -/// if let Some(messages) = &partial.messages { -/// println!("Node {:?} produced {} messages", node_kind, messages.len()); -/// } -/// } -/// } -/// ``` +/// Execution summary for a single superstep. #[derive(Debug, Clone)] pub struct StepRunResult { - /// Nodes that were executed this step, in the order they were scheduled. + /// Nodes that executed this step, in scheduling order. pub ran_nodes: Vec, - /// Nodes that were skipped this step (End nodes or no new versions seen). + /// Nodes skipped this step (structural or version-gated). pub skipped_nodes: Vec, - /// Outputs from nodes that ran: (node_kind, NodePartial) + /// Outputs from executed nodes as `(node_kind, partial)` pairs. pub outputs: Vec<(NodeKind, NodePartial)>, } -/// Runtime context passed to a scheduler superstep. +/// Runtime context injected into each superstep. #[derive(Clone)] #[non_exhaustive] pub struct SchedulerRunContext { - /// Event emitter injected into node contexts. + /// Event emitter forwarded to each node context. pub event_emitter: Arc, - /// Optional runtime clock injected into node contexts. + /// Optional clock forwarded to each node context. pub clock: Option>, - /// Optional invocation identifier injected into node contexts. + /// Optional invocation identifier forwarded to each node context. pub invocation_id: Option, } impl SchedulerRunContext { - /// Create scheduler runtime context with only an event emitter. + /// Build a context with only an event emitter; clock and invocation ID default to `None`. #[must_use] pub fn new(event_emitter: Arc) -> Self { Self { @@ -113,7 +66,7 @@ impl SchedulerRunContext { } } - /// Attach a runtime clock. + /// Attach a clock. #[must_use] pub fn with_clock(mut self, clock: Arc) -> Self { self.clock = Some(clock); @@ -128,133 +81,54 @@ impl SchedulerRunContext { } } -/// Tracks version information for nodes to enable intelligent scheduling. -/// -/// The scheduler uses this state to determine whether a node needs to run -/// based on whether it has already processed the current state versions. -/// This prevents unnecessary re-execution and enables efficient incremental -/// processing. +/// Version-tracking state that drives the scheduler's execution gate. /// -/// # Internal Structure -/// -/// The `versions_seen` map uses a two-level structure: -/// - Outer key: Node identifier (string representation of NodeKind) -/// - Inner key: Channel name ("messages", "extra", etc.) -/// - Value: Last version number the node processed for that channel -/// -/// # Examples +/// `versions_seen[node_id][channel]` records the last version a node consumed +/// per channel. When a snapshot's version exceeds that value the node runs +/// again; otherwise it is skipped. /// /// ```rust /// use weavegraph::channels::Channel; /// use weavegraph::schedulers::{Scheduler, SchedulerState}; /// use weavegraph::state::VersionedState; /// -/// let mut sched_state = SchedulerState::default(); /// let scheduler = Scheduler::new(2); -/// -/// // Simulate a snapshot with version changes -/// let mut state = VersionedState::builder().build(); -/// state.messages.set_version(3); -/// let snapshot = state.snapshot(); -/// -/// // Record that a node has seen these versions -/// scheduler.record_seen(&mut sched_state, "node_a", &snapshot); -/// -/// // Later checks will use this information for gating -/// assert!(!scheduler.should_run(&sched_state, "node_a", &snapshot)); +/// let mut state = SchedulerState::default(); +/// let mut vs = VersionedState::builder().build(); +/// vs.messages.set_version(3); +/// let snap = vs.snapshot(); +/// scheduler.record_seen(&mut state, "node_a", &snap); +/// assert!(!scheduler.should_run(&state, "node_a", &snap)); /// ``` #[derive(Debug, Default, Clone)] pub struct SchedulerState { - /// `versions_seen[node_id][channel]` stores the last version observed when the node ran. + /// `versions_seen[node_id][channel]` — last version the node processed. pub versions_seen: FxHashMap>, } -/// High-performance frontier scheduler with version gating and bounded concurrency. -/// -/// The `Scheduler` is the core execution engine for workflow steps. It manages -/// parallel node execution while ensuring consistency through version-based -/// gating. The scheduler is stateless by design - all persistence is handled -/// through the separate `SchedulerState`. -/// -/// # Architecture -/// -/// The scheduler implements a "superstep" execution model: -/// 1. **Frontier Analysis**: Determine which nodes are eligible to run -/// 2. **Version Gating**: Skip nodes that have already processed current state -/// 3. **Concurrent Execution**: Run eligible nodes with bounded parallelism -/// 4. **Result Collection**: Gather outputs preserving execution metadata +/// Frontier scheduler with version gating and bounded concurrency. /// -/// # Performance Characteristics -/// -/// - **Concurrency**: Configurable parallelism with `concurrency_limit` -/// - **Efficiency**: Version gating prevents redundant node execution -/// - **Scalability**: Stateless design enables easy distribution -/// - **Determinism**: Execution order is preserved in results -/// -/// # Examples +/// Stateless execution engine — all tracking lives in [`SchedulerState`]. +/// Eligible nodes run concurrently up to `concurrency_limit`; structural nodes +/// and nodes that have already processed the current state are skipped. /// /// ```rust /// use weavegraph::schedulers::Scheduler; /// -/// // Create scheduler with specific concurrency limit -/// let scheduler = Scheduler::new(8); // Max 8 concurrent nodes -/// assert_eq!(scheduler.concurrency_limit, 8); -/// -/// // Zero concurrency defaults to 1 for safety -/// let safe_scheduler = Scheduler::new(0); -/// assert_eq!(safe_scheduler.concurrency_limit, 1); +/// assert_eq!(Scheduler::new(8).concurrency_limit, 8); +/// assert_eq!(Scheduler::new(0).concurrency_limit, 1); // zero clamps to 1 /// ``` #[derive(Debug, Default, Clone)] pub struct Scheduler { - /// Maximum number of nodes that may execute concurrently in a single superstep. + /// Maximum nodes that may run concurrently in a single superstep. pub concurrency_limit: usize, } -/// Errors that can occur during scheduler execution. -/// -/// This enum represents the various failure modes that can happen during -/// workflow execution in the scheduler. Each variant provides specific -/// context about the failure to enable appropriate error handling and -/// debugging. -/// -/// # Error Handling -/// -/// These errors typically indicate either: -/// - **Node failures**: Issues within individual node execution -/// - **System failures**: Problems with the async runtime or task management -/// -/// # Examples -/// -/// ```rust -/// use weavegraph::schedulers::SchedulerError; -/// use weavegraph::node::NodeError; -/// use weavegraph::types::NodeKind; -/// -/// fn handle_scheduler_error(error: SchedulerError) { -/// match error { -/// SchedulerError::NodeNotFound { kind, step } => { -/// eprintln!("Node {:?} not found at step {}", kind, step); -/// // Handle missing node (graph consistency issue) -/// } -/// SchedulerError::NodeRun { kind, step, source } => { -/// eprintln!("Node {:?} failed at step {}: {}", kind, step, source); -/// // Handle node-specific failure -/// } -/// SchedulerError::Join(join_error) => { -/// eprintln!("Task coordination failed: {}", join_error); -/// // Handle system-level failure -/// } -/// } -/// } -/// ``` +/// Errors raised during scheduler execution. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum SchedulerError { - /// A node in the frontier was not found in the registry. - /// - /// This error indicates a graph consistency issue where the frontier - /// contains a node that doesn't exist in the node registry. This should - /// not occur with properly compiled graphs. + /// A frontier node was absent from the node registry. #[error("node {kind:?} in frontier not found in registry at step {step}")] #[cfg_attr( feature = "diagnostics", @@ -264,72 +138,40 @@ pub enum SchedulerError { ) )] NodeNotFound { - /// The node kind that was expected in the registry. + /// The node kind that was missing. kind: NodeKind, - /// The workflow step at which the lookup failed. + /// Workflow step at which the lookup failed. step: u64, }, - /// A node failed during execution. - /// - /// This error occurs when a node's `run` method returns an error. - /// It includes the node kind, execution step, and the underlying node error - /// to provide comprehensive context for debugging. - /// - /// # Fields - /// - `kind`: The type of node that failed - /// - `step`: The workflow step number when the failure occurred - /// - `source`: The underlying `NodeError` that caused the failure + /// A node's `run` method returned an error. #[error("node run error at step {step} for {kind:?}: {source}")] #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::node)))] NodeRun { - /// The node kind that encountered an error. + /// The node kind that failed. kind: NodeKind, - /// The workflow step at which the node failed. + /// Workflow step at which the failure occurred. step: u64, + /// Underlying node error. #[source] - /// The underlying node error. source: NodeError, }, - /// A task join operation failed. - /// - /// This error occurs when there's a problem with the async task coordination, - /// such as a task being cancelled or panicking. This typically indicates - /// a system-level issue rather than a node logic problem. - /// - /// Common causes: - /// - Task panic during execution - /// - Runtime shutdown during execution - /// - Task cancellation due to timeout or external signal + /// An async task join failed (panic or cancellation). #[error("task join error: {0}")] #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::scheduler::join)))] Join(#[from] tokio::task::JoinError), } impl Scheduler { - /// Create a new scheduler with the specified concurrency limit. - /// - /// If concurrency_limit is 0, it will be set to 1 to ensure at least - /// one concurrent task can run. - /// - /// # Parameters - /// * `concurrency_limit` - Maximum number of concurrent node executions - /// - /// # Returns - /// A new Scheduler instance configured with the given concurrency limit + /// Create a scheduler; `concurrency_limit` of 0 is clamped to 1. #[must_use] pub fn new(concurrency_limit: usize) -> Self { Self { - concurrency_limit: if concurrency_limit == 0 { - 1 - } else { - concurrency_limit - }, + concurrency_limit: concurrency_limit.max(1), } } - /// Helper to expose channel versions as generic (name, version) pairs. #[inline] fn channel_versions(snap: &StateSnapshot) -> [(&'static str, u64); 2] { [ @@ -338,37 +180,13 @@ impl Scheduler { ] } - /// Decide if a node should run given the pre-barrier snapshot. - /// - /// Returns true if any channel version increased since this node last ran. - /// This enables efficient incremental execution by skipping nodes that - /// have already processed the current state. - /// - /// # Parameters - /// * `state` - Current scheduler state tracking execution history - /// * `node_id` - Identifier of the node to check - /// * `snap` - Current state snapshot with version information - /// - /// # Returns - /// `true` if the node should run, `false` if it can be skipped + /// Return `true` if the node should run given the snapshot's current versions. #[must_use] pub fn should_run(&self, state: &SchedulerState, node_id: &str, snap: &StateSnapshot) -> bool { - let channels = Self::channel_versions(snap); - self.should_run_with(state, node_id, &channels) + self.should_run_with(state, node_id, &Self::channel_versions(snap)) } - /// Generic form of should_run: decide based on provided (channel_name, version) pairs. - /// - /// This method provides the core scheduling logic and can be used with - /// custom channel configurations or for testing purposes. - /// - /// # Parameters - /// * `state` - Current scheduler state tracking execution history - /// * `node_id` - Identifier of the node to check - /// * `channels` - Array of (channel_name, version) pairs to check - /// - /// # Returns - /// `true` if the node should run based on version changes, `false` otherwise + /// Return `true` if any channel version exceeds what `node_id` last processed. #[must_use] pub fn should_run_with( &self, @@ -376,89 +194,27 @@ impl Scheduler { node_id: &str, channels: &[(&str, u64)], ) -> bool { - let seen = match state.versions_seen.get(node_id) { - Some(v) => v, - None => return true, // never ran -> run + let Some(seen) = state.versions_seen.get(node_id) else { + return true; }; - for (name, ver) in channels.iter() { - let last = seen.get::(name).copied().unwrap_or(0); - if *ver > last { - return true; - } - } - false + channels + .iter() + .any(|&(name, ver)| ver > seen.get(name).copied().unwrap_or(0)) } - /// Record the versions seen for a node at the start of its execution. - /// - /// This method updates the scheduler state to track which versions of each - /// channel a node has processed. This information is used by version gating - /// to determine whether a node needs to run again in future supersteps. - /// - /// The versions are captured from the pre-barrier snapshot to ensure - /// consistency with the state the node actually processes. - /// - /// # Parameters - /// * `state` - Mutable scheduler state to update with version information - /// * `node_id` - String identifier of the node (typically `format!("{:?}", node_kind)`) - /// * `snap` - State snapshot containing current channel versions - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::channels::Channel; - /// use weavegraph::schedulers::{Scheduler, SchedulerState}; - /// use weavegraph::state::VersionedState; - /// - /// let scheduler = Scheduler::new(2); - /// let mut state = SchedulerState::default(); - /// - /// let mut snapshot_state = VersionedState::builder().build(); - /// snapshot_state.messages.set_version(5); - /// snapshot_state.extra.set_version(3); - /// let snapshot = snapshot_state.snapshot(); - /// - /// // Record that node_a has processed versions 5 and 3 - /// scheduler.record_seen(&mut state, "node_a", &snapshot); - /// - /// // Future checks will use this information - /// assert!(!scheduler.should_run(&state, "node_a", &snapshot)); - /// ``` + /// Record that `node_id` has processed the snapshot's channel versions. pub fn record_seen(&self, state: &mut SchedulerState, node_id: &str, snap: &StateSnapshot) { - let channels = Self::channel_versions(snap); - self.record_seen_with(state, node_id, &channels); + self.record_seen_with(state, node_id, &Self::channel_versions(snap)); } - /// Generic form of record_seen: store versions for provided channel/version pairs. - /// - /// This is the low-level version tracking method that allows recording - /// arbitrary channel versions. It's used internally by `record_seen` and - /// can be used directly for custom channel configurations or testing. - /// - /// The method updates the internal `versions_seen` map structure: - /// `versions_seen[node_id][channel_name] = version` - /// - /// # Parameters - /// * `state` - Mutable scheduler state to update - /// * `node_id` - String identifier of the node - /// * `channels` - Slice of (channel_name, version) pairs to record - /// - /// # Examples + /// Record arbitrary `(channel, version)` pairs as seen by `node_id`. /// /// ```rust /// use weavegraph::schedulers::{Scheduler, SchedulerState}; /// - /// let scheduler = Scheduler::new(2); /// let mut state = SchedulerState::default(); - /// - /// // Record custom channel versions - /// let channels = [("messages", 10), ("custom_channel", 5)]; - /// scheduler.record_seen_with(&mut state, "node_x", &channels); - /// - /// // Verify the versions were recorded - /// let node_versions = &state.versions_seen["node_x"]; - /// assert_eq!(node_versions["messages"], 10); - /// assert_eq!(node_versions["custom_channel"], 5); + /// Scheduler::new(1).record_seen_with(&mut state, "n", &[("messages", 7)]); + /// assert_eq!(state.versions_seen["n"]["messages"], 7); /// ``` pub fn record_seen_with( &self, @@ -466,128 +222,46 @@ impl Scheduler { node_id: &str, channels: &[(&str, u64)], ) { - let entry = state.versions_seen.entry(node_id.to_string()).or_default(); - for (name, ver) in channels.iter() { - entry.insert((*name).to_string(), *ver); + let entry = state.versions_seen.entry(node_id.to_owned()).or_default(); + for &(name, ver) in channels { + entry.insert(name.to_owned(), ver); } } - /// Execute a single superstep over a frontier with bounded concurrency. - /// - /// This is the core execution method of the scheduler. It processes a frontier - /// of nodes, applying version gating to skip unnecessary work, and executes - /// eligible nodes concurrently with the configured parallelism limit. - /// - /// # Execution Flow - /// - /// 1. **Frontier Partitioning**: Separate nodes into "to run" vs "skipped" - /// - Skip `NodeKind::End` nodes (terminal nodes) - /// - Skip nodes that have already processed current state versions - /// 2. **Task Creation**: Build async tasks for eligible nodes - /// 3. **Concurrent Execution**: Run tasks with bounded parallelism - /// 4. **Result Collection**: Gather outputs, preserving execution metadata - /// 5. **Version Recording**: Update state with processed versions - /// - /// # Version Gating - /// - /// Nodes are skipped if they have already processed the current state versions, - /// preventing redundant computation. This enables efficient incremental execution - /// in long-running workflows. - /// - /// # Concurrency Model - /// - /// - **Bounded Parallelism**: Respects `concurrency_limit` to control resource usage - /// - **Unordered Completion**: Tasks may complete out of order for efficiency - /// - **Deterministic Results**: `ran_nodes` preserves scheduling order - /// - /// # Parameters - /// * `state` - Mutable scheduler state for version tracking - /// * `nodes` - Registry mapping node kinds to their implementations - /// * `frontier` - Vector of nodes eligible for execution this step - /// * `snap` - Pre-barrier state snapshot for version gating - /// * `step` - Current workflow step number (for context and logging) - /// * `run_context` - Runtime context injected into node execution - /// - /// # Returns - /// * `Ok(StepRunResult)` - Execution results with ran/skipped nodes and outputs - /// * `Err(SchedulerError)` - Node execution failure or task coordination error - /// - /// # Examples + /// Execute one superstep over `frontier` with bounded concurrency. /// - /// ```rust - /// use weavegraph::channels::Channel; - /// use weavegraph::event_bus::EventBus; - /// use weavegraph::schedulers::{Scheduler, SchedulerRunContext, SchedulerState}; - /// use weavegraph::state::VersionedState; - /// use weavegraph::types::NodeKind; - /// use rustc_hash::FxHashMap; - /// use std::sync::Arc; - /// - /// # async fn example() -> Result<(), Box> { - /// let scheduler = Scheduler::new(4); - /// let mut state = SchedulerState::default(); - /// let nodes = FxHashMap::default(); // Node registry - /// let event_bus = EventBus::default(); - /// - /// let frontier = vec![NodeKind::Start, NodeKind::Custom("process".into())]; - /// let snapshot = VersionedState::builder().build().snapshot(); - /// - /// let result = scheduler.superstep( - /// &mut state, - /// &nodes, - /// frontier, - /// snapshot, - /// 1, - /// SchedulerRunContext::new(event_bus.get_emitter()), - /// ).await?; - /// - /// println!("Executed {} nodes, skipped {}", - /// result.ran_nodes.len(), - /// result.skipped_nodes.len()); - /// # Ok(()) - /// # } - /// ``` - /// - /// # Error Handling - /// - /// - **Node Failures**: If any node returns an error, the entire superstep fails - /// - **Task Panics**: Panicking nodes result in `SchedulerError::Join` - /// - **Missing Nodes**: Panics if frontier contains nodes not in registry + /// Structural nodes (`Start`, `End`) and version-gated nodes are skipped. + /// Remaining nodes run concurrently up to `concurrency_limit`. Returns the + /// step summary or the first error encountered. #[instrument(skip(self, state, nodes, frontier, snap, run_context))] pub async fn superstep( &self, state: &mut SchedulerState, - nodes: &FxHashMap>, // registry - frontier: Vec, // frontier for this step - snap: StateSnapshot, // pre-barrier snapshot + nodes: &FxHashMap>, + frontier: Vec, + snap: StateSnapshot, step: u64, run_context: SchedulerRunContext, ) -> Result { - // Partition frontier into to_run vs skipped using a skip predicate and version gating. let channels = Self::channel_versions(&snap); - // Skip virtual Start and End nodes (they are not executed, only structural) - let skip_predicate = |k: &NodeKind| matches!(k, NodeKind::Start | NodeKind::End); let mut to_run: Vec = Vec::new(); + let mut to_run_ids: Vec = Vec::new(); let mut skipped_kinds: Vec = Vec::new(); - for k in frontier.into_iter() { - if skip_predicate(&k) { - skipped_kinds.push(k); + + for kind in frontier { + if matches!(kind, NodeKind::Start | NodeKind::End) { + skipped_kinds.push(kind); continue; } - let id_str = format!("{:?}", k); - if self.should_run_with(state, &id_str, &channels) { - to_run.push(k); + let id = format!("{kind:?}"); + if self.should_run_with(state, &id, &channels) { + to_run_ids.push(id); + to_run.push(kind); } else { - skipped_kinds.push(k); + skipped_kinds.push(kind); } } - // Build tasks for the nodes to run. - let to_run_ids: Vec = to_run.iter().map(|k| format!("{:?}", k)).collect(); - - // Pre-validate all nodes exist in registry before creating tasks. - // This allows us to return early with a proper error rather than - // encountering issues mid-execution. for kind in &to_run { if !nodes.contains_key(kind) { return Err(SchedulerError::NodeNotFound { @@ -597,48 +271,33 @@ impl Scheduler { } } - let tasks = to_run_ids + let tasks: Vec<_> = to_run .iter() - .cloned() - .zip(to_run.clone().into_iter()) - .map(|(id_str, kind)| { - // SAFETY: We validated all nodes exist above, so this unwrap is safe. - let node = nodes.get(&kind).unwrap().clone(); - let event_emitter = Arc::clone(&run_context.event_emitter); - let clock = run_context.clock.clone(); - let invocation_id = run_context.invocation_id.clone(); + .zip(&to_run_ids) + .map(|(kind, id)| { + let node = nodes[kind].clone(); let ctx = NodeContext { - node_id: id_str.clone(), + node_id: id.clone(), step, - event_emitter, - clock, - invocation_id, + event_emitter: Arc::clone(&run_context.event_emitter), + clock: run_context.clock.clone(), + invocation_id: run_context.invocation_id.clone(), }; let s = snap.clone(); - async move { - // Return Result and let caller collect - let out = node.run(s, ctx).await; - (kind, out) - } - }); + let kind = kind.clone(); + async move { (kind, node.run(s, ctx).await) } + }) + .collect(); - // Execute with bounded concurrency; completion order may differ. let mut outputs: Vec<(NodeKind, NodePartial)> = Vec::new(); let mut stream = stream::iter(tasks).buffer_unordered(self.concurrency_limit); while let Some((kind, res)) = stream.next().await { match res { Ok(part) => outputs.push((kind, part)), - Err(e) => { - return Err(SchedulerError::NodeRun { - kind, - step, - source: e, - }); - } + Err(e) => return Err(SchedulerError::NodeRun { kind, step, source: e }), } } - // Record versions seen for nodes that ran. for id in &to_run_ids { self.record_seen_with(state, id, &channels); } diff --git a/src/telemetry/mod.rs b/src/telemetry/mod.rs index b362e1c..b598fa6 100644 --- a/src/telemetry/mod.rs +++ b/src/telemetry/mod.rs @@ -1,85 +1,65 @@ -//! Telemetry formatting utilities for rendering workflow events as human-readable or machine-readable output. -use crate::channels::errors::ErrorEvent; +//! Telemetry formatting: renders workflow events and errors as human-readable or machine-readable text. + +use crate::channels::errors::{ErrorEvent, WeaveError}; use crate::event_bus::Event; use std::io::IsTerminal; use std::sync::OnceLock; -/// ANSI escape code for green context text in telemetry output. -pub const CONTEXT_COLOR: &str = "\x1b[32m"; // green -/// ANSI escape code for magenta line text in telemetry output. -pub const LINE_COLOR: &str = "\x1b[35m"; // magenta / dark pink -/// ANSI escape code to reset terminal color after colored output. +/// ANSI green — used for scope context labels. +pub const CONTEXT_COLOR: &str = "\x1b[32m"; +/// ANSI magenta — used for event and error body lines. +pub const LINE_COLOR: &str = "\x1b[35m"; +/// ANSI reset — clears all color attributes. pub const RESET_COLOR: &str = "\x1b[0m"; -static IS_STDERR_TERMINAL: OnceLock = OnceLock::new(); +static STDERR_IS_TERMINAL: OnceLock = OnceLock::new(); -fn get_is_stderr_terminal() -> bool { - *IS_STDERR_TERMINAL.get_or_init(|| std::io::stderr().is_terminal()) +fn stderr_is_terminal() -> bool { + *STDERR_IS_TERMINAL.get_or_init(|| std::io::stderr().is_terminal()) } -/// Formatter color mode for telemetry output. -/// -/// Controls whether ANSI color codes are included in formatted output: -/// - [`FormatterMode::Auto`]: Automatically detects TTY capability via `stderr.is_terminal()` -/// - [`FormatterMode::Colored`]: Always include color codes (for forced color output) -/// - [`FormatterMode::Plain`]: Never include color codes (for logs/files) -/// -/// # Examples -/// ``` -/// use weavegraph::telemetry::FormatterMode; -/// -/// // Auto-detect based on TTY -/// let mode = FormatterMode::auto_detect(); +/// Color mode for telemetry output. /// -/// // Force colored output -/// let mode = FormatterMode::Colored; -/// -/// // Force plain output for logging -/// let mode = FormatterMode::Plain; -/// ``` +/// - [`Auto`](FormatterMode::Auto): detect TTY on first use via `stderr.is_terminal()` +/// - [`Colored`](FormatterMode::Colored): always emit ANSI codes +/// - [`Plain`](FormatterMode::Plain): never emit ANSI codes #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum FormatterMode { - /// Auto-detect TTY capability (checks `stderr.is_terminal()`) + /// Detect TTY on first use; resolves to `Colored` or `Plain`. #[default] Auto, - /// Always include ANSI color codes + /// Always emit ANSI color codes. Colored, - /// Never include ANSI color codes + /// Never emit ANSI color codes. Plain, } impl FormatterMode { - /// Auto-detect formatter mode based on stderr TTY capability. - /// - /// Returns `FormatterMode::Colored` if stderr is a terminal, otherwise `FormatterMode::Plain`. + /// Resolve `Auto` against the current stderr TTY and return a concrete mode. pub fn auto_detect() -> Self { - if get_is_stderr_terminal() { - FormatterMode::Colored + if stderr_is_terminal() { + Self::Colored } else { - FormatterMode::Plain + Self::Plain } } - /// Returns true if this mode should use colored output. - /// - /// For `Auto` mode, performs TTY detection on each call. + /// Return `true` if this mode produces colored output. pub fn is_colored(&self) -> bool { match self { - FormatterMode::Auto => get_is_stderr_terminal(), - FormatterMode::Colored => true, - FormatterMode::Plain => false, + Self::Auto => stderr_is_terminal(), + Self::Colored => true, + Self::Plain => false, } } } -// Default is derived; Auto is the default variant. - -/// Rendered output for a telemetry item that can be consumed by sinks. +/// Rendered output for a single telemetry item. #[derive(Clone, Debug, Default)] pub struct EventRender { - /// Optional context prefix shown before the event lines. + /// Optional scope label shown before the event lines. pub context: Option, - /// One or more formatted lines for this event. + /// Formatted output lines for this event. pub lines: Vec, } @@ -90,67 +70,40 @@ impl EventRender { } } -/// Trait for formatting workflow events and errors into rendered output. +/// Formats workflow events and errors into [`EventRender`] values. pub trait TelemetryFormatter: Send + Sync { - /// Render a single [`Event`] into an [`EventRender`]. + /// Render a single [`Event`]. fn render_event(&self, event: &Event) -> EventRender; /// Render a slice of [`ErrorEvent`]s, one [`EventRender`] per error. fn render_errors(&self, errors: &[ErrorEvent]) -> Vec; } -/// Plain text formatter with optional ANSI color codes. -/// -/// Color output is controlled by [`FormatterMode`]: -/// - `Auto`: Uses color when stderr is a TTY -/// - `Colored`: Always uses color -/// - `Plain`: Never uses color -/// -/// # Examples -/// ``` -/// use weavegraph::telemetry::{PlainFormatter, FormatterMode}; +/// Plain-text formatter with optional ANSI color support. /// -/// // Auto-detect TTY -/// let formatter = PlainFormatter::new(); -/// -/// // Force colored output -/// let formatter = PlainFormatter::with_mode(FormatterMode::Colored); -/// -/// // Force plain output (no colors) -/// let formatter = PlainFormatter::with_mode(FormatterMode::Plain); -/// ``` +/// Color output is governed by [`FormatterMode`]. pub struct PlainFormatter { mode: FormatterMode, } impl PlainFormatter { - /// Create a new formatter with auto-detected color mode. + /// Create a formatter with auto-detected color mode. pub fn new() -> Self { Self { mode: FormatterMode::Auto, } } - /// Create a new formatter with explicit color mode. + /// Create a formatter with an explicit color mode. pub fn with_mode(mode: FormatterMode) -> Self { Self { mode } } - /// Get color prefix string based on current mode. - fn color<'a>(&self, ansi_code: &'a str) -> &'a str { - if self.mode.is_colored() { - ansi_code - } else { - "" - } + fn paint<'a>(&self, code: &'a str) -> &'a str { + if self.mode.is_colored() { code } else { "" } } - /// Get reset color string based on current mode. fn reset(&self) -> &str { - if self.mode.is_colored() { - RESET_COLOR - } else { - "" - } + self.paint(RESET_COLOR) } } @@ -160,93 +113,66 @@ impl Default for PlainFormatter { } } -fn format_error_chain( - error: &crate::channels::errors::WeaveError, - indent: usize, - use_color: bool, -) -> Vec { - let mut lines = Vec::new(); - if let Some(cause) = &error.cause { - let indent_str = " ".repeat(indent); - if use_color { - lines.push(format!( - "{LINE_COLOR}{}cause: {}{RESET_COLOR}\n", - indent_str, cause.message - )); - } else { - lines.push(format!("{}cause: {}\n", indent_str, cause.message)); - } - lines.extend(format_error_chain(cause, indent + 1, use_color)); - } +fn cause_chain(error: &WeaveError, depth: usize, fmt: &PlainFormatter) -> Vec { + let Some(cause) = &error.cause else { + return Vec::new(); + }; + let indent = " ".repeat(depth); + let mut lines = vec![format!( + "{}{}cause: {}{}\n", + fmt.paint(LINE_COLOR), + indent, + cause.message, + fmt.reset() + )]; + lines.extend(cause_chain(cause, depth + 1, fmt)); lines } impl TelemetryFormatter for PlainFormatter { fn render_event(&self, event: &Event) -> EventRender { - let line = if self.mode.is_colored() { - format!("{LINE_COLOR}{}{RESET_COLOR}\n", event) - } else { - format!("{}\n", event) - }; + let line = format!("{}{}{}\n", self.paint(LINE_COLOR), event, self.reset()); EventRender { - context: event.scope_label().map(|s| s.to_string()), + context: event.scope_label().map(str::to_owned), lines: vec![line], } } fn render_errors(&self, errors: &[ErrorEvent]) -> Vec { - let use_color = self.mode.is_colored(); errors .iter() .enumerate() .map(|(i, e)| { - let mut lines = Vec::new(); - let scope_str = if use_color { - format!("{}{:?}{}", self.color(CONTEXT_COLOR), e.scope, self.reset()) - } else { - format!("{:?}", e.scope) - }; - lines.push(format!("[{}] {} | {}\n", i, e.when, scope_str)); - - if use_color { + let scope = format!( + "{}{:?}{}", + self.paint(CONTEXT_COLOR), + e.scope, + self.reset() + ); + let mut lines = vec![format!("[{}] {} | {}\n", i, e.when, scope)]; + lines.push(format!( + "{} error: {}{}\n", + self.paint(LINE_COLOR), + e.error.message, + self.reset() + )); + lines.extend(cause_chain(&e.error, 1, self)); + if !e.tags.is_empty() { lines.push(format!( - "{} error: {}{}\n", - self.color(LINE_COLOR), - e.error.message, + "{} tags: {:?}{}\n", + self.paint(LINE_COLOR), + e.tags, self.reset() )); - } else { - lines.push(format!(" error: {}\n", e.error.message)); - } - - lines.extend(format_error_chain(&e.error, 1, use_color)); - - if !e.tags.is_empty() { - if use_color { - lines.push(format!( - "{} tags: {:?}{}\n", - self.color(LINE_COLOR), - e.tags, - self.reset() - )); - } else { - lines.push(format!(" tags: {:?}\n", e.tags)); - } } - if !e.context.is_null() { - if use_color { - lines.push(format!( - "{} context: {}{}\n", - self.color(LINE_COLOR), - e.context, - self.reset() - )); - } else { - lines.push(format!(" context: {}\n", e.context)); - } + lines.push(format!( + "{} context: {}{}\n", + self.paint(LINE_COLOR), + e.context, + self.reset() + )); } - EventRender { context: Some(format!("{:?}", e.scope)), lines, From 04cc58947abf8281f9f3fcbda99af717f9ae23bb Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 16:15:02 -0400 Subject: [PATCH 06/15] a bunch of fmt fixes --- src/app.rs | 16 +++++++-- src/channels/errors.rs | 21 ++++++++++-- src/channels/errors_channel.rs | 10 ++++-- src/channels/extras.rs | 10 ++++-- src/channels/messages.rs | 10 ++++-- src/event_bus/bus.rs | 9 +++-- src/event_bus/diagnostics.rs | 5 ++- src/event_bus/event.rs | 36 +++++++++++++++----- src/event_bus/hub.rs | 31 +++++++++++++---- src/event_bus/sink.rs | 25 +++++++++++--- src/graphs/builder.rs | 8 +++-- src/graphs/compilation.rs | 8 ++++- src/graphs/iteration.rs | 40 +++++++++++++++++----- src/graphs/petgraph_compat.rs | 8 ++++- src/message.rs | 5 ++- src/runtimes/checkpointer_sqlite.rs | 52 +++++++++++++++++------------ src/runtimes/execution.rs | 1 - src/runtimes/mod.rs | 16 ++++----- src/runtimes/persistence.rs | 6 +++- src/runtimes/replay.rs | 13 ++++++-- src/runtimes/runner.rs | 30 ++++++++--------- src/runtimes/runtime_config.rs | 5 ++- src/runtimes/session.rs | 1 - src/runtimes/streaming.rs | 1 - src/schedulers/scheduler.rs | 8 ++++- src/state.rs | 9 +++-- src/telemetry/mod.rs | 7 +--- tests/event_bus.rs | 4 ++- 28 files changed, 285 insertions(+), 110 deletions(-) diff --git a/src/app.rs b/src/app.rs index 43ed4c2..a994d09 100644 --- a/src/app.rs +++ b/src/app.rs @@ -427,9 +427,14 @@ impl App { F: FnOnce() -> (EventBus, R), { let (event_bus, output) = build_event_bus(); - let runner = self.build_runner(event_bus, autosave, checkpointer_override).await; + let runner = self + .build_runner(event_bus, autosave, checkpointer_override) + .await; let session_id = self.next_session_id(); - (Self::run_session(runner, session_id, initial_state).await, output) + ( + Self::run_session(runner, session_id, initial_state).await, + output, + ) } /// Invoke the workflow asynchronously while streaming events to the caller. @@ -493,7 +498,12 @@ impl App { let runner = self.build_runner(event_bus, true, None).await; let session_id = self.next_session_id(); let join = tokio::spawn(Self::run_session(runner, session_id, initial_state)); - (InvocationHandle { join_handle: Some(join) }, event_stream) + ( + InvocationHandle { + join_handle: Some(join), + }, + event_stream, + ) } /// Execute the workflow to completion using the runtime-configured event bus. diff --git a/src/channels/errors.rs b/src/channels/errors.rs index d7c6fca..6f23f12 100644 --- a/src/channels/errors.rs +++ b/src/channels/errors.rs @@ -70,7 +70,13 @@ impl ErrorEvent { /// Creates a node-scoped error event. pub fn node>(kind: S, step: u64, error: WeaveError) -> Self { - Self::with_scope(ErrorScope::Node { kind: kind.into(), step }, error) + Self::with_scope( + ErrorScope::Node { + kind: kind.into(), + step, + }, + error, + ) } /// Creates a scheduler-scoped error event. @@ -80,7 +86,13 @@ impl ErrorEvent { /// Creates a runner-scoped error event. pub fn runner>(session: S, step: u64, error: WeaveError) -> Self { - Self::with_scope(ErrorScope::Runner { session: session.into(), step }, error) + Self::with_scope( + ErrorScope::Runner { + session: session.into(), + step, + }, + error, + ) } /// Creates an app-scoped error event. @@ -165,7 +177,10 @@ impl std::error::Error for WeaveError { impl WeaveError { /// Constructs an error from a message. pub fn msg>(m: M) -> Self { - Self { message: m.into(), ..Default::default() } + Self { + message: m.into(), + ..Default::default() + } } /// Attaches structured details. diff --git a/src/channels/errors_channel.rs b/src/channels/errors_channel.rs index 8006de1..c564ef4 100644 --- a/src/channels/errors_channel.rs +++ b/src/channels/errors_channel.rs @@ -13,13 +13,19 @@ pub struct ErrorsChannel { impl ErrorsChannel { /// Creates a new `ErrorsChannel` with the given events and version counter. pub fn new(events: Vec, version: u32) -> Self { - Self { value: events, version } + Self { + value: events, + version, + } } } impl Default for ErrorsChannel { fn default() -> Self { - Self { value: Vec::new(), version: 1 } + Self { + value: Vec::new(), + version: 1, + } } } diff --git a/src/channels/extras.rs b/src/channels/extras.rs index 05137c2..806b703 100644 --- a/src/channels/extras.rs +++ b/src/channels/extras.rs @@ -16,13 +16,19 @@ pub struct ExtrasChannel { impl ExtrasChannel { /// Creates a new `ExtrasChannel` with the given map and version counter. pub fn new(extras: ChannelValue, version: u32) -> Self { - Self { value: extras, version } + Self { + value: extras, + version, + } } } impl Default for ExtrasChannel { fn default() -> Self { - Self { value: FxHashMap::default(), version: 1 } + Self { + value: FxHashMap::default(), + version: 1, + } } } diff --git a/src/channels/messages.rs b/src/channels/messages.rs index 205d9f4..a17f116 100644 --- a/src/channels/messages.rs +++ b/src/channels/messages.rs @@ -14,13 +14,19 @@ pub struct MessagesChannel { impl MessagesChannel { /// Creates a new `MessagesChannel` with the given messages and version counter. pub fn new(messages: ChannelValue, version: u32) -> Self { - Self { value: messages, version } + Self { + value: messages, + version, + } } } impl Default for MessagesChannel { fn default() -> Self { - Self { value: Vec::new(), version: 1 } + Self { + value: Vec::new(), + version: 1, + } } } diff --git a/src/event_bus/bus.rs b/src/event_bus/bus.rs index 1e9f821..0195969 100644 --- a/src/event_bus/bus.rs +++ b/src/event_bus/bus.rs @@ -2,8 +2,8 @@ use std::collections::HashMap; use std::io; -use std::sync::{Arc, Mutex}; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; use chrono::Utc; use tokio::sync::{broadcast, oneshot}; @@ -251,7 +251,12 @@ impl SinkEntry { let sink_name = self.name.clone(); let (shutdown_tx, mut shutdown_rx) = oneshot::channel(); let mut stream = hub.subscribe(); - let WorkerDiag { tx: diagnostics_tx, health, enabled: diagnostics_enabled, emit_as_events: emit_diagnostics_as_events } = diag; + let WorkerDiag { + tx: diagnostics_tx, + health, + enabled: diagnostics_enabled, + emit_as_events: emit_diagnostics_as_events, + } = diag; let handle = task::spawn(async move { loop { if generation_counter.load(Ordering::SeqCst) != spawned_generation { diff --git a/src/event_bus/diagnostics.rs b/src/event_bus/diagnostics.rs index 054cc37..1bb14bc 100644 --- a/src/event_bus/diagnostics.rs +++ b/src/event_bus/diagnostics.rs @@ -5,7 +5,10 @@ use std::time::Duration; use chrono::{DateTime, Utc}; use futures_util::stream::{self, BoxStream, StreamExt}; use serde::{Deserialize, Serialize}; -use tokio::sync::broadcast::{Receiver, error::{RecvError, TryRecvError}}; +use tokio::sync::broadcast::{ + Receiver, + error::{RecvError, TryRecvError}, +}; use tokio::time::timeout; /// A single error event emitted when a sink fails. diff --git a/src/event_bus/event.rs b/src/event_bus/event.rs index cec9d2e..3476770 100644 --- a/src/event_bus/event.rs +++ b/src/event_bus/event.rs @@ -58,14 +58,22 @@ impl Event { metadata: FxHashMap, ) -> Self { Event::Node( - NodeEvent::new(Some(node_id.into()), Some(step), scope.into(), message.into()) - .with_metadata(metadata), + NodeEvent::new( + Some(node_id.into()), + Some(step), + scope.into(), + message.into(), + ) + .with_metadata(metadata), ) } /// Construct a diagnostic event. pub fn diagnostic(scope: impl Into, message: impl Into) -> Self { - Event::Diagnostic(DiagnosticEvent { scope: scope.into(), message: message.into() }) + Event::Diagnostic(DiagnosticEvent { + scope: scope.into(), + message: message.into(), + }) } /// Returns the scope label for this event. @@ -118,8 +126,11 @@ impl Event { let (event_type, metadata, timestamp) = match self { Event::Node(n) => { - let mut meta: serde_json::Map = - n.metadata().iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + let mut meta: serde_json::Map = n + .metadata() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); if let Some(id) = n.node_id() { meta.insert("node_id".to_owned(), json!(id)); } @@ -130,8 +141,11 @@ impl Event { } Event::Diagnostic(_) => ("diagnostic", json!({}), Utc::now()), Event::LLM(l) => { - let mut meta: serde_json::Map = - l.metadata().iter().map(|(k, v)| (k.clone(), v.clone())).collect(); + let mut meta: serde_json::Map = l + .metadata() + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); if let Some(id) = l.session_id() { meta.insert("session_id".to_owned(), json!(id)); } @@ -223,7 +237,13 @@ pub struct NodeEvent { impl NodeEvent { /// Create a new node event. pub fn new(node_id: Option, step: Option, scope: String, message: String) -> Self { - Self { node_id, step, scope, message, metadata: FxHashMap::default() } + Self { + node_id, + step, + scope, + message, + metadata: FxHashMap::default(), + } } /// Returns the node ID, if set. diff --git a/src/event_bus/hub.rs b/src/event_bus/hub.rs index 9e3e899..bebf953 100644 --- a/src/event_bus/hub.rs +++ b/src/event_bus/hub.rs @@ -67,7 +67,11 @@ impl EventHub { rx } }; - EventStream { receiver, hub: Arc::clone(self), shutdown: None } + EventStream { + receiver, + hub: Arc::clone(self), + shutdown: None, + } } /// Returns the configured buffer capacity of the underlying broadcast channel. @@ -90,16 +94,24 @@ impl EventHub { /// Create a [`HubEmitter`] that publishes events to this hub. pub fn emitter(self: &Arc) -> HubEmitter { - HubEmitter { hub: Arc::clone(self) } + HubEmitter { + hub: Arc::clone(self), + } } /// Close the hub and signal all subscribers that no further events will arrive. pub fn close(&self) { - self.sender.write().expect("hub sender lock poisoned").take(); + self.sender + .write() + .expect("hub sender lock poisoned") + .take(); } fn current_sender(&self) -> Option> { - self.sender.read().expect("hub sender lock poisoned").clone() + self.sender + .read() + .expect("hub sender lock poisoned") + .clone() } fn record_lag(&self, missed: u64) { @@ -167,7 +179,10 @@ impl EventStream { /// Convert this stream into a synchronous blocking iterator. pub fn into_blocking_iter(self) -> BlockingEventIter { - BlockingEventIter { receiver: self.receiver, hub: self.hub } + BlockingEventIter { + receiver: self.receiver, + hub: self.hub, + } } /// Attach a shutdown watch channel; the stream ends when the watch value becomes `true`. @@ -178,7 +193,11 @@ impl EventStream { /// Convert this stream into a pinned `BoxStream` for use with async combinators. pub fn into_async_stream(self) -> BoxStream<'static, Event> { - let EventStream { receiver, hub, shutdown } = self; + let EventStream { + receiver, + hub, + shutdown, + } = self; stream::unfold( (receiver, hub, shutdown), |(mut receiver, hub, mut shutdown)| async move { diff --git a/src/event_bus/sink.rs b/src/event_bus/sink.rs index f78dfee..8ed9ba4 100644 --- a/src/event_bus/sink.rs +++ b/src/event_bus/sink.rs @@ -74,18 +74,27 @@ impl MemorySink { /// /// Clones the internal buffer so callers do not hold the mutex. pub fn snapshot(&self) -> Vec { - self.entries.lock().expect("MemorySink mutex poisoned").clone() + self.entries + .lock() + .expect("MemorySink mutex poisoned") + .clone() } /// Discard all captured events. pub fn clear(&self) { - self.entries.lock().expect("MemorySink mutex poisoned").clear(); + self.entries + .lock() + .expect("MemorySink mutex poisoned") + .clear(); } } impl EventSink for MemorySink { fn handle(&mut self, event: &Event) -> IoResult<()> { - self.entries.lock().expect("MemorySink mutex poisoned").push(event.clone()); + self.entries + .lock() + .expect("MemorySink mutex poisoned") + .push(event.clone()); Ok(()) } } @@ -110,7 +119,10 @@ pub struct JsonLinesSink { impl JsonLinesSink { /// Create a compact (one-line-per-event) sink writing to `handle`. pub fn new(handle: Box) -> Self { - Self { handle, pretty: false } + Self { + handle, + pretty: false, + } } /// Create a pretty-printed sink writing to `handle`. @@ -118,7 +130,10 @@ impl JsonLinesSink { /// Pretty-printed output spans multiple lines and is **not** valid JSONL. /// Use for debugging and human-readable logs only. pub fn with_pretty_print(handle: Box) -> Self { - Self { handle, pretty: true } + Self { + handle, + pretty: true, + } } /// Create a compact sink writing to stdout. diff --git a/src/graphs/builder.rs b/src/graphs/builder.rs index 988a29b..5e2bf60 100644 --- a/src/graphs/builder.rs +++ b/src/graphs/builder.rs @@ -87,7 +87,10 @@ impl GraphBuilder { pub fn add_node(mut self, id: NodeKind, node: impl Node + 'static) -> Self { match id { NodeKind::Start | NodeKind::End => { - tracing::warn!(?id, "Ignoring registration of virtual node kind (Start/End are virtual)"); + tracing::warn!( + ?id, + "Ignoring registration of virtual node kind (Start/End are virtual)" + ); } _ => { self.nodes.insert(id, Arc::new(node)); @@ -110,7 +113,8 @@ impl GraphBuilder { /// the next nodes to activate. #[must_use] pub fn add_conditional_edge(mut self, from: NodeKind, predicate: EdgePredicate) -> Self { - self.conditional_edges.push(ConditionalEdge::new(from, predicate)); + self.conditional_edges + .push(ConditionalEdge::new(from, predicate)); self } diff --git a/src/graphs/compilation.rs b/src/graphs/compilation.rs index 284a5a8..d9b14a6 100644 --- a/src/graphs/compilation.rs +++ b/src/graphs/compilation.rs @@ -77,7 +77,13 @@ impl super::builder::GraphBuilder { pub fn compile(self) -> Result { self.validate()?; let (nodes, edges, conditional_edges, runtime_config, reducer_registry) = self.into_parts(); - Ok(App::from_parts(nodes, edges, conditional_edges, runtime_config, reducer_registry)) + Ok(App::from_parts( + nodes, + edges, + conditional_edges, + runtime_config, + reducer_registry, + )) } /// Validates the graph for structural correctness. diff --git a/src/graphs/iteration.rs b/src/graphs/iteration.rs index 6954a85..857386a 100644 --- a/src/graphs/iteration.rs +++ b/src/graphs/iteration.rs @@ -188,7 +188,10 @@ mod tests { fn test_topological_sort_linear() { let mut edges: FxHashMap> = FxHashMap::default(); edges.insert(NodeKind::Start, vec![NodeKind::Custom("A".into())]); - edges.insert(NodeKind::Custom("A".into()), vec![NodeKind::Custom("B".into())]); + edges.insert( + NodeKind::Custom("A".into()), + vec![NodeKind::Custom("B".into())], + ); edges.insert(NodeKind::Custom("B".into()), vec![NodeKind::End]); let sorted = topological_sort(&edges); @@ -196,8 +199,14 @@ mod tests { assert_eq!(sorted[0], NodeKind::Start); assert_eq!(*sorted.last().unwrap(), NodeKind::End); - let a_pos = sorted.iter().position(|n| n == &NodeKind::Custom("A".into())).unwrap(); - let b_pos = sorted.iter().position(|n| n == &NodeKind::Custom("B".into())).unwrap(); + let a_pos = sorted + .iter() + .position(|n| n == &NodeKind::Custom("A".into())) + .unwrap(); + let b_pos = sorted + .iter() + .position(|n| n == &NodeKind::Custom("B".into())) + .unwrap(); assert!(a_pos < b_pos); } @@ -208,8 +217,14 @@ mod tests { NodeKind::Start, vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())], ); - edges.insert(NodeKind::Custom("A".into()), vec![NodeKind::Custom("C".into())]); - edges.insert(NodeKind::Custom("B".into()), vec![NodeKind::Custom("C".into())]); + edges.insert( + NodeKind::Custom("A".into()), + vec![NodeKind::Custom("C".into())], + ); + edges.insert( + NodeKind::Custom("B".into()), + vec![NodeKind::Custom("C".into())], + ); edges.insert(NodeKind::Custom("C".into()), vec![NodeKind::End]); let sorted = topological_sort(&edges); @@ -217,9 +232,18 @@ mod tests { assert_eq!(sorted[0], NodeKind::Start); assert_eq!(*sorted.last().unwrap(), NodeKind::End); - let a_pos = sorted.iter().position(|n| n == &NodeKind::Custom("A".into())).unwrap(); - let b_pos = sorted.iter().position(|n| n == &NodeKind::Custom("B".into())).unwrap(); - let c_pos = sorted.iter().position(|n| n == &NodeKind::Custom("C".into())).unwrap(); + let a_pos = sorted + .iter() + .position(|n| n == &NodeKind::Custom("A".into())) + .unwrap(); + let b_pos = sorted + .iter() + .position(|n| n == &NodeKind::Custom("B".into())) + .unwrap(); + let c_pos = sorted + .iter() + .position(|n| n == &NodeKind::Custom("C".into())) + .unwrap(); assert!(a_pos < c_pos); assert!(b_pos < c_pos); assert!(a_pos < b_pos); diff --git a/src/graphs/petgraph_compat.rs b/src/graphs/petgraph_compat.rs index 97fb825..f3c47ae 100644 --- a/src/graphs/petgraph_compat.rs +++ b/src/graphs/petgraph_compat.rs @@ -118,7 +118,13 @@ pub(super) fn to_dot(edges: &FxHashMap>) -> String { writeln!(out).unwrap(); for edge in conversion.graph.edge_references() { - writeln!(out, " {} -> {};", edge.source().index(), edge.target().index()).unwrap(); + writeln!( + out, + " {} -> {};", + edge.source().index(), + edge.target().index() + ) + .unwrap(); } writeln!(out, "}}").unwrap(); diff --git a/src/message.rs b/src/message.rs index 6e5d612..1e02f11 100644 --- a/src/message.rs +++ b/src/message.rs @@ -87,7 +87,10 @@ impl Message { /// Construct a message with an explicit role and content. #[must_use] pub fn with_role(role: Role, content: &str) -> Self { - Self { role, content: content.to_owned() } + Self { + role, + content: content.to_owned(), + } } /// Construct a `User` message. diff --git a/src/runtimes/checkpointer_sqlite.rs b/src/runtimes/checkpointer_sqlite.rs index 4dbce05..15a930f 100644 --- a/src/runtimes/checkpointer_sqlite.rs +++ b/src/runtimes/checkpointer_sqlite.rs @@ -110,11 +110,12 @@ impl SQLiteCheckpointer { #[must_use = "checkpointer must be used to persist state"] #[instrument(skip(database_url))] pub async fn connect(database_url: &str) -> std::result::Result { - let pool = SqlitePool::connect(database_url) - .await - .map_err(|e| CheckpointerError::Backend { - message: format!("connect: {e}"), - })?; + let pool = + SqlitePool::connect(database_url) + .await + .map_err(|e| CheckpointerError::Backend { + message: format!("connect: {e}"), + })?; #[cfg(feature = "sqlite-migrations")] sqlx::migrate!("./migrations") @@ -150,7 +151,12 @@ impl Checkpointer for SQLiteCheckpointer { let enc = EncodedCheckpoint::encode(&checkpoint)?; let mut tx = self.begin_tx().await?; - exec_insert_session(&mut tx, &checkpoint.session_id, checkpoint.concurrency_limit).await?; + exec_insert_session( + &mut tx, + &checkpoint.session_id, + checkpoint.concurrency_limit, + ) + .await?; exec_upsert_step(&mut tx, &checkpoint.session_id, checkpoint.step, &enc).await?; tx.commit().await.map_err(|e| CheckpointerError::Backend { @@ -181,22 +187,21 @@ impl Checkpointer for SQLiteCheckpointer { let concurrency_limit: i64 = row.get("concurrency_limit"); let updated_at_str: String = row.get("updated_at"); - let state_json: Option = row.try_get("last_state_json").map_err(|e| { - CheckpointerError::Backend { - message: format!("last_state_json: {e}"), - } - })?; - let frontier_json: Option = row.try_get("last_frontier_json").map_err(|e| { - CheckpointerError::Backend { - message: format!("last_frontier_json: {e}"), - } - })?; + let state_json: Option = + row.try_get("last_state_json") + .map_err(|e| CheckpointerError::Backend { + message: format!("last_state_json: {e}"), + })?; + let frontier_json: Option = + row.try_get("last_frontier_json") + .map_err(|e| CheckpointerError::Backend { + message: format!("last_frontier_json: {e}"), + })?; let versions_seen_json: Option = - row.try_get("last_versions_seen_json").map_err(|e| { - CheckpointerError::Backend { + row.try_get("last_versions_seen_json") + .map_err(|e| CheckpointerError::Backend { message: format!("last_versions_seen_json: {e}"), - } - })?; + })?; // Session row exists but no checkpoint written yet. if last_step == 0 && state_json.is_none() { @@ -396,7 +401,12 @@ impl SQLiteCheckpointer { let enc = EncodedCheckpoint::encode(&checkpoint)?; let mut tx = self.begin_tx().await?; - exec_insert_session(&mut tx, &checkpoint.session_id, checkpoint.concurrency_limit).await?; + exec_insert_session( + &mut tx, + &checkpoint.session_id, + checkpoint.concurrency_limit, + ) + .await?; if let Some(expected) = expected_last_step { let current: Option = diff --git a/src/runtimes/execution.rs b/src/runtimes/execution.rs index a5183ce..23fe81c 100644 --- a/src/runtimes/execution.rs +++ b/src/runtimes/execution.rs @@ -73,4 +73,3 @@ pub(crate) struct SchedulerOutcome { pub skipped_nodes: Vec, pub partials: Vec, } - diff --git a/src/runtimes/mod.rs b/src/runtimes/mod.rs index 8885af1..576c721 100644 --- a/src/runtimes/mod.rs +++ b/src/runtimes/mod.rs @@ -56,19 +56,19 @@ pub use checkpointer_postgres::{ #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] pub use checkpointer_sqlite::{PageInfo, SQLiteCheckpointer, StepQuery, StepQueryResult}; pub use execution::{PausedReason, PausedReport, StepOptions, StepReport, StepResult}; -pub use runner::{AppRunner, AppRunnerBuilder, RunMetadata}; -pub use session::{SessionInit, SessionState, StateVersions}; +#[cfg(feature = "metrics")] +pub use metrics_observer::MetricsObserver; +pub use observer::{ + CheckpointLoadMeta, CheckpointSaveMeta, EventBusEmitMeta, InvocationFinishMeta, + InvocationOutcome, InvocationStartMeta, NodeFinishMeta, NodeOutcome, RuntimeObserver, +}; pub use replay::{ ReplayComparison, ReplayConformanceError, ReplayRun, StateNormalizeProfile, compare_event_sequences, compare_event_sequences_with, compare_final_state, compare_final_state_with, compare_replay_runs, compare_replay_runs_with, compare_replay_runs_with_profile, normalize_event, normalize_state, normalize_state_with, }; +pub use runner::{AppRunner, AppRunnerBuilder, RunMetadata}; pub use runtime_config::{EventBusConfig, RuntimeConfig, SinkConfig}; +pub use session::{SessionInit, SessionState, StateVersions}; pub use types::{SessionId, StepNumber}; -#[cfg(feature = "metrics")] -pub use metrics_observer::MetricsObserver; -pub use observer::{ - CheckpointLoadMeta, CheckpointSaveMeta, EventBusEmitMeta, InvocationFinishMeta, - InvocationOutcome, InvocationStartMeta, NodeFinishMeta, NodeOutcome, RuntimeObserver, -}; diff --git a/src/runtimes/persistence.rs b/src/runtimes/persistence.rs index 6a7ccd6..44f6734 100644 --- a/src/runtimes/persistence.rs +++ b/src/runtimes/persistence.rs @@ -225,7 +225,11 @@ impl TryFrom for Checkpoint { concurrency_limit: p.concurrency_limit, created_at, ran_nodes: p.ran_nodes.iter().map(|s| NodeKind::decode(s)).collect(), - skipped_nodes: p.skipped_nodes.iter().map(|s| NodeKind::decode(s)).collect(), + skipped_nodes: p + .skipped_nodes + .iter() + .map(|s| NodeKind::decode(s)) + .collect(), updated_channels: p.updated_channels, }) } diff --git a/src/runtimes/replay.rs b/src/runtimes/replay.rs index 99a0bc6..4f8721d 100644 --- a/src/runtimes/replay.rs +++ b/src/runtimes/replay.rs @@ -26,7 +26,10 @@ impl ReplayRun { /// Construct from final state and captured events. #[must_use] pub fn new(final_state: VersionedState, events: Vec) -> Self { - Self { final_state, events } + Self { + final_state, + events, + } } } @@ -40,7 +43,9 @@ impl ReplayComparison { /// No differences found. #[must_use] pub fn matched() -> Self { - Self { differences: Vec::new() } + Self { + differences: Vec::new(), + } } /// Construct with the supplied differences. @@ -66,7 +71,9 @@ impl ReplayComparison { if self.is_match() { Ok(()) } else { - Err(ReplayConformanceError::Mismatch { differences: self.differences }) + Err(ReplayConformanceError::Mismatch { + differences: self.differences, + }) } } } diff --git a/src/runtimes/runner.rs b/src/runtimes/runner.rs index 9a64c98..0992b37 100644 --- a/src/runtimes/runner.rs +++ b/src/runtimes/runner.rs @@ -850,12 +850,12 @@ impl AppRunner { } } - let mut session_state = self - .sessions - .remove(session_id) - .ok_or_else(|| RunnerError::SessionNotFound { - session_id: session_id.to_string(), - })?; + let mut session_state = + self.sessions + .remove(session_id) + .ok_or_else(|| RunnerError::SessionNotFound { + session_id: session_id.to_string(), + })?; let step_report = match self.run_one_superstep(session_id, &mut session_state).await { Ok(rep) => rep, @@ -868,8 +868,7 @@ impl AppRunner { .apply_barrier(&mut error_state, &[], vec![partial]) .await; session_state.state = error_state; - self.sessions - .insert(session_id.to_string(), session_state); + self.sessions.insert(session_id.to_string(), session_state); if self.autosave && let Some(cp) = &self.checkpointer && let Some(s) = self.sessions.get(session_id) @@ -989,9 +988,12 @@ impl AppRunner { }, ) .await?; - let mut by_kind: FxHashMap = - raw.outputs.into_iter().collect(); - let partials = raw.ran_nodes.iter().filter_map(|k| by_kind.remove(k)).collect(); + let mut by_kind: FxHashMap = raw.outputs.into_iter().collect(); + let partials = raw + .ran_nodes + .iter() + .filter_map(|k| by_kind.remove(k)) + .collect(); Ok(SchedulerOutcome { ran_nodes: raw.ran_nodes, skipped_nodes: raw.skipped_nodes, @@ -1028,8 +1030,7 @@ impl AppRunner { let conditional_edges = self.app.conditional_edges(); let state_snapshot = session_state.state.snapshot(); - let mut commands_by_node: FxHashMap> = - FxHashMap::default(); + let mut commands_by_node: FxHashMap> = FxHashMap::default(); for (origin, cmd) in &barrier.frontier_commands { commands_by_node .entry(origin.clone()) @@ -1120,8 +1121,7 @@ impl AppRunner { if !self.autosave { return; } - let (Some(cp), Some(ss)) = - (&self.checkpointer, self.sessions.get(session_id)) + let (Some(cp), Some(ss)) = (&self.checkpointer, self.sessions.get(session_id)) else { return; }; diff --git a/src/runtimes/runtime_config.rs b/src/runtimes/runtime_config.rs index a97f368..4d23617 100644 --- a/src/runtimes/runtime_config.rs +++ b/src/runtimes/runtime_config.rs @@ -106,7 +106,10 @@ impl RuntimeConfig { let parts: Vec = [ "weavegraph-runtime-config-v1".to_string(), format!("session_id:{}", self.session_id.as_deref().unwrap_or("")), - format!("sqlite_db_name:{}", self.sqlite_db_name.as_deref().unwrap_or("")), + format!( + "sqlite_db_name:{}", + self.sqlite_db_name.as_deref().unwrap_or("") + ), format!("custom_checkpointer:{}", self.checkpointer_custom.is_some()), format!("clock:{}", self.clock_mode()), ] diff --git a/src/runtimes/session.rs b/src/runtimes/session.rs index 0ee6112..5ffd2b1 100644 --- a/src/runtimes/session.rs +++ b/src/runtimes/session.rs @@ -42,4 +42,3 @@ pub struct StateVersions { /// Current version of the extras channel. pub extra_version: u32, } - diff --git a/src/runtimes/streaming.rs b/src/runtimes/streaming.rs index 2740805..4f9f0a6 100644 --- a/src/runtimes/streaming.rs +++ b/src/runtimes/streaming.rs @@ -79,4 +79,3 @@ pub(crate) fn emit_invocation_end(event_bus: &EventBus, session_id: &str, reason ); } } - diff --git a/src/schedulers/scheduler.rs b/src/schedulers/scheduler.rs index 85e2462..cd5b865 100644 --- a/src/schedulers/scheduler.rs +++ b/src/schedulers/scheduler.rs @@ -294,7 +294,13 @@ impl Scheduler { while let Some((kind, res)) = stream.next().await { match res { Ok(part) => outputs.push((kind, part)), - Err(e) => return Err(SchedulerError::NodeRun { kind, step, source: e }), + Err(e) => { + return Err(SchedulerError::NodeRun { + kind, + step, + source: e, + }); + } } } diff --git a/src/state.rs b/src/state.rs index 5116960..fe56a12 100644 --- a/src/state.rs +++ b/src/state.rs @@ -446,19 +446,22 @@ impl VersionedStateBuilder { /// Append an assistant message. pub fn with_assistant_message(mut self, content: &str) -> Self { - self.messages.push(Message::with_role(Role::Assistant, content)); + self.messages + .push(Message::with_role(Role::Assistant, content)); self } /// Append a system message. pub fn with_system_message(mut self, content: &str) -> Self { - self.messages.push(Message::with_role(Role::System, content)); + self.messages + .push(Message::with_role(Role::System, content)); self } /// Append a message with a custom role. pub fn with_message(mut self, role: &str, content: &str) -> Self { - self.messages.push(Message::with_role(Role::from(role), content)); + self.messages + .push(Message::with_role(Role::from(role), content)); self } diff --git a/src/telemetry/mod.rs b/src/telemetry/mod.rs index b598fa6..cfac2d6 100644 --- a/src/telemetry/mod.rs +++ b/src/telemetry/mod.rs @@ -143,12 +143,7 @@ impl TelemetryFormatter for PlainFormatter { .iter() .enumerate() .map(|(i, e)| { - let scope = format!( - "{}{:?}{}", - self.paint(CONTEXT_COLOR), - e.scope, - self.reset() - ); + let scope = format!("{}{:?}{}", self.paint(CONTEXT_COLOR), e.scope, self.reset()); let mut lines = vec![format!("[{}] {} | {}\n", i, e.when, scope)]; lines.push(format!( "{} error: {}{}\n", diff --git a/tests/event_bus.rs b/tests/event_bus.rs index 8fa90ee..efa2e5c 100644 --- a/tests/event_bus.rs +++ b/tests/event_bus.rs @@ -710,7 +710,9 @@ fn event_strategy() -> impl Strategy { .prop_map( |(session_id, node_id, stream_id, chunk, metadata, is_final)| { let meta: FxHashMap = metadata.into_iter().collect(); - let mut b = LLMStreamingEvent::builder(chunk).is_final(is_final).metadata(meta); + let mut b = LLMStreamingEvent::builder(chunk) + .is_final(is_final) + .metadata(meta); if let Some(id) = session_id { b = b.session_id(id); } From aae1092a986d1746c7804f96fe4cda641e97fbf2 Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 17:09:58 -0400 Subject: [PATCH 07/15] remaining revision work for modules --- src/utils/clock.rs | 159 ++-------- src/utils/collections.rs | 520 ++++---------------------------- src/utils/deterministic_rng.rs | 110 +++---- src/utils/id_generator.rs | 417 +++---------------------- src/utils/json_ext.rs | 327 ++++++-------------- src/utils/merge_inspector.rs | 43 --- src/utils/message_id_helpers.rs | 46 --- src/utils/mod.rs | 65 +--- src/utils/type_guards.rs | 50 --- 9 files changed, 279 insertions(+), 1458 deletions(-) delete mode 100644 src/utils/merge_inspector.rs delete mode 100644 src/utils/message_id_helpers.rs delete mode 100644 src/utils/type_guards.rs diff --git a/src/utils/clock.rs b/src/utils/clock.rs index 66238de..678e3b9 100644 --- a/src/utils/clock.rs +++ b/src/utils/clock.rs @@ -1,55 +1,44 @@ -//! Injectable clock abstraction for checkpoints and time-based operations. -//! -//! Provides a mockable time source that wraps chrono for system time operations. -//! This abstraction enables deterministic testing and dependency injection for -//! time-sensitive functionality throughout the Weavegraph framework. +//! Injectable clock abstraction for deterministic testing and time-based operations. use chrono::{DateTime, Utc}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -/// Trait for time sources providing both Unix timestamps and DateTime objects. +/// A mockable time source for dependency injection. +/// +/// Implementors must provide `now()`, `now_datetime()`, and `now_system_time()`. +/// All other methods are derived from those three. pub trait Clock: Send + Sync + std::fmt::Debug { - /// Get the current time as a Unix timestamp (seconds since epoch). + /// Current time as seconds since the Unix epoch. fn now(&self) -> u64; - /// Get the current time as a Unix timestamp in milliseconds. + /// Current time as milliseconds since the Unix epoch. fn now_unix_ms(&self) -> i64 { self.now_datetime().timestamp_millis() } - /// Get the current time as a `DateTime` for more complex time operations. + /// Current time as a `DateTime`. fn now_datetime(&self) -> DateTime; - /// Get the current time as SystemTime for compatibility with std library. + /// Current time as a `SystemTime` for interop with the standard library. fn now_system_time(&self) -> SystemTime; - /// Check if a duration has elapsed since the given timestamp. + /// Returns `true` if at least `duration` has elapsed since `since`. fn has_elapsed(&self, since: u64, duration: Duration) -> bool { - let elapsed = self.now().saturating_sub(since); - elapsed >= duration.as_secs() + self.now().saturating_sub(since) >= duration.as_secs() } - /// Get the duration since a given timestamp. + /// Elapsed time since `timestamp`, saturating at zero. fn duration_since(&self, timestamp: u64) -> Duration { - let elapsed_secs = self.now().saturating_sub(timestamp); - Duration::from_secs(elapsed_secs) + Duration::from_secs(self.now().saturating_sub(timestamp)) } } -/// System clock implementation wrapping chrono for real time operations. -/// -/// This is the default clock implementation that provides actual system time. -/// Use this in production environments where real time is required. -/// -/// # Examples +/// Production clock backed by the operating system. /// /// ```rust /// use weavegraph::utils::clock::{Clock, SystemClock}; /// -/// let clock = SystemClock; -/// let timestamp = clock.now(); -/// let datetime = clock.now_datetime(); -/// println!("Current timestamp: {}", timestamp); +/// assert!(SystemClock.now() > 0); /// ``` #[derive(Debug, Clone, Copy, Default)] pub struct SystemClock; @@ -68,25 +57,15 @@ impl Clock for SystemClock { } } -/// Mock clock for deterministic testing and time manipulation. -/// -/// This clock allows tests to control time progression and test time-dependent -/// behavior in a deterministic way. The time can be advanced manually to -/// simulate different scenarios. -/// -/// # Examples +/// Manually-driven clock for deterministic tests. /// /// ```rust /// use weavegraph::utils::clock::{Clock, MockClock}; /// use std::time::Duration; /// /// let mut clock = MockClock::new(1000); -/// assert_eq!(clock.now(), 1000); -/// /// clock.advance(Duration::from_secs(30)); /// assert_eq!(clock.now(), 1030); -/// -/// // Test timeout behavior /// assert!(clock.has_elapsed(1000, Duration::from_secs(25))); /// ``` #[derive(Debug, Clone)] @@ -95,13 +74,7 @@ pub struct MockClock { } impl MockClock { - /// Create a new mock clock starting at the specified timestamp. - /// - /// # Parameters - /// * `start_time` - Initial timestamp in seconds since Unix epoch - /// - /// # Returns - /// A new MockClock instance set to the given time. + /// Creates a clock pinned to `start_time` (seconds since Unix epoch). #[must_use] pub fn new(start_time: u64) -> Self { Self { @@ -109,25 +82,14 @@ impl MockClock { } } - /// Create a mock clock starting at the current system time. - /// - /// Useful for tests that need to start from "now" but control progression. - /// - /// # Returns - /// A new MockClock instance set to the current system time. + /// Creates a clock pinned to the current system time. #[must_use] pub fn now() -> Self { Self::new(SystemClock.now()) } - /// Advance the clock by the specified duration. - /// - /// This simulates the passage of time for testing purposes. + /// Advances the clock by `duration`. /// - /// # Parameters - /// * `duration` - How much time to advance the clock - /// - /// # Examples /// ```rust /// use weavegraph::utils::clock::{Clock, MockClock}; /// use std::time::Duration; @@ -140,28 +102,17 @@ impl MockClock { self.current_time += duration.as_secs(); } - /// Advance the clock by the specified number of seconds. - /// - /// Convenience method for advancing by whole seconds. - /// - /// # Parameters - /// * `seconds` - Number of seconds to advance + /// Advances the clock by `seconds`. pub fn advance_secs(&mut self, seconds: u64) { self.current_time += seconds; } - /// Set the clock to a specific timestamp. - /// - /// This allows jumping to any point in time, useful for testing - /// edge cases or specific time scenarios. - /// - /// # Parameters - /// * `timestamp` - New timestamp to set the clock to + /// Pins the clock to `timestamp`. pub fn set_time(&mut self, timestamp: u64) { self.current_time = timestamp; } - /// Reset the clock to Unix epoch (timestamp 0). + /// Resets the clock to the Unix epoch. pub fn reset(&mut self) { self.current_time = 0; } @@ -174,7 +125,7 @@ impl Clock for MockClock { fn now_datetime(&self) -> DateTime { DateTime::from_timestamp(self.current_time as i64, 0) - .unwrap_or_else(|| DateTime::from_timestamp(0, 0).unwrap()) + .unwrap_or_else(|| DateTime::from_timestamp(0, 0).expect("epoch is valid")) } fn now_system_time(&self) -> SystemTime { @@ -182,29 +133,8 @@ impl Clock for MockClock { } } -/// Creates a boxed clock instance based on environment or testing needs. -/// -/// This factory function provides a convenient way to create clock instances -/// with appropriate type erasure for dependency injection. -/// -/// # Parameters -/// * `use_mock` - Whether to create a mock clock (true) or system clock (false) -/// * `mock_start_time` - Starting time for mock clock (ignored for system clock) -/// -/// # Returns -/// A boxed Clock trait object -/// -/// # Examples -/// -/// ```rust -/// use weavegraph::utils::clock::create_clock; -/// -/// // System clock for production -/// let sys_clock = create_clock(false, 0); -/// -/// // Mock clock for testing -/// let mock_clock = create_clock(true, 1000); -/// ``` +/// Returns a boxed [`SystemClock`] when `use_mock` is `false`, otherwise a +/// [`MockClock`] starting at `mock_start_time`. #[must_use] pub fn create_clock(use_mock: bool, mock_start_time: u64) -> Box { if use_mock { @@ -214,57 +144,34 @@ pub fn create_clock(use_mock: bool, mock_start_time: u64) -> Box { } } -/// Utility functions for common time operations. +/// Utility functions for common timestamp operations. pub mod time_utils { use super::*; - /// Format a timestamp as a human-readable string. - /// - /// # Parameters - /// * `timestamp` - Unix timestamp to format - /// - /// # Returns - /// Formatted time string in ISO 8601 format + /// Formats `timestamp` as `"YYYY-MM-DD HH:MM:SS UTC"`. /// - /// # Examples + /// Returns `"invalid-timestamp-"` for out-of-range values. /// /// ```rust /// use weavegraph::utils::clock::time_utils::format_timestamp; /// - /// let formatted = format_timestamp(1640995200); // 2022-01-01 00:00:00 UTC - /// assert!(formatted.contains("2022")); + /// assert!(format_timestamp(1640995200).contains("2022")); /// ``` #[must_use] pub fn format_timestamp(timestamp: u64) -> String { DateTime::from_timestamp(timestamp as i64, 0) .map(|dt| dt.format("%Y-%m-%d %H:%M:%S UTC").to_string()) - .unwrap_or_else(|| format!("invalid-timestamp-{}", timestamp)) + .unwrap_or_else(|| format!("invalid-timestamp-{timestamp}")) } - /// Calculate the difference between two timestamps. - /// - /// # Parameters - /// * `earlier` - Earlier timestamp - /// * `later` - Later timestamp - /// - /// # Returns - /// Duration between the timestamps, or zero if earlier > later + /// Duration between `earlier` and `later`, saturating at zero. #[must_use] pub fn duration_between(earlier: u64, later: u64) -> Duration { Duration::from_secs(later.saturating_sub(earlier)) } - /// Check if a timestamp is within a certain age. - /// - /// # Parameters - /// * `timestamp` - Timestamp to check - /// * `max_age` - Maximum acceptable age - /// * `clock` - Clock instance to get current time - /// - /// # Returns - /// True if the timestamp is within the specified age + /// Returns `true` if `timestamp` is within `max_age` of the current time. pub fn is_recent(timestamp: u64, max_age: Duration, clock: &dyn Clock) -> bool { - let age = clock.duration_since(timestamp); - age <= max_age + clock.duration_since(timestamp) <= max_age } } diff --git a/src/utils/collections.rs b/src/utils/collections.rs index 0bfc320..246bf96 100644 --- a/src/utils/collections.rs +++ b/src/utils/collections.rs @@ -1,61 +1,57 @@ -//! Collection utilities and common patterns for the Weavegraph framework. +//! Collection utilities for Weavegraph. //! -//! This module provides type-safe, ergonomic utilities for working with -//! common collection patterns throughout the codebase, particularly for -//! extra data maps, state snapshots, and channel operations. +//! Typed helpers for `FxHashMap` extra data maps and common +//! string-keyed map patterns used throughout the codebase. use rustc_hash::FxHashMap; use serde_json::Value; use std::collections::HashMap; use thiserror::Error; -/// Errors that can occur during collection operations. +/// Errors from typed collection operations. #[derive(Debug, Error)] #[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] pub enum CollectionError { - /// Attempted to access a key that doesn't exist. + /// The requested key was not present. #[error("Key '{key}' not found in collection")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::collections::missing_key)) )] MissingKey { - /// The key that was not found + /// The missing key. key: String, }, - /// Invalid type conversion during value extraction. + /// The stored value's JSON type did not match the requested type. #[error("Invalid type conversion for key '{key}': expected {expected}, found {found}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::collections::type_mismatch)) )] TypeMismatch { - /// The key where the type mismatch occurred + /// The key where the mismatch occurred. key: String, - /// The expected type as a string + /// The type the caller expected. expected: String, - /// The actual type that was found + /// The type that was actually stored. found: String, }, - /// JSON serialization/deserialization error. + /// A `serde_json` serialization or deserialization failure. #[error("JSON operation failed: {source}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::collections::json)) )] Json { - /// The underlying JSON error + /// The underlying JSON error. #[from] source: serde_json::Error, }, } -/// Creates a new `FxHashMap` for string keys and JSON values. -/// -/// This is the standard pattern throughout the codebase for extra data storage. -/// Uses `FxHashMap` for better performance with string keys. +/// Creates a new empty `FxHashMap` for extra data storage. /// /// # Examples /// @@ -65,7 +61,6 @@ pub enum CollectionError { /// /// let mut extra = new_extra_map(); /// extra.insert("key".to_string(), json!("value")); -/// extra.insert("count".to_string(), json!(42)); /// ``` #[must_use] #[inline] @@ -73,23 +68,13 @@ pub fn new_extra_map() -> FxHashMap { FxHashMap::default() } -/// Creates a new `FxHashMap` with the specified capacity. -/// -/// Useful when you know the approximate size of the map ahead of time -/// to avoid reallocations during insertion. -/// -/// # Parameters -/// * `capacity` - Initial capacity hint for the map -/// -/// # Returns -/// A new `FxHashMap` with the specified capacity. +/// Creates a new `FxHashMap` pre-allocated for `capacity` entries. /// /// # Examples /// /// ```rust /// use weavegraph::utils::collections::new_extra_map_with_capacity; /// -/// // Pre-allocate for known size /// let mut extra = new_extra_map_with_capacity(10); /// ``` #[must_use] @@ -98,15 +83,7 @@ pub fn new_extra_map_with_capacity(capacity: usize) -> FxHashMap FxHashMap::with_capacity_and_hasher(capacity, Default::default()) } -/// Creates a new `FxHashMap` from key-value pairs. -/// -/// Convenience function for creating extra maps with initial data. -/// -/// # Parameters -/// * `pairs` - Iterator of (key, value) pairs -/// -/// # Returns -/// A new `FxHashMap` populated with the given pairs. +/// Builds an extra map from an iterator of `(key, value)` pairs. /// /// # Examples /// @@ -114,10 +91,7 @@ pub fn new_extra_map_with_capacity(capacity: usize) -> FxHashMap /// use weavegraph::utils::collections::extra_map_from_pairs; /// use serde_json::json; /// -/// let extra = extra_map_from_pairs([ -/// ("name", json!("test")), -/// ("count", json!(42)), -/// ]); +/// let extra = extra_map_from_pairs([("name", json!("test")), ("count", json!(42))]); /// ``` #[must_use] pub fn extra_map_from_pairs(pairs: I) -> FxHashMap @@ -132,16 +106,7 @@ where .collect() } -/// Merges multiple extra maps into one, with later maps overriding earlier ones. -/// -/// This is useful for combining extra data from multiple sources in a -/// predictable way. -/// -/// # Parameters -/// * `maps` - Iterator of maps to merge -/// -/// # Returns -/// A new map containing all entries, with conflicts resolved by last-wins. +/// Merges extra maps left-to-right; later entries win on key collision. /// /// # Examples /// @@ -149,42 +114,26 @@ where /// use weavegraph::utils::collections::{new_extra_map, merge_extra_maps}; /// use serde_json::json; /// -/// let mut map1 = new_extra_map(); -/// map1.insert("a".into(), json!(1)); -/// map1.insert("b".into(), json!(2)); +/// let mut a = new_extra_map(); +/// a.insert("x".to_string(), json!(1)); /// -/// let mut map2 = new_extra_map(); -/// map2.insert("b".into(), json!(3)); // Overrides map1 -/// map2.insert("c".into(), json!(4)); +/// let mut b = new_extra_map(); +/// b.insert("x".to_string(), json!(2)); /// -/// let merged = merge_extra_maps([&map1, &map2]); -/// assert_eq!(merged["b"], json!(3)); // Last wins +/// let merged = merge_extra_maps([&a, &b]); +/// assert_eq!(merged["x"], json!(2)); /// ``` #[must_use] pub fn merge_extra_maps<'a, I>(maps: I) -> FxHashMap where I: IntoIterator>, { - let mut result = new_extra_map(); - for map in maps { - result.extend(map.iter().map(|(k, v)| (k.clone(), v.clone()))); - } - result + maps.into_iter() + .flat_map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone()))) + .collect() } -/// Extension trait for working with extra data maps in a type-safe manner. -/// -/// This trait provides convenient methods for inserting and retrieving typed values -/// from JSON-based extra data maps. It's designed specifically for the common pattern -/// of storing heterogeneous data in `FxHashMap` structures throughout -/// the Weavegraph framework. -/// -/// # Design Principles -/// -/// - **Type Safety**: Methods ensure type correctness at compile time where possible -/// - **Error Handling**: Retrieval methods return `Result` for proper error handling -/// - **Ergonomics**: Convenient insertion methods that accept `Into` traits -/// - **JSON Compatibility**: Full support for JSON serialization/deserialization +/// Typed insert and get methods for `FxHashMap` extra data maps. /// /// # Examples /// @@ -192,360 +141,121 @@ where /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; /// /// let mut extra = new_extra_map(); -/// -/// // Insert different types of data /// extra.insert_string("name", "Alice"); /// extra.insert_number("age", 30); /// extra.insert_bool("active", true); /// -/// // Retrieve with type checking /// assert_eq!(extra.get_string("name").unwrap(), "Alice"); /// assert_eq!(extra.get_number("age").unwrap(), 30.into()); -/// assert_eq!(extra.get_bool("active").unwrap(), true); -/// -/// // Type validation +/// assert!(extra.get_bool("active").unwrap()); /// assert!(extra.has_typed("name", "string")); /// assert!(!extra.has_typed("name", "number")); /// ``` pub trait ExtraMapExt { - /// Insert a string value into the extra map. - /// - /// This method provides a convenient way to insert string values without - /// manually wrapping them in `Value::String`. - /// - /// # Parameters - /// * `key` - Map key (will be converted to String) - /// * `value` - String value (will be converted to String) - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// - /// let mut extra = new_extra_map(); - /// extra.insert_string("username", "alice123"); - /// extra.insert_string("display_name", String::from("Alice Smith")); - /// ``` + /// Inserts a string value. fn insert_string(&mut self, key: impl Into, value: impl Into); - /// Insert a numeric value into the extra map. - /// - /// Accepts any value that can be converted to `serde_json::Number`, - /// including integers and floats. - /// - /// # Parameters - /// * `key` - Map key (will be converted to String) - /// * `value` - Numeric value (will be converted to serde_json::Number) - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// use serde_json::Number; - /// - /// let mut extra = new_extra_map(); - /// extra.insert_number("count", 42); - /// - /// // For floats, create Number explicitly - /// let price = Number::from_f64(19.99).unwrap(); - /// extra.insert_number("price", price); - /// ``` + /// Inserts a numeric value. fn insert_number(&mut self, key: impl Into, value: impl Into); - /// Insert a boolean value into the extra map. - /// - /// # Parameters - /// * `key` - Map key (will be converted to String) - /// * `value` - Boolean value - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// - /// let mut extra = new_extra_map(); - /// extra.insert_bool("enabled", true); - /// extra.insert_bool("verified", false); - /// ``` + /// Inserts a boolean value. fn insert_bool(&mut self, key: impl Into, value: bool); - /// Insert any serializable value into the extra map. - /// - /// This method can handle complex types by serializing them to JSON. - /// It's useful for storing structured data that needs to be preserved - /// exactly as it was inserted. - /// - /// # Parameters - /// * `key` - Map key (will be converted to String) - /// * `value` - Any value that implements `serde::Serialize` - /// - /// # Returns - /// * `Ok(())` - Value was successfully serialized and inserted - /// * `Err(CollectionError::Json)` - Serialization failed - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// use serde::Serialize; - /// - /// #[derive(Serialize)] - /// struct UserProfile { - /// id: u64, - /// email: String, - /// } - /// - /// let mut extra = new_extra_map(); - /// let profile = UserProfile { - /// id: 123, - /// email: "alice@example.com".to_string(), - /// }; - /// - /// extra.insert_json("profile", &profile).unwrap(); - /// ``` + /// Serializes and inserts any `Serialize` value; returns `Err` on serialization failure. fn insert_json( &mut self, key: impl Into, value: T, ) -> Result<(), CollectionError>; - /// Get a string value from the map with type validation. - /// - /// Returns the string value if the key exists and contains a string, - /// otherwise returns an appropriate error. - /// - /// # Parameters - /// * `key` - Map key to look up - /// - /// # Returns - /// * `Ok(String)` - The string value - /// * `Err(CollectionError::MissingKey)` - Key doesn't exist - /// * `Err(CollectionError::TypeMismatch)` - Key exists but value is not a string - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// - /// let mut extra = new_extra_map(); - /// extra.insert_string("name", "Alice"); - /// extra.insert_number("count", 42); - /// - /// assert_eq!(extra.get_string("name").unwrap(), "Alice"); - /// assert!(extra.get_string("count").is_err()); // Type mismatch - /// assert!(extra.get_string("missing").is_err()); // Missing key - /// ``` + /// Returns the string at `key`, or `Err` on missing key or type mismatch. fn get_string(&self, key: &str) -> Result; - /// Get a numeric value from the map with type validation. - /// - /// Returns the numeric value if the key exists and contains a number, - /// otherwise returns an appropriate error. - /// - /// # Parameters - /// * `key` - Map key to look up - /// - /// # Returns - /// * `Ok(serde_json::Number)` - The numeric value - /// * `Err(CollectionError::MissingKey)` - Key doesn't exist - /// * `Err(CollectionError::TypeMismatch)` - Key exists but value is not a number + /// Returns the number at `key`, or `Err` on missing key or type mismatch. fn get_number(&self, key: &str) -> Result; - /// Get a boolean value from the map with type validation. - /// - /// Returns the boolean value if the key exists and contains a boolean, - /// otherwise returns an appropriate error. - /// - /// # Parameters - /// * `key` - Map key to look up - /// - /// # Returns - /// * `Ok(bool)` - The boolean value - /// * `Err(CollectionError::MissingKey)` - Key doesn't exist - /// * `Err(CollectionError::TypeMismatch)` - Key exists but value is not a boolean + /// Returns the bool at `key`, or `Err` on missing key or type mismatch. fn get_bool(&self, key: &str) -> Result; - /// Get a value and deserialize it to the specified type. - /// - /// This method attempts to deserialize the JSON value at the given key - /// to the requested type. It's useful for retrieving complex structured - /// data that was stored using `insert_json`. - /// - /// # Parameters - /// * `key` - Map key to look up - /// - /// # Returns - /// * `Ok(T)` - Successfully deserialized value - /// * `Err(CollectionError::MissingKey)` - Key doesn't exist - /// * `Err(CollectionError::Json)` - Deserialization failed - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// use serde::{Deserialize, Serialize}; - /// - /// #[derive(Serialize, Deserialize, PartialEq, Debug)] - /// struct Config { - /// timeout: u64, - /// retries: u32, - /// } - /// - /// let mut extra = new_extra_map(); - /// let config = Config { timeout: 5000, retries: 3 }; - /// extra.insert_json("config", &config).unwrap(); - /// - /// let retrieved: Config = extra.get_typed("config").unwrap(); - /// assert_eq!(retrieved, config); - /// ``` + /// Deserializes the value at `key` to `T`. fn get_typed(&self, key: &str) -> Result; - /// Check if a key exists and has the expected type. - /// - /// This method provides a quick way to validate the presence and type - /// of a value without retrieving it. Useful for conditional logic - /// and validation scenarios. - /// - /// # Parameters - /// * `key` - Map key to check - /// * `expected_type` - Expected type as string ("string", "number", "bool", "array", "object", "null") - /// - /// # Returns - /// `true` if the key exists and has the expected type, `false` otherwise - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; - /// - /// let mut extra = new_extra_map(); - /// extra.insert_string("name", "Alice"); - /// extra.insert_number("age", 30); - /// - /// assert!(extra.has_typed("name", "string")); - /// assert!(extra.has_typed("age", "number")); - /// assert!(!extra.has_typed("name", "number")); - /// assert!(!extra.has_typed("missing", "string")); - /// ``` + /// Returns `true` if `key` exists with the given JSON type name + /// (`"string"`, `"number"`, `"bool"`, `"array"`, `"object"`, `"null"`). fn has_typed(&self, key: &str, expected_type: &str) -> bool; } impl ExtraMapExt for FxHashMap { - /// Inserts a string value into the FxHashMap as a JSON String value. - /// - /// This implementation converts both the key and value to their owned forms - /// and wraps the value in `Value::String` for JSON compatibility. fn insert_string(&mut self, key: impl Into, value: impl Into) { self.insert(key.into(), Value::String(value.into())); } - /// Inserts a numeric value into the FxHashMap as a JSON Number value. - /// - /// This implementation accepts any type that can be converted to `serde_json::Number` - /// and wraps it in `Value::Number`. fn insert_number(&mut self, key: impl Into, value: impl Into) { self.insert(key.into(), Value::Number(value.into())); } - /// Inserts a boolean value into the FxHashMap as a JSON Bool value. - /// - /// This implementation directly wraps the boolean in `Value::Bool`. fn insert_bool(&mut self, key: impl Into, value: bool) { self.insert(key.into(), Value::Bool(value)); } - /// Serializes and inserts any serializable value into the FxHashMap. - /// - /// This implementation uses `serde_json::to_value` to convert the input - /// to a JSON value, which may fail if serialization is not possible. - /// The serialized value is then inserted into the map. fn insert_json( &mut self, key: impl Into, value: T, ) -> Result<(), CollectionError> { - let json_value = serde_json::to_value(value)?; - self.insert(key.into(), json_value); + self.insert(key.into(), serde_json::to_value(value)?); Ok(()) } - /// Retrieves a string value from the FxHashMap with type validation. - /// - /// This implementation checks that the value exists and is of type `Value::String`, - /// returning appropriate errors for missing keys or type mismatches. - /// The string is cloned to return an owned value. fn get_string(&self, key: &str) -> Result { match self.get(key) { Some(Value::String(s)) => Ok(s.clone()), Some(other) => Err(CollectionError::TypeMismatch { - key: key.to_string(), - expected: "string".to_string(), - found: format!("{:?}", other), + key: key.to_owned(), + expected: "string".to_owned(), + found: format!("{other:?}"), }), None => Err(CollectionError::MissingKey { - key: key.to_string(), + key: key.to_owned(), }), } } - /// Retrieves a numeric value from the FxHashMap with type validation. - /// - /// This implementation checks that the value exists and is of type `Value::Number`, - /// returning appropriate errors for missing keys or type mismatches. - /// The number is cloned to return an owned value. fn get_number(&self, key: &str) -> Result { match self.get(key) { Some(Value::Number(n)) => Ok(n.clone()), Some(other) => Err(CollectionError::TypeMismatch { - key: key.to_string(), - expected: "number".to_string(), - found: format!("{:?}", other), + key: key.to_owned(), + expected: "number".to_owned(), + found: format!("{other:?}"), }), None => Err(CollectionError::MissingKey { - key: key.to_string(), + key: key.to_owned(), }), } } - /// Retrieves a boolean value from the FxHashMap with type validation. - /// - /// This implementation checks that the value exists and is of type `Value::Bool`, - /// returning appropriate errors for missing keys or type mismatches. - /// The boolean is dereferenced to return the primitive value. fn get_bool(&self, key: &str) -> Result { match self.get(key) { Some(Value::Bool(b)) => Ok(*b), Some(other) => Err(CollectionError::TypeMismatch { - key: key.to_string(), - expected: "boolean".to_string(), - found: format!("{:?}", other), + key: key.to_owned(), + expected: "boolean".to_owned(), + found: format!("{other:?}"), }), None => Err(CollectionError::MissingKey { - key: key.to_string(), + key: key.to_owned(), }), } } - /// Deserializes a JSON value to the specified type. - /// - /// This implementation uses `serde_json::from_value` to convert the stored - /// JSON value to the requested type. The value is cloned before deserialization. - /// Returns JSON errors if deserialization fails. fn get_typed(&self, key: &str) -> Result { - match self.get(key) { - Some(value) => serde_json::from_value(value.clone()) - .map_err(|e| CollectionError::Json { source: e }), - None => Err(CollectionError::MissingKey { - key: key.to_string(), - }), - } + let value = self.get(key).ok_or_else(|| CollectionError::MissingKey { + key: key.to_owned(), + })?; + Ok(serde_json::from_value(value.clone())?) } - /// Checks if a key exists and matches the expected JSON type. - /// - /// This implementation performs pattern matching against the JSON value variants - /// and the expected type string. It's optimized to avoid cloning or deserializing - /// the actual value, making it suitable for validation scenarios. fn has_typed(&self, key: &str, expected_type: &str) -> bool { matches!( (self.get(key), expected_type), @@ -559,18 +269,7 @@ impl ExtraMapExt for FxHashMap { } } -/// Extension trait for `HashMap` with string keys to provide common operations. -/// -/// This trait extends both `HashMap` and `FxHashMap` with -/// convenient methods for common access patterns. These methods reduce boilerplate -/// code when working with string-keyed maps throughout the codebase. -/// -/// # Design Goals -/// -/// - **Ergonomics**: Reduce common boilerplate patterns -/// - **Flexibility**: Support both standard HashMap and FxHashMap -/// - **Performance**: Minimize allocations and clones where possible -/// - **Safety**: Provide safe alternatives to unwrap-heavy code +/// Ergonomic methods for string-keyed maps. /// /// # Examples /// @@ -578,100 +277,25 @@ impl ExtraMapExt for FxHashMap { /// use weavegraph::utils::collections::StringMapExt; /// use rustc_hash::FxHashMap; /// -/// let mut config: FxHashMap = FxHashMap::default(); -/// config.insert("max_connections".to_string(), 100); -/// -/// // Get with default - no panic if key missing -/// let connections = config.get_or_default("max_connections", 50); // Returns 100 -/// let timeout = config.get_or_default("timeout", 30); // Returns 30 (default) -/// -/// // Insert or update pattern -/// config.insert_or_update( -/// "retry_count".to_string(), -/// 1, // Initial value if key doesn't exist -/// |count| *count += 1, // Update function if key exists -/// ); +/// let mut counts: FxHashMap = FxHashMap::default(); +/// counts.insert_or_update("hits".to_string(), 1, |n| *n += 1); +/// counts.insert_or_update("hits".to_string(), 1, |n| *n += 1); +/// assert_eq!(counts.get_or_default("hits", 0), 2); +/// assert_eq!(counts.get_or_default("misses", 0), 0); /// ``` pub trait StringMapExt { - /// Get a value with a default if the key doesn't exist. - /// - /// This method provides a safe way to retrieve values from a map with - /// a fallback default, avoiding the need for `unwrap()` or complex - /// match statements. - /// - /// # Parameters - /// * `key` - Key to look up in the map - /// * `default` - Default value to return if key is not found - /// - /// # Returns - /// The value associated with the key, or the default if key doesn't exist - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::StringMapExt; - /// use std::collections::HashMap; - /// - /// let mut settings = HashMap::new(); - /// settings.insert("debug".to_string(), true); - /// - /// assert_eq!(settings.get_or_default("debug", false), true); - /// assert_eq!(settings.get_or_default("verbose", false), false); - /// ``` + /// Returns the value at `key`, or `default` if the key is absent. fn get_or_default(&self, key: &str, default: V) -> V where V: Clone; - /// Insert if the key doesn't exist, otherwise update with a function. - /// - /// This method implements the common "insert or update" pattern efficiently. - /// If the key exists, the update function is called with a mutable reference - /// to the existing value. If the key doesn't exist, the provided value is inserted. - /// - /// This is particularly useful for counters, accumulators, and other scenarios - /// where you need to modify existing values or insert new ones. - /// - /// # Parameters - /// * `key` - Key to insert or update - /// * `value` - Value to insert if key doesn't exist - /// * `update_fn` - Function to call with existing value if key exists - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::collections::StringMapExt; - /// use rustc_hash::FxHashMap; - /// - /// let mut counters = FxHashMap::default(); - /// - /// // First call inserts the initial value - /// counters.insert_or_update( - /// "page_views".to_string(), - /// 1, - /// |count| *count += 1, - /// ); - /// assert_eq!(counters["page_views"], 1); - /// - /// // Subsequent calls update the existing value - /// counters.insert_or_update( - /// "page_views".to_string(), - /// 1, - /// |count| *count += 1, - /// ); - /// assert_eq!(counters["page_views"], 2); - /// ``` + /// Inserts `value` if `key` is absent; otherwise calls `update_fn` on the existing value. fn insert_or_update(&mut self, key: String, value: V, update_fn: F) where F: FnOnce(&mut V); } impl StringMapExt for HashMap { - /// Gets a value from the HashMap with a default fallback. - /// - /// This implementation uses the standard HashMap's `get` method with `cloned()` - /// to return an owned value, falling back to the provided default if the key - /// is not found. The standard HashMap uses SipHash for better security against - /// hash collision attacks. fn get_or_default(&self, key: &str, default: V) -> V where V: Clone, @@ -679,12 +303,6 @@ impl StringMapExt for HashMap { self.get(key).cloned().unwrap_or(default) } - /// Inserts a new value or updates an existing one using the entry API. - /// - /// This implementation leverages HashMap's entry API for efficient - /// insert-or-update operations. If the key exists, `and_modify` calls the - /// update function. If not, `or_insert` adds the initial value. - /// This avoids double-lookup that would occur with separate contains/get/insert calls. fn insert_or_update(&mut self, key: String, value: V, update_fn: F) where F: FnOnce(&mut V), @@ -694,12 +312,6 @@ impl StringMapExt for HashMap { } impl StringMapExt for FxHashMap { - /// Gets a value from the FxHashMap with a default fallback. - /// - /// This implementation uses FxHashMap's `get` method with `cloned()` - /// to return an owned value, falling back to the provided default if the key - /// is not found. FxHashMap uses a faster hash function (FxHash) that's suitable - /// for trusted input and provides better performance than the standard HashMap. fn get_or_default(&self, key: &str, default: V) -> V where V: Clone, @@ -707,12 +319,6 @@ impl StringMapExt for FxHashMap { self.get(key).cloned().unwrap_or(default) } - /// Inserts a new value or updates an existing one using the entry API. - /// - /// This implementation leverages FxHashMap's entry API for efficient - /// insert-or-update operations. The behavior is identical to HashMap's - /// implementation but benefits from FxHashMap's faster hashing for string keys. - /// This is particularly beneficial in the Weavegraph framework where string keys are common. fn insert_or_update(&mut self, key: String, value: V, update_fn: F) where F: FnOnce(&mut V), diff --git a/src/utils/deterministic_rng.rs b/src/utils/deterministic_rng.rs index 50a8768..2b109ad 100644 --- a/src/utils/deterministic_rng.rs +++ b/src/utils/deterministic_rng.rs @@ -1,18 +1,14 @@ -//! Deterministic random number generation for testing and reproducible behavior. +//! Deterministic random number generation for reproducible behavior and testing. //! -//! Provides deterministic random number generators that can be seeded for -//! reproducible test scenarios and deterministic workflow execution when -//! randomness is required but consistency is needed. +//! Seeded wrappers around [`rand::rngs::StdRng`] that produce the same sequence +//! for any given seed, enabling reproducible test scenarios and deterministic +//! workflow execution. use rand::rngs::StdRng; use rand::{RngExt, SeedableRng}; use std::collections::HashMap; -/// Deterministic random number generator wrapping rand::StdRng. -/// -/// This generator provides deterministic random values when seeded with a -/// fixed value, enabling reproducible test scenarios and consistent behavior -/// across runs when needed. +/// Seeded deterministic RNG that produces the same sequence for a given seed. /// /// # Examples /// @@ -22,10 +18,8 @@ use std::collections::HashMap; /// let mut rng = DeterministicRng::new(42); /// let value1 = rng.random_u64(); /// -/// // Same seed produces same sequence /// let mut rng2 = DeterministicRng::new(42); -/// let value2 = rng2.random_u64(); -/// assert_eq!(value1, value2); +/// assert_eq!(value1, rng2.random_u64()); /// ``` #[derive(Debug)] pub struct DeterministicRng { @@ -34,13 +28,7 @@ pub struct DeterministicRng { } impl DeterministicRng { - /// Create a deterministic RNG with a fixed seed. - /// - /// # Parameters - /// * `seed` - Seed value for deterministic generation - /// - /// # Returns - /// A new DeterministicRng instance seeded with the given value. + /// Create a new RNG seeded with `seed`. #[must_use] pub fn new(seed: u64) -> Self { Self { @@ -49,49 +37,44 @@ impl DeterministicRng { } } - /// Get the seed used to initialize this RNG. + /// Return the original seed. #[must_use] pub fn seed(&self) -> u64 { self.seed } - /// Reset the RNG to its initial state using the original seed. + /// Reset the sequence to its initial state. pub fn reset(&mut self) { self.rng = StdRng::seed_from_u64(self.seed); } - /// Create a new RNG with a different seed derived from the current state. + /// Create a child RNG seeded from the next value in this sequence. #[must_use] pub fn fork(&mut self) -> Self { - let new_seed = self.random_u64(); - Self::new(new_seed) + Self::new(self.random_u64()) } - /// Generate a random u64 value. + /// Draw a random `u64`. pub fn random_u64(&mut self) -> u64 { self.rng.random() } - /// Generate a random u32 value. + /// Draw a random `u32`. pub fn random_u32(&mut self) -> u32 { self.rng.random() } - /// Generate a random boolean value. + /// Draw a random `bool`. pub fn random_bool(&mut self) -> bool { self.rng.random() } - /// Generate a random f64 value between 0.0 and 1.0. + /// Draw a random `f64` in `[0.0, 1.0)`. pub fn random_f64(&mut self) -> f64 { self.rng.random() } - /// Generate a random u32 value in the specified range. - /// - /// # Parameters - /// * `min` - Minimum value (inclusive) - /// * `max` - Maximum value (exclusive) + /// Draw a `u32` in `[min, max)`. Returns `min` when `min >= max`. pub fn random_range_u32(&mut self, min: u32, max: u32) -> u32 { if min >= max { return min; @@ -99,10 +82,7 @@ impl DeterministicRng { min + (self.rng.random::() % (max - min)) } - /// Generate a random string of lowercase letters. - /// - /// # Parameters - /// * `len` - Length of the string to generate + /// Generate a random lowercase-ASCII string of length `len`. /// /// # Examples /// @@ -110,25 +90,20 @@ impl DeterministicRng { /// use weavegraph::utils::deterministic_rng::DeterministicRng; /// /// let mut rng = DeterministicRng::new(42); - /// let random_string = rng.random_string(8); - /// assert_eq!(random_string.len(), 8); - /// assert!(random_string.chars().all(|c| c.is_ascii_lowercase())); + /// let s = rng.random_string(8); + /// assert_eq!(s.len(), 8); + /// assert!(s.chars().all(|c| c.is_ascii_lowercase())); /// ``` pub fn random_string(&mut self, len: usize) -> String { (0..len) - .map(|_| { - let byte = self.random_range_u32(b'a' as u32, b'z' as u32 + 1) as u8; - byte as char - }) + .map(|_| (b'a' + self.random_range_u32(0, 26) as u8) as char) .collect() } - /// Generate a random alphanumeric string. - /// - /// # Parameters - /// * `len` - Length of the string to generate + /// Generate a random alphanumeric string of length `len`. /// /// # Examples + /// /// ```rust /// use weavegraph::utils::deterministic_rng::DeterministicRng; /// @@ -140,31 +115,20 @@ impl DeterministicRng { pub fn random_alphanumeric(&mut self, len: usize) -> String { const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; (0..len) - .map(|_| { - let idx = self.random_range_u32(0, CHARS.len() as u32) as usize; - CHARS[idx] as char - }) + .map(|_| CHARS[self.random_range_u32(0, CHARS.len() as u32) as usize] as char) .collect() } - /// Choose a random element from a slice. - /// - /// # Parameters - /// * `choices` - Slice to choose from - /// - /// # Returns - /// A reference to a randomly chosen element, or None if slice is empty. + /// Return a reference to a randomly chosen element of `choices`, or `None` if empty. pub fn choose<'a, T>(&mut self, choices: &'a [T]) -> Option<&'a T> { if choices.is_empty() { - None - } else { - let idx = self.random_range_u32(0, choices.len() as u32) as usize; - Some(&choices[idx]) + return None; } + Some(&choices[self.random_range_u32(0, choices.len() as u32) as usize]) } } -/// Thread-safe registry for managing multiple named deterministic RNG instances. +/// Registry of named [`DeterministicRng`] instances, all seeded from a shared base. #[derive(Debug)] pub struct RngRegistry { rngs: HashMap, @@ -172,7 +136,7 @@ pub struct RngRegistry { } impl RngRegistry { - /// Create a new RNG registry with a base seed. + /// Create a new registry with the given base seed. #[must_use] pub fn new(base_seed: u64) -> Self { Self { @@ -181,32 +145,32 @@ impl RngRegistry { } } - /// Get or create a deterministic RNG for a specific name. + /// Return the [`DeterministicRng`] for `name`, inserting one on first access. + /// + /// Each name receives a stable seed derived by hashing the name against the base seed. pub fn get_rng(&mut self, name: &str) -> &mut DeterministicRng { - self.rngs.entry(name.to_string()).or_insert_with(|| { - // Create deterministic seed from name and base seed + self.rngs.entry(name.to_owned()).or_insert_with(|| { let name_hash = name .bytes() .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64)); - let derived_seed = self.base_seed.wrapping_add(name_hash); - DeterministicRng::new(derived_seed) + DeterministicRng::new(self.base_seed.wrapping_add(name_hash)) }) } - /// Reset all RNGs in the registry to their initial state. + /// Reset every RNG in the registry to its initial state. pub fn reset_all(&mut self) { for rng in self.rngs.values_mut() { rng.reset(); } } - /// Get the number of RNGs in the registry. + /// Return the number of named RNGs in this registry. #[must_use] pub fn len(&self) -> usize { self.rngs.len() } - /// Check if the registry is empty. + /// Return `true` if no named RNGs have been created yet. #[must_use] pub fn is_empty(&self) -> bool { self.rngs.is_empty() diff --git a/src/utils/id_generator.rs b/src/utils/id_generator.rs index bc7c53a..11b3fe5 100644 --- a/src/utils/id_generator.rs +++ b/src/utils/id_generator.rs @@ -1,131 +1,55 @@ //! ID generation utilities for run, step, node, and session identifiers. //! -//! Provides utilities for generating unique, deterministic, and contextual IDs -//! throughout the Weavegraph framework. Supports both random UUID-based generation -//! and deterministic seeded generation for testing and reproducibility. +//! Supports UUID-based random generation and deterministic seeded generation +//! for testing and reproducibility. -use std::fmt; use std::sync::atomic::{AtomicU64, Ordering}; -use thiserror::Error; use uuid::Uuid; -/// Errors that can occur during ID generation. -#[derive(Debug, Error)] -#[cfg_attr(feature = "diagnostics", derive(miette::Diagnostic))] -pub enum IdError { - /// Invalid format for ID parsing or validation. - #[error("Invalid ID format: {format}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic(code(weavegraph::id::invalid_format)) - )] - InvalidFormat { - /// The invalid format string that caused the error. - format: String, - }, - - /// ID generation failed due to system constraints. - #[error("ID generation failed: {reason}")] - #[cfg_attr( - feature = "diagnostics", - diagnostic(code(weavegraph::id::generation_failed)) - )] - GenerationFailed { - /// Human-readable description of why generation failed. - reason: String, - }, -} - -/// Configuration for ID generation behavior. +/// Controls how [`IdGenerator`] produces IDs. #[derive(Debug, Clone, Default)] pub struct IdConfig { - /// Random seed for deterministic ID generation (optional). + /// When set, IDs are derived from this seed rather than a random UUID. pub seed: Option, - /// Prefix to use for all generated IDs. + /// Prepended to every ID produced by [`IdGenerator::generate_id`]. pub prefix: Option, - /// Whether to include timestamps in IDs. + /// Append a Unix-timestamp suffix (`-t`) to each generated ID. pub include_timestamp: bool, - /// Counter for sequential ID generation. + /// Append a monotonically increasing counter to each generated ID. pub use_counter: bool, } -impl fmt::Display for IdConfig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "IdConfig {{ seed: {:?}, prefix: {:?}, timestamp: {}, counter: {} }}", - self.seed, self.prefix, self.include_timestamp, self.use_counter - ) - } -} - -/// High-performance ID generator with configurable behavior. -/// -/// Supports multiple ID generation strategies including UUID-based random IDs, -/// deterministic seeded IDs, and sequential counter-based IDs. Thread-safe -/// and optimized for high-throughput scenarios. +/// Thread-safe ID generator with configurable strategies. /// -/// # Examples +/// # Example /// /// ```rust /// use weavegraph::utils::id_generator::{IdGenerator, IdConfig}; /// -/// // Default random ID generation -/// let generator = IdGenerator::new(); -/// let id = generator.generate_run_id(); -/// assert!(id.starts_with("run-")); +/// let id_gen = IdGenerator::new(); +/// assert!(id_gen.generate_run_id().starts_with("run-")); /// -/// // Deterministic generation for testing -/// let config = IdConfig { -/// seed: Some(12345), -/// prefix: Some("test".into()), +/// let det = IdGenerator::with_config(IdConfig { +/// seed: Some(1), +/// use_counter: true, /// ..Default::default() -/// }; -/// let det_generator = IdGenerator::with_config(config); -/// let det_id = det_generator.generate_id(); +/// }); +/// assert_ne!(det.generate_id(), det.generate_id()); /// ``` -#[derive(Debug)] +#[derive(Debug, Default)] pub struct IdGenerator { config: IdConfig, counter: AtomicU64, } impl IdGenerator { - /// Create a new ID generator with default configuration. - /// - /// Uses random UUID generation without prefixes or deterministic behavior. - /// - /// # Returns - /// A new IdGenerator instance with default settings. + /// Create a generator with default (random UUID) settings. #[must_use] pub fn new() -> Self { - Self { - config: IdConfig::default(), - counter: AtomicU64::new(0), - } + Self::default() } - /// Create an ID generator with custom configuration. - /// - /// # Parameters - /// * `config` - Configuration specifying ID generation behavior - /// - /// # Returns - /// A new IdGenerator instance with the specified configuration. - /// - /// # Examples - /// - /// ```rust - /// use weavegraph::utils::id_generator::{IdGenerator, IdConfig}; - /// - /// let config = IdConfig { - /// seed: Some(42), - /// prefix: Some("weavegraph".into()), - /// include_timestamp: true, - /// use_counter: false, - /// }; - /// let generator = IdGenerator::with_config(config); - /// ``` + /// Create a generator with custom configuration. #[must_use] pub fn with_config(config: IdConfig) -> Self { Self { @@ -134,323 +58,74 @@ impl IdGenerator { } } - /// Generate a generic ID using the configured strategy. - /// - /// This is the core ID generation method that respects all configuration - /// options including seeds, prefixes, timestamps, and counters. - /// - /// # Returns - /// A newly generated ID string. - /// - /// # Examples + /// Generate an ID, applying seed, counter, timestamp, and prefix from config. /// /// ```rust /// use weavegraph::utils::id_generator::IdGenerator; /// - /// let generator = IdGenerator::new(); - /// let id = generator.generate_id(); - /// assert!(!id.is_empty()); + /// assert!(!IdGenerator::new().generate_id().is_empty()); /// ``` #[must_use] pub fn generate_id(&self) -> String { - let base_id = if let Some(seed) = self.config.seed { - if self.config.use_counter { - let counter = self.counter.fetch_add(1, Ordering::Relaxed); - format!("seeded-{}-{}", seed, counter) - } else { - format!("seeded-{}", seed) - } - } else if self.config.use_counter { - let counter = self.counter.fetch_add(1, Ordering::Relaxed); - format!("counter-{}", counter) - } else { - self.generate_uuid() - }; - - let mut final_id = base_id; - + let mut id = self.base_id(); if self.config.include_timestamp { - let timestamp = chrono::Utc::now().timestamp(); - final_id = format!("{}-t{}", final_id, timestamp); + let ts = chrono::Utc::now().timestamp(); + id = format!("{id}-t{ts}"); } - - if let Some(prefix) = &self.config.prefix { - final_id = format!("{}-{}", prefix, final_id); + if let Some(p) = &self.config.prefix { + id = format!("{p}-{id}"); } - - final_id + id } - /// Generate a UUID v4 as a string. - /// - /// Provides direct access to UUID generation regardless of configuration. - /// - /// # Returns - /// A new UUID v4 formatted as a string. - #[must_use] - pub fn generate_uuid(&self) -> String { - Uuid::new_v4().to_string() - } - - /// Generate an ID with a specific prefix. - /// - /// This method combines the configured generation strategy with a - /// custom prefix, useful for creating typed IDs. - /// - /// # Parameters - /// * `prefix` - Prefix to prepend to the generated ID - /// - /// # Returns - /// A new ID with the specified prefix. - /// - /// # Examples + /// Generate an ID prefixed with `prefix`, bypassing config prefix and timestamp. /// /// ```rust /// use weavegraph::utils::id_generator::IdGenerator; /// - /// let generator = IdGenerator::new(); - /// let session_id = generator.generate_id_with_prefix("session"); - /// assert!(session_id.starts_with("session-")); + /// assert!(IdGenerator::new().generate_id_with_prefix("session").starts_with("session-")); /// ``` #[must_use] pub fn generate_id_with_prefix(&self, prefix: &str) -> String { - format!("{}-{}", prefix, self.generate_base_id()) + format!("{prefix}-{}", self.base_id()) } - /// Generate a run ID for workflow execution tracking. - /// - /// Run IDs are used to track complete workflow executions from start to finish. - /// - /// # Returns - /// A new run ID with "run" prefix. + /// Generate a `run-` identifier. #[must_use] pub fn generate_run_id(&self) -> String { self.generate_id_with_prefix("run") } - /// Generate a step ID for individual workflow steps. - /// - /// Step IDs track individual execution steps within a workflow run. - /// - /// # Returns - /// A new step ID with "step" prefix. + /// Generate a `step-` identifier. #[must_use] pub fn generate_step_id(&self) -> String { self.generate_id_with_prefix("step") } - /// Generate a node ID for graph node instances. - /// - /// Node IDs identify specific instances of nodes in the execution graph. - /// - /// # Returns - /// A new node ID with "node" prefix. + /// Generate a `node-` identifier. #[must_use] pub fn generate_node_id(&self) -> String { self.generate_id_with_prefix("node") } - /// Generate a session ID for checkpoint and persistence tracking. - /// - /// Session IDs are used for checkpoint management and state persistence. - /// - /// # Returns - /// A new session ID with "session" prefix. + /// Generate a `session-` identifier. #[must_use] pub fn generate_session_id(&self) -> String { self.generate_id_with_prefix("session") } - /// Generate a random ID without any configuration-based modifications. - /// - /// This bypasses all configuration and always generates a random UUID. - /// - /// # Returns - /// A random UUID string. - #[must_use] - pub fn generate_random_id(&self) -> String { - self.generate_uuid() - } - - /// Parse and validate an ID format. - /// - /// Checks if an ID follows expected patterns and extracts components. - /// - /// # Parameters - /// * `id` - ID string to validate - /// - /// # Returns - /// Ok(ParsedId) if valid, Err(IdError) if invalid. - pub fn parse_id(&self, id: &str) -> Result { - if id.is_empty() { - return Err(IdError::InvalidFormat { - format: "empty string".into(), - }); - } - - let parts: Vec<&str> = id.split('-').collect(); - if parts.len() < 2 { - return Err(IdError::InvalidFormat { - format: "missing separator".into(), - }); - } - - Ok(ParsedId { - prefix: parts[0].to_string(), - base: parts[1..].join("-"), - original: id.to_string(), - }) - } - - /// Get the current counter value. - /// - /// Useful for testing and debugging sequential ID generation. - /// - /// # Returns - /// Current counter value. - #[must_use] - pub fn current_counter(&self) -> u64 { - self.counter.load(Ordering::Relaxed) - } - - /// Reset the counter to zero. - /// - /// Useful for testing scenarios that require predictable counter values. - pub fn reset_counter(&self) { - self.counter.store(0, Ordering::Relaxed); - } - - /// Private helper method for base ID generation - fn generate_base_id(&self) -> String { - if let Some(seed) = self.config.seed { - if self.config.use_counter { - let counter = self.counter.fetch_add(1, Ordering::Relaxed); - format!("seeded-{}-{}", seed, counter) - } else { - format!("seeded-{}", seed) + fn base_id(&self) -> String { + match (self.config.seed, self.config.use_counter) { + (Some(seed), true) => { + let n = self.counter.fetch_add(1, Ordering::Relaxed); + format!("seeded-{seed}-{n}") } - } else if self.config.use_counter { - let counter = self.counter.fetch_add(1, Ordering::Relaxed); - format!("counter-{}", counter) - } else { - self.generate_uuid() - } - } -} - -impl Default for IdGenerator { - fn default() -> Self { - Self::new() - } -} - -/// Parsed ID components for analysis and validation. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ParsedId { - /// The prefix part of the ID (e.g., "run", "step", "node"). - pub prefix: String, - /// The base identifier part after the prefix. - pub base: String, - /// The original complete ID string. - pub original: String, -} - -impl ParsedId { - /// Check if this ID has a specific prefix. - /// - /// # Parameters - /// * `expected_prefix` - Prefix to check for - /// - /// # Returns - /// True if the ID has the expected prefix. - #[must_use] - pub fn has_prefix(&self, expected_prefix: &str) -> bool { - self.prefix == expected_prefix - } - - /// Extract timestamp from the ID if present. - /// - /// # Returns - /// Timestamp if found and valid, None otherwise. - #[must_use] - pub fn extract_timestamp(&self) -> Option { - if let Some(timestamp_part) = self.base.split('-').find(|part| part.starts_with('t')) { - timestamp_part[1..].parse().ok() - } else { - None - } - } -} - -/// Utility functions for common ID operations. -pub mod id_utils { - use super::*; - - /// Create a deterministic ID generator for testing. - /// - /// # Parameters - /// * `seed` - Seed value for deterministic generation - /// - /// # Returns - /// IdGenerator configured for deterministic behavior. - #[must_use] - pub fn create_test_generator(seed: u64) -> IdGenerator { - IdGenerator::with_config(IdConfig { - seed: Some(seed), - use_counter: true, - ..Default::default() - }) - } - - /// Create a production ID generator with timestamps. - /// - /// # Returns - /// IdGenerator configured for production use with timestamps. - #[must_use] - pub fn create_production_generator() -> IdGenerator { - IdGenerator::with_config(IdConfig { - include_timestamp: true, - ..Default::default() - }) - } - - /// Validate that an ID follows expected patterns. - /// - /// # Parameters - /// * `id` - ID to validate - /// * `expected_prefix` - Expected prefix (optional) - /// - /// # Returns - /// True if the ID is valid. - pub fn is_valid_id(id: &str, expected_prefix: Option<&str>) -> bool { - if id.is_empty() { - return false; - } - - let parts: Vec<&str> = id.split('-').collect(); - if parts.len() < 2 { - return false; - } - - if let Some(prefix) = expected_prefix { - parts[0] == prefix - } else { - true + (Some(seed), false) => format!("seeded-{seed}"), + (None, true) => { + let n = self.counter.fetch_add(1, Ordering::Relaxed); + format!("counter-{n}") + } + (None, false) => Uuid::new_v4().to_string(), } } - - /// Extract the type of an ID from its prefix. - /// - /// # Parameters - /// * `id` - ID to analyze - /// - /// # Returns - /// ID type based on prefix, or "unknown" if unrecognized. - #[must_use] - pub fn get_id_type(id: &str) -> String { - id.split('-') - .next() - .map(|s| s.to_string()) - .unwrap_or_else(|| "unknown".to_string()) - } } diff --git a/src/utils/json_ext.rs b/src/utils/json_ext.rs index 608b0dd..ee7a709 100644 --- a/src/utils/json_ext.rs +++ b/src/utils/json_ext.rs @@ -1,7 +1,6 @@ //! JSON manipulation utilities and extensions for the Weavegraph framework. //! -//! Provides utilities for deep merging JSON objects, pointer-based access, -//! and common JSON manipulation patterns used throughout the framework. +//! Provides deep-merge, dot-path access, and common JSON manipulation patterns. use serde_json::{Map, Value}; use std::collections::HashMap; @@ -18,159 +17,127 @@ pub enum JsonError { diagnostic(code(weavegraph::json::invalid_pointer)) )] InvalidPointer { - /// The invalid JSON pointer string. + /// The offending pointer string. pointer: String, }, - /// JSON merge conflict that cannot be resolved. + /// Merge conflict that cannot be resolved. #[error("Merge conflict at path '{path}': cannot merge {left_type} with {right_type}")] #[cfg_attr( feature = "diagnostics", diagnostic(code(weavegraph::json::merge_conflict)) )] MergeConflict { - /// JSON path where the conflict occurred. + /// Dot-separated path where the conflict occurred. path: String, - /// Type of the left operand at the conflict point. + /// JSON type of the left operand. left_type: String, - /// Type of the right operand at the conflict point. + /// JSON type of the right operand. right_type: String, }, - /// Serialization/deserialization error. + /// Serialization or deserialization error. #[error("JSON serialization error: {source}")] #[cfg_attr(feature = "diagnostics", diagnostic(code(weavegraph::json::serde)))] Serde { - /// The underlying serde_json error. + /// Underlying serde_json error. #[from] source: serde_json::Error, }, } -/// Strategy for handling conflicts during JSON merges. +/// Strategy for resolving conflicts during JSON merges. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MergeStrategy { - /// Prefer values from the left operand when conflicts occur. + /// Prefer the left value on conflict. PreferLeft, - /// Prefer values from the right operand when conflicts occur. + /// Prefer the right value on conflict. PreferRight, - /// Fail on any merge conflict. + /// Fail on any conflict. FailOnConflict, - /// Attempt to merge values recursively, failing only on type mismatches. + /// Merge objects recursively; prefer right for primitive conflicts; concatenate arrays. DeepMerge, } -/// Performs deep merge of two JSON values according to the specified strategy. +/// Deep-merge two JSON values according to `strategy`. /// -/// # Parameters -/// * `left` - Left operand for the merge -/// * `right` - Right operand for the merge -/// * `strategy` - Strategy for handling conflicts -/// -/// # Returns -/// Merged JSON value or error if merge fails -/// -/// # Examples +/// Objects are always merged key-by-key. Arrays and scalars follow `strategy`. /// /// ```rust /// use weavegraph::utils::json_ext::{deep_merge, MergeStrategy}; -/// use serde_json::{json, Value}; +/// use serde_json::json; /// /// let left = json!({"a": 1, "b": {"x": 10}}); /// let right = json!({"b": {"y": 20}, "c": 3}); -/// /// let merged = deep_merge(&left, &right, MergeStrategy::DeepMerge).unwrap(); -/// let expected = json!({"a": 1, "b": {"x": 10, "y": 20}, "c": 3}); -/// assert_eq!(merged, expected); +/// assert_eq!(merged, json!({"a": 1, "b": {"x": 10, "y": 20}, "c": 3})); /// ``` pub fn deep_merge( left: &Value, right: &Value, strategy: MergeStrategy, ) -> Result { - deep_merge_with_path(left, right, strategy, "") + merge_at(left, right, strategy, "") } -/// Internal function that tracks the current path for better error reporting. -fn deep_merge_with_path( +fn merge_at( left: &Value, right: &Value, strategy: MergeStrategy, path: &str, ) -> Result { match (left, right) { - // Both are objects - merge recursively - (Value::Object(left_obj), Value::Object(right_obj)) => { + (Value::Object(l), Value::Object(r)) => { let mut result = Map::new(); - - // Add all keys from left - for (key, value) in left_obj { - let current_path = if path.is_empty() { + for (key, lv) in l { + let child_path = if path.is_empty() { key.clone() } else { - format!("{}.{}", path, key) + format!("{path}.{key}") }; - - if let Some(right_value) = right_obj.get(key) { - // Key exists in both - merge recursively - let merged = deep_merge_with_path(value, right_value, strategy, ¤t_path)?; - result.insert(key.clone(), merged); - } else { - // Key only in left - result.insert(key.clone(), value.clone()); - } + let merged = match r.get(key) { + Some(rv) => merge_at(lv, rv, strategy, &child_path)?, + None => lv.clone(), + }; + result.insert(key.clone(), merged); } - - // Add keys that only exist in right - for (key, value) in right_obj { - if !left_obj.contains_key(key) { - result.insert(key.clone(), value.clone()); + for (key, rv) in r { + if !l.contains_key(key) { + result.insert(key.clone(), rv.clone()); } } - Ok(Value::Object(result)) } - - // Both are arrays - strategy determines behavior - (Value::Array(left_arr), Value::Array(right_arr)) => match strategy { - MergeStrategy::PreferLeft => Ok(Value::Array(left_arr.clone())), - MergeStrategy::PreferRight => Ok(Value::Array(right_arr.clone())), + (Value::Array(l), Value::Array(r)) => match strategy { + MergeStrategy::PreferLeft => Ok(Value::Array(l.clone())), + MergeStrategy::PreferRight => Ok(Value::Array(r.clone())), MergeStrategy::FailOnConflict => Err(JsonError::MergeConflict { - path: path.to_string(), - left_type: "array".to_string(), - right_type: "array".to_string(), + path: path.to_owned(), + left_type: "array".to_owned(), + right_type: "array".to_owned(), }), MergeStrategy::DeepMerge => { - // Concatenate arrays - let mut result = left_arr.clone(); - result.extend(right_arr.clone()); - Ok(Value::Array(result)) + let mut out = l.clone(); + out.extend_from_slice(r); + Ok(Value::Array(out)) } }, - - // Same primitive values - (left_val, right_val) if left_val == right_val => Ok(left_val.clone()), - - // Different values - strategy determines behavior - (left_val, right_val) => match strategy { - MergeStrategy::PreferLeft => Ok(left_val.clone()), - MergeStrategy::PreferRight => Ok(right_val.clone()), + (lv, rv) if lv == rv => Ok(lv.clone()), + (lv, rv) => match strategy { + MergeStrategy::PreferLeft => Ok(lv.clone()), + MergeStrategy::PreferRight => Ok(rv.clone()), MergeStrategy::FailOnConflict => Err(JsonError::MergeConflict { - path: path.to_string(), - left_type: get_value_type(left_val).to_string(), - right_type: get_value_type(right_val).to_string(), + path: path.to_owned(), + left_type: value_type(lv).to_owned(), + right_type: value_type(rv).to_owned(), }), - MergeStrategy::DeepMerge => { - // For primitives in deep merge, prefer right - Ok(right_val.clone()) - } + MergeStrategy::DeepMerge => Ok(rv.clone()), }, } } -/// Get a human-readable type name for a JSON value. -fn get_value_type(value: &Value) -> &'static str { - match value { +fn value_type(v: &Value) -> &'static str { + match v { Value::Null => "null", Value::Bool(_) => "boolean", Value::Number(_) => "number", @@ -180,97 +147,51 @@ fn get_value_type(value: &Value) -> &'static str { } } -/// Merge multiple JSON values using the specified strategy. -/// -/// # Parameters -/// * `values` - Iterator of JSON values to merge -/// * `strategy` - Strategy for handling conflicts -/// -/// # Returns -/// Merged JSON value or error if merge fails -/// -/// # Examples +/// Fold-merge an iterator of JSON values using `strategy`. /// /// ```rust /// use weavegraph::utils::json_ext::{merge_multiple, MergeStrategy}; /// use serde_json::json; /// -/// let values = vec![ -/// json!({"a": 1}), -/// json!({"b": 2}), -/// json!({"c": 3}), -/// ]; -/// +/// let values = [json!({"a": 1}), json!({"b": 2}), json!({"c": 3})]; /// let merged = merge_multiple(values.iter(), MergeStrategy::DeepMerge).unwrap(); -/// let expected = json!({"a": 1, "b": 2, "c": 3}); -/// assert_eq!(merged, expected); +/// assert_eq!(merged, json!({"a": 1, "b": 2, "c": 3})); /// ``` pub fn merge_multiple<'a, I>(values: I, strategy: MergeStrategy) -> Result where I: IntoIterator, { - let mut result = Value::Object(Map::new()); - for value in values { - result = deep_merge(&result, value, strategy)?; + let mut acc = Value::Object(Map::new()); + for v in values { + acc = deep_merge(&acc, v, strategy)?; } - Ok(result) + Ok(acc) } -/// Get a value using a JSON pointer-like path. -/// -/// # Parameters -/// * `value` - JSON value to search in -/// * `path` - Dot-separated path (e.g., "user.profile.name") -/// -/// # Returns -/// Reference to the value if found, None otherwise +/// Walk a dot-separated path into a JSON value. /// -/// # Examples +/// Array segments are parsed as numeric indices. /// /// ```rust /// use weavegraph::utils::json_ext::get_by_path; /// use serde_json::json; /// /// let data = json!({"user": {"profile": {"name": "Alice"}}}); -/// let name = get_by_path(&data, "user.profile.name"); -/// assert_eq!(name, Some(&json!("Alice"))); +/// assert_eq!(get_by_path(&data, "user.profile.name"), Some(&json!("Alice"))); /// ``` #[must_use] pub fn get_by_path<'a>(value: &'a Value, path: &str) -> Option<&'a Value> { if path.is_empty() { return Some(value); } - - let parts: Vec<&str> = path.split('.').collect(); - let mut current = value; - - for part in parts { - match current { - Value::Object(obj) => { - current = obj.get(part)?; - } - Value::Array(arr) => { - let index: usize = part.parse().ok()?; - current = arr.get(index)?; - } - _ => return None, - } - } - - Some(current) + path.split('.').try_fold(value, |cur, seg| match cur { + Value::Object(obj) => obj.get(seg), + Value::Array(arr) => arr.get(seg.parse::().ok()?), + _ => None, + }) } -/// Set a value using a JSON pointer-like path, creating intermediate objects as needed. -/// -/// # Parameters -/// * `target` - Mutable JSON value to modify -/// * `path` - Dot-separated path (e.g., "user.profile.name") -/// * `value` - Value to set -/// -/// # Returns -/// Result indicating success or failure -/// -/// # Examples +/// Walk a dot-separated path and assign a value, auto-vivifying intermediate objects. /// /// ```rust /// use weavegraph::utils::json_ext::set_by_path; @@ -278,58 +199,42 @@ pub fn get_by_path<'a>(value: &'a Value, path: &str) -> Option<&'a Value> { /// /// let mut data = json!({}); /// set_by_path(&mut data, "user.profile.name", json!("Alice")).unwrap(); -/// -/// let expected = json!({"user": {"profile": {"name": "Alice"}}}); -/// assert_eq!(data, expected); +/// assert_eq!(data, json!({"user": {"profile": {"name": "Alice"}}})); /// ``` pub fn set_by_path(target: &mut Value, path: &str, value: Value) -> Result<(), JsonError> { if path.is_empty() { *target = value; return Ok(()); } - - let parts: Vec<&str> = path.split('.').collect(); - let mut current = target; - - // Navigate to the parent of the target location - for part in &parts[..parts.len() - 1] { - match current { + let mut segs: Vec<&str> = path.split('.').collect(); + let final_key = segs.pop().expect("split yields at least one element"); + let mut cur = target; + for &seg in &segs { + match cur { Value::Object(obj) => { - current = obj - .entry(part.to_string()) + cur = obj + .entry(seg.to_owned()) .or_insert_with(|| Value::Object(Map::new())); } _ => { return Err(JsonError::InvalidPointer { - pointer: path.to_string(), + pointer: path.to_owned(), }); } } } - - // Set the final value - let final_key = parts[parts.len() - 1]; - match current { + match cur { Value::Object(obj) => { - obj.insert(final_key.to_string(), value); + obj.insert(final_key.to_owned(), value); Ok(()) } _ => Err(JsonError::InvalidPointer { - pointer: path.to_string(), + pointer: path.to_owned(), }), } } -/// Check if a JSON value has a specific structure. -/// -/// # Parameters -/// * `value` - JSON value to validate -/// * `expected_keys` - Expected object keys -/// -/// # Returns -/// True if the value is an object containing all expected keys -/// -/// # Examples +/// Returns `true` if `value` is an object containing every key in `expected_keys`. /// /// ```rust /// use weavegraph::utils::json_ext::has_structure; @@ -342,39 +247,29 @@ pub fn set_by_path(target: &mut Value, path: &str, value: Value) -> Result<(), J #[must_use] pub fn has_structure(value: &Value, expected_keys: &[&str]) -> bool { match value { - Value::Object(obj) => expected_keys.iter().all(|key| obj.contains_key(*key)), + Value::Object(obj) => expected_keys.iter().all(|k| obj.contains_key(*k)), _ => false, } } -/// Convert a HashMap to a JSON object. -/// -/// # Parameters -/// * `map` - HashMap to convert -/// -/// # Returns -/// JSON object representation +/// Convert a `HashMap` into a JSON object. pub fn hashmap_to_json>(map: HashMap) -> Value { - let json_map: Map = map.into_iter().map(|(k, v)| (k, v.into())).collect(); - Value::Object(json_map) + Value::Object(map.into_iter().map(|(k, v)| (k, v.into())).collect()) } -/// Extension trait for JSON Value providing additional utility methods. +/// Extension trait for `Value` with path-navigation and introspection helpers. pub trait JsonValueExt { - /// Get a value by path with a default if not found. + /// Return the value at `path`, or `default` if the path is absent. fn get_path_or<'a>(&'a self, path: &str, default: &'a Value) -> &'a Value; - /// Check if this value is an empty object or array. + /// Return `true` if this is an empty object or array. fn is_empty_container(&self) -> bool; - /// Get the number of elements (for objects/arrays) or 1 (for primitives). + /// Number of elements for objects/arrays, or `1` for scalars. fn element_count(&self) -> usize; - /// Get all keys if this is an object. + /// Object keys, or an empty vec for non-objects. fn keys(&self) -> Vec; - - /// Deep clone with type conversion. - fn deep_clone(&self) -> Value; } impl JsonValueExt for Value { @@ -404,45 +299,20 @@ impl JsonValueExt for Value { _ => vec![], } } - - fn deep_clone(&self) -> Value { - self.clone() - } } -/// Trait for types that can be serialized to/from JSON strings with specific error handling. +/// Generic serialization interface for types that round-trip through JSON. /// -/// This provides a consistent interface for JSON operations throughout the framework. -/// Unlike the other utilities in this module which work with `JsonError`, this trait -/// is generic over the error type to allow different modules to use their own error types. +/// Generic over `E` so each module can map errors to its own type. pub trait JsonSerializable: serde::Serialize + for<'de> serde::de::DeserializeOwned { - /// Serialize this object to a JSON string. - /// - /// # Errors - /// - /// Returns an error if serialization fails. + /// Serialize to a JSON string. fn to_json_string(&self) -> Result; - /// Deserialize an object from a JSON string. - /// - /// # Errors - /// - /// Returns an error if deserialization fails. + /// Deserialize from a JSON string. fn from_json_str(s: &str) -> Result; } -/// Helper for JSON serialization with custom error context. -/// -/// This utility provides context-aware JSON serialization that can be used -/// by different modules with their own error types. -/// -/// # Parameters -/// * `value` - The value to serialize -/// * `context` - Context string for error messages -/// * `error_mapper` - Function to convert serde_json::Error to the target error type -/// -/// # Returns -/// JSON string or mapped error +/// Serialize `value` to a JSON string, mapping any error through `error_mapper`. pub fn serialize_with_context( value: &T, context: &str, @@ -454,18 +324,7 @@ where serde_json::to_string(value).map_err(|e| error_mapper(e, context)) } -/// Helper for JSON deserialization with custom error context. -/// -/// This utility provides context-aware JSON deserialization that can be used -/// by different modules with their own error types. -/// -/// # Parameters -/// * `json` - The JSON string to deserialize -/// * `context` - Context string for error messages -/// * `error_mapper` - Function to convert serde_json::Error to the target error type -/// -/// # Returns -/// Deserialized value or mapped error +/// Deserialize a JSON string, mapping any error through `error_mapper`. pub fn deserialize_with_context( json: &str, context: &str, diff --git a/src/utils/merge_inspector.rs b/src/utils/merge_inspector.rs deleted file mode 100644 index 3525769..0000000 --- a/src/utils/merge_inspector.rs +++ /dev/null @@ -1,43 +0,0 @@ -//! Merge Inspector for debug merge traces and state debugging. -//! -//! This module provides tools to inspect and debug state merges during -//! barrier synchronization. Useful for understanding how `NodePartial` -//! updates are combined into the final `VersionedState`. -//! -//! # Overview -//! -//! When multiple nodes execute concurrently, their outputs are merged -//! during the barrier phase. The merge inspector helps diagnose issues -//! like: -//! -//! - Unexpected state after merges -//! - Reducer conflicts or ordering issues -//! - Missing or overwritten channel data -//! -//! # Future Implementation -//! -//! This module is currently a placeholder. Planned features include: -//! -//! - `MergeTrace` struct capturing before/after snapshots -//! - `MergeInspector` trait for custom inspection hooks -//! - Integration with tracing for structured merge logging -//! - Diff generation between pre/post merge states -//! -//! # Example (Future API) -//! -//! ```rust,ignore -//! use weavegraph::utils::merge_inspector::MergeInspector; -//! -//! // Attach an inspector to capture merge operations -//! let inspector = MergeInspector::new() -//! .with_diff_output(true) -//! .with_channel_filter(ChannelType::Message); -//! -//! // Inspect will be called during barrier synchronization -//! let traces = inspector.traces(); -//! for trace in traces { -//! println!("Node {} merged: {:?}", trace.node_id, trace.diff); -//! } -//! ``` - -// Placeholder for future implementation diff --git a/src/utils/message_id_helpers.rs b/src/utils/message_id_helpers.rs deleted file mode 100644 index 0c44b4b..0000000 --- a/src/utils/message_id_helpers.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! Message and tool-call ID generation helpers. -//! -//! Provides utilities for generating unique identifiers for messages -//! and tool calls within a workflow execution. These IDs enable: -//! -//! - Message correlation and threading -//! - Tool call tracking and response matching -//! - Audit trails and debugging -//! -//! # Thread Safety -//! -//! ID generators in this module ensure per-thread uniqueness to avoid -//! collisions in concurrent execution scenarios. The implementation -//! uses atomic counters combined with thread-local state for efficiency. -//! -//! # ID Formats -//! -//! Generated IDs follow predictable formats for parseability: -//! -//! - Message IDs: `msg-{session_id}-{step}-{counter}` -//! - Tool Call IDs: `tool-{node_id}-{step}-{counter}` -//! -//! # Future Implementation -//! -//! This module is currently a placeholder. Planned features include: -//! -//! - `MessageIdGenerator` struct with configurable prefixes -//! - `ToolCallIdGenerator` for tracking tool invocations -//! - Integration with `IdGenerator` for consistent ID semantics -//! - Parsing utilities to extract components from generated IDs -//! -//! # Example (Future API) -//! -//! ```rust,ignore -//! use weavegraph::utils::message_id_helpers::MessageIdGenerator; -//! -//! let generator = MessageIdGenerator::new("session-123"); -//! -//! let msg_id = generator.next_message_id(1); // step 1 -//! // msg_id = "msg-session-123-1-0" -//! -//! let tool_id = generator.next_tool_call_id("my_node", 1); -//! // tool_id = "tool-my_node-1-0" -//! ``` - -// Placeholder for future implementation diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 0252713..47cd88c 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,66 +1,15 @@ -//! Utilities module for common functionality across the Weavegraph framework. +//! Shared utilities used throughout the Weavegraph framework. //! -//! This module provides reusable utilities and common patterns that are used -//! throughout the codebase. These utilities are designed to be generic, -//! type-safe, and ergonomic to maximize applicability and maintain consistency. +//! # Submodules //! -//! # Module Organization -//! -//! - [`collections`]: Collection utilities and common patterns for maps and data structures -//! - [`clock`]: Injectable time sources for checkpoints and time-based operations -//! - [`deterministic_rng`]: Deterministic random number generation for testing -//! - [`id_generator`]: ID generation utilities for runs, steps, nodes, and sessions -//! - [`json_ext`]: JSON manipulation utilities and extensions -//! - [`merge_inspector`]: State merge debugging and inspection tools *(placeholder)* -//! - [`message_id_helpers`]: Message and tool-call ID generation *(placeholder)* -//! - [`type_guards`]: Type validation and shape checking utilities *(placeholder)* -//! -//! # Design Principles -//! -//! - **Type Safety**: Leverage Rust's type system for compile-time correctness -//! - **Error Handling**: Consistent use of `Result`, `Option`, and `thiserror`/`miette` -//! - **Ergonomics**: Builder patterns, extension traits, and intuitive APIs -//! - **Reusability**: Generic implementations that work across different contexts -//! - **Testing**: Comprehensive test coverage and deterministic behavior -//! -//! # Common Patterns -//! -//! ## Extra Data Maps -//! ```rust -//! use weavegraph::utils::collections::{new_extra_map, ExtraMapExt}; -//! -//! let mut extra = new_extra_map(); -//! extra.insert_string("key", "value"); -//! extra.insert_number("count", 42); -//! ``` -//! -//! ## ID Generation -//! ```rust -//! use weavegraph::utils::id_generator::IdGenerator; -//! -//! let generator = IdGenerator::new(); -//! let run_id = generator.generate_run_id(); -//! let step_id = generator.generate_step_id(); -//! ``` -//! -//! ## Time Abstraction -//! ```rust -//! use weavegraph::utils::clock::{Clock, SystemClock, MockClock}; -//! -//! // Production -//! let clock = SystemClock; -//! let timestamp = clock.now(); -//! -//! // Testing -//! let mut mock_clock = MockClock::new(1000); -//! mock_clock.advance_secs(30); -//! ``` +//! - [`clock`]: Injectable time sources for checkpointing and timing +//! - [`collections`]: Extra-map factories and extension traits +//! - [`deterministic_rng`]: Seeded RNG for reproducible behaviour in tests +//! - [`id_generator`]: Run, step, node, and session ID generation +//! - [`json_ext`]: JSON manipulation helpers pub mod clock; pub mod collections; pub mod deterministic_rng; pub mod id_generator; pub mod json_ext; -pub mod merge_inspector; -pub mod message_id_helpers; -pub mod type_guards; diff --git a/src/utils/type_guards.rs b/src/utils/type_guards.rs deleted file mode 100644 index aa2cea3..0000000 --- a/src/utils/type_guards.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! Type-erasure guards and shape validation utilities. -//! -//! This module provides runtime validation for type-erased data structures, -//! particularly for validating `NodePartial` updates against `ChannelSpec` -//! definitions. These guards provide clear, actionable error messages when -//! type mismatches occur. -//! -//! # Purpose -//! -//! In workflow systems, data often flows through type-erased containers -//! (like `serde_json::Value` or `Box`). This module helps catch -//! type errors at runtime boundaries with descriptive diagnostics. -//! -//! # Use Cases -//! -//! - Validating reducer inputs match expected channel types -//! - Checking `NodePartial` updates contain valid data shapes -//! - Runtime schema validation for dynamic workflows -//! - Clear error reporting for configuration mistakes -//! -//! # Future Implementation -//! -//! This module is currently a placeholder. Planned features include: -//! -//! - `TypeGuard` trait for custom validation logic -//! - `ShapeValidator` for JSON schema-like validation -//! - `ChannelTypeChecker` for reducer input validation -//! - Integration with `miette` for rich error diagnostics -//! -//! # Example (Future API) -//! -//! ```rust,ignore -//! use weavegraph::utils::type_guards::{TypeGuard, validate_shape}; -//! use serde_json::json; -//! -//! // Define expected shape -//! let expected = ShapeSpec::object() -//! .with_field("messages", ShapeSpec::array(ShapeSpec::string())) -//! .with_field("count", ShapeSpec::number()); -//! -//! // Validate at runtime -//! let data = json!({"messages": ["hello"], "count": 42}); -//! validate_shape(&data, &expected)?; // Ok -//! -//! let bad_data = json!({"messages": "not an array"}); -//! validate_shape(&bad_data, &expected)?; -//! // Error: field 'messages' expected array, found string -//! ``` - -// Placeholder for future implementation From c48d79a67004c4c2d64e3fd3b30eeb8b07e7b212 Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 18:41:46 -0400 Subject: [PATCH 08/15] first batch of test revisions --- tests/app.rs | 289 +++++++++++---------------------- tests/channels.rs | 142 ++-------------- tests/event_bus.rs | 68 ++------ tests/event_bus_diagnostics.rs | 9 +- tests/event_bus_stress.rs | 19 ++- tests/graphs.rs | 138 +++++----------- tests/graphs_property.rs | 18 +- tests/messages.rs | 8 +- tests/nodes.rs | 38 ++--- tests/reducers.rs | 202 +++++++++-------------- 10 files changed, 275 insertions(+), 656 deletions(-) diff --git a/tests/app.rs b/tests/app.rs index 2384344..358fb42 100644 --- a/tests/app.rs +++ b/tests/app.rs @@ -14,22 +14,32 @@ mod common; use common::*; fn make_app() -> weavegraph::app::App { - // Minimal app via GraphBuilder; node graph is irrelevant for apply_barrier GraphBuilder::new() .add_edge(NodeKind::Start, NodeKind::End) .compile() .unwrap() } +fn message_app() -> weavegraph::app::App { + GraphBuilder::new() + .add_node( + NodeKind::Custom("node".into()), + SimpleMessageNode::new("response"), + ) + .add_edge(NodeKind::Start, NodeKind::Custom("node".into())) + .add_edge(NodeKind::Custom("node".into()), NodeKind::End) + .compile() + .unwrap() +} + #[tokio::test] -async fn test_apply_barrier_messages_update() { +async fn apply_barrier_appends_messages_and_bumps_version() { let app = make_app(); let state = &mut state_with_user("hi"); - let run_ids = vec![NodeKind::Start]; let partial = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "foo")]); let outcome = app - .apply_barrier(state, &run_ids, vec![partial]) + .apply_barrier(state, &[NodeKind::Start], vec![partial]) .await .unwrap(); assert!(outcome.updated_channels.contains(&"messages")); @@ -40,13 +50,11 @@ async fn test_apply_barrier_messages_update() { } #[tokio::test] -async fn test_apply_barrier_no_update() { +async fn apply_barrier_with_empty_partial_changes_nothing() { let app = make_app(); let state = &mut state_with_user("hi"); - let run_ids = vec![NodeKind::Start]; - let partial = NodePartial::new(); let outcome = app - .apply_barrier(state, &run_ids, vec![partial]) + .apply_barrier(state, &[NodeKind::Start], vec![NodePartial::new()]) .await .unwrap(); assert!(outcome.updated_channels.is_empty()); @@ -56,10 +64,9 @@ async fn test_apply_barrier_no_update() { } #[tokio::test] -async fn test_apply_barrier_saturating_version() { +async fn apply_barrier_version_saturates_at_max() { let app = make_app(); let state = &mut state_with_user("hi"); - // push messages version to max to verify saturating add behavior state.messages.set_version(u32::MAX); let partial = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "x")]); app.apply_barrier(state, &[NodeKind::Start], vec![partial]) @@ -69,12 +76,11 @@ async fn test_apply_barrier_saturating_version() { } #[tokio::test] -async fn test_apply_barrier_preserves_updated_channel_order() { +async fn apply_barrier_channel_order_is_messages_then_extra() { use weavegraph::channels::errors::{ErrorEvent, ErrorScope}; let app = make_app(); let state = &mut state_with_user("hi"); - let run_ids = vec![NodeKind::Start]; let partial_a = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "a")]); @@ -95,7 +101,11 @@ async fn test_apply_barrier_preserves_updated_channel_order() { let partial_c = NodePartial::new().with_errors(vec![err_event.clone()]); let outcome = app - .apply_barrier(state, &run_ids, vec![partial_a, partial_b, partial_c]) + .apply_barrier( + state, + &[NodeKind::Start], + vec![partial_a, partial_b, partial_c], + ) .await .unwrap(); @@ -103,8 +113,7 @@ async fn test_apply_barrier_preserves_updated_channel_order() { assert_eq!(outcome.errors, vec![err_event]); assert_eq!(state.messages.version(), 2); assert_eq!(state.extra.version(), 2); - let extra_snapshot = state.extra.snapshot(); - let mut keys: Vec<_> = extra_snapshot.keys().cloned().collect(); + let mut keys: Vec<_> = state.extra.snapshot().keys().cloned().collect(); keys.sort(); assert_eq!(keys, vec!["a".to_string(), "z".to_string()]); } @@ -156,19 +165,13 @@ async fn invoke_streaming_closes_stream() { } #[tokio::test] -async fn test_apply_barrier_multiple_updates() { +async fn apply_barrier_from_multiple_partials_appends_all_messages() { let app = make_app(); let state = &mut state_with_user("hi"); - let partial1 = - NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "foo")]); - let partial2 = - NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "bar")]); + let p1 = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "foo")]); + let p2 = NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "bar")]); let outcome = app - .apply_barrier( - state, - &[NodeKind::Start, NodeKind::End], - vec![partial1, partial2], - ) + .apply_barrier(state, &[NodeKind::Start, NodeKind::End], vec![p1, p2]) .await .unwrap(); let snap = state.messages.snapshot(); @@ -179,18 +182,17 @@ async fn test_apply_barrier_multiple_updates() { } #[tokio::test] -async fn test_apply_barrier_empty_vectors_and_maps() { +async fn apply_barrier_empty_collections_are_no_ops() { let app = make_app(); let state = &mut state_with_user("hi"); - // Empty messages vector -> Some(vec![]) should be treated as no-op by guard - let empty_msgs = NodePartial::new().with_messages(vec![]); - // Empty extra map -> Some(empty) should be treated as no-op by guard - let empty_extra = NodePartial::new().with_extra(FxHashMap::default()); let outcome = app .apply_barrier( state, &[NodeKind::Start, NodeKind::End], - vec![empty_msgs, empty_extra], + vec![ + NodePartial::new().with_messages(vec![]), + NodePartial::new().with_extra(FxHashMap::default()), + ], ) .await .unwrap(); @@ -200,7 +202,7 @@ async fn test_apply_barrier_empty_vectors_and_maps() { } #[tokio::test] -async fn test_apply_barrier_extra_merge_and_version() { +async fn apply_barrier_extra_partials_merge_and_later_key_wins() { let app = make_app(); let state = &mut state_with_user("hi"); @@ -208,14 +210,17 @@ async fn test_apply_barrier_extra_merge_and_version() { m1.insert("k1".into(), Value::String("v1".into())); let mut m2 = FxHashMap::default(); m2.insert("k2".into(), Value::String("v2".into())); - // Overwrite k1 in second partial to test key overwrite still counts as change m2.insert("k1".into(), Value::String("v3".into())); - let p1 = NodePartial::new().with_extra(m1); - let p2 = NodePartial::new().with_extra(m2); - let outcome = app - .apply_barrier(state, &[NodeKind::Start, NodeKind::End], vec![p1, p2]) + .apply_barrier( + state, + &[NodeKind::Start, NodeKind::End], + vec![ + NodePartial::new().with_extra(m1), + NodePartial::new().with_extra(m2), + ], + ) .await .unwrap(); assert!(outcome.updated_channels.contains(&"extra")); @@ -226,94 +231,51 @@ async fn test_apply_barrier_extra_merge_and_version() { } #[tokio::test] -async fn test_apply_barrier_collects_errors() { +async fn apply_barrier_error_partials_appear_in_outcome() { use weavegraph::channels::errors::ErrorEvent; let app = make_app(); let state = &mut state_with_user("hi"); - let run_ids = vec![NodeKind::Start]; let partial = NodePartial::new().with_errors(vec![ErrorEvent::default()]); - let outcome = app - .apply_barrier(state, &run_ids, vec![partial]) + .apply_barrier(state, &[NodeKind::Start], vec![partial]) .await .unwrap(); - assert!(outcome.updated_channels.is_empty()); assert_eq!(outcome.errors.len(), 1); } #[tokio::test] -async fn test_invoke_with_channel() { - // Build a simple graph with a test node - let app = GraphBuilder::new() - .add_node( - NodeKind::Custom("test".into()), - SimpleMessageNode::new("test output"), - ) - .add_edge(NodeKind::Start, NodeKind::Custom("test".into())) - .add_edge(NodeKind::Custom("test".into()), NodeKind::End) - .compile() - .unwrap(); - - // Execute with channel - let initial_state = state_with_user("test input"); - let (result, events) = app.invoke_with_channel(initial_state).await; - - // Spawn task to collect events (simulating client consumption) - let event_task = tokio::spawn(async move { - let mut count = 0; - // Use timeout to avoid hanging if no events come - let timeout_duration = tokio::time::Duration::from_millis(100); - loop { - match tokio::time::timeout(timeout_duration, events.recv_async()).await { - Ok(Ok(_event)) => count += 1, - Ok(Err(_)) => break, // Channel closed - Err(_) => break, // Timeout - no more events - } - } - count - }); - - // Wait for workflow to complete - let final_state = result.expect("Workflow should complete successfully"); - assert!(!final_state.messages.is_empty(), "Should have messages"); - - // The method itself works - we got a receiver and a result - // Note: Event count verification is inherently racy due to EventBus Drop behavior - let _event_count = event_task.await.expect("Event task should complete"); - // We just verify the API works, not exact event counts +async fn invoke_with_channel_returns_completed_state() { + let app = message_app(); + let (result, _events) = app.invoke_with_channel(state_with_user("prompt")).await; + let final_state = result.expect("workflow completes"); + assert_eq!( + final_state.messages.snapshot().last().unwrap().content, + "response" + ); } #[tokio::test] -async fn test_invoke_with_channel_resumption_updates_versions() { - let app = GraphBuilder::new() - .add_node( - NodeKind::Custom("test".into()), - SimpleMessageNode::new("test output"), - ) - .add_edge(NodeKind::Start, NodeKind::Custom("test".into())) - .add_edge(NodeKind::Custom("test".into()), NodeKind::End) - .compile() - .unwrap(); - - let state = VersionedState::new_with_user_message("first run"); - let (result, _events) = app.invoke_with_channel(state).await; - let final_state = result.expect("first run succeeds"); - assert_eq!(final_state.messages.version(), 2); - - // Re-run with the output state to ensure versions bump deterministically. - let (second_result, _second_events) = app.invoke_with_channel(final_state.clone()).await; - let second_state = second_result.expect("second run succeeds"); - assert_eq!(second_state.messages.version(), 3); - assert_eq!(second_state.extra.version(), final_state.extra.version()); +async fn invoke_with_channel_second_run_increments_message_version() { + let app = message_app(); + + let (result, _events) = app + .invoke_with_channel(VersionedState::new_with_user_message("first")) + .await; + let after_first = result.expect("first run succeeds"); + assert_eq!(after_first.messages.version(), 2); + + let (result, _events) = app.invoke_with_channel(after_first.clone()).await; + let after_second = result.expect("second run succeeds"); + assert_eq!(after_second.messages.version(), 3); + assert_eq!(after_second.extra.version(), after_first.extra.version()); } #[tokio::test] -async fn test_invoke_with_channel_collects_events() { +async fn invoke_streaming_delivers_node_events_before_close() { use weavegraph::event_bus::Event; - // Build graph with a node that emits events let app = GraphBuilder::new() .add_node(NodeKind::Custom("emitter".into()), EmitterNode) .add_edge(NodeKind::Start, NodeKind::Custom("emitter".into())) @@ -321,113 +283,60 @@ async fn test_invoke_with_channel_collects_events() { .compile() .unwrap(); - let initial_state = state_with_user("emit events"); - let (result, events) = app.invoke_with_channel(initial_state).await; - - // Collect events with timeout - let event_task = tokio::spawn(async move { - let mut collected = Vec::new(); - let timeout_duration = tokio::time::Duration::from_millis(100); - loop { - match tokio::time::timeout(timeout_duration, events.recv_async()).await { - Ok(Ok(event)) => collected.push(event), - Ok(Err(_)) => break, - Err(_) => break, - } + let (invocation, events) = app.invoke_streaming(state_with_user("go")).await; + let mut stream = events.into_async_stream(); + let mut node_events = 0; + let mut stream_closed = false; + while let Some(event) = stream.next().await { + if event.scope_label() == Some(STREAM_END_SCOPE) { + stream_closed = true; + } else if matches!(event, Event::Node(_)) { + node_events += 1; } - collected - }); - - // Verify workflow succeeded - result.expect("Workflow should complete"); - - // Wait for events - let collected_events = event_task.await.expect("Event task should complete"); - - // The API works - we can receive events (even if timing makes this racy) - // In production, the EventBus stays alive longer so events flow properly - if !collected_events.is_empty() { - // If we got events, verify they're the right type - let has_node_event = collected_events.iter().any(|e| matches!(e, Event::Node(_))); - assert!(has_node_event, "Should have at least one node event"); } - // Test passes if we got a valid result and receiver, regardless of timing + + assert!(stream_closed, "stream must close with STREAM_END_SCOPE"); + assert!( + node_events > 0, + "node events must arrive before stream closes" + ); + invocation.join().await.unwrap(); } #[tokio::test] -async fn test_invoke_with_sinks() { +async fn invoke_with_sinks_completes_workflow() { use weavegraph::event_bus::MemorySink; - // Build simple graph - let app = GraphBuilder::new() - .add_node( - NodeKind::Custom("test".into()), - SimpleMessageNode::new("test output"), - ) - .add_edge(NodeKind::Start, NodeKind::Custom("test".into())) - .add_edge(NodeKind::Custom("test".into()), NodeKind::End) - .compile() - .unwrap(); - - // Use MemorySink which captures synchronously (no async timing issues) - let memory_sink = MemorySink::new(); - - // Execute with custom sink - let initial_state = state_with_user("test with sinks"); + let app = message_app(); let final_state = app - .invoke_with_sinks(initial_state, vec![Box::new(memory_sink.clone())]) + .invoke_with_sinks(state_with_user("prompt"), vec![Box::new(MemorySink::new())]) .await - .expect("Workflow should complete successfully"); - - // Verify execution completed - assert!(!final_state.messages.is_empty(), "Should have messages"); - - // MemorySink should have captured events (it's synchronous in the listener loop) - // However, due to Drop abort, we might miss some events - // The test verifies the API works, not exact event counts - let _events = memory_sink.snapshot(); - // API works if we reach here without errors + .expect("workflow completes"); + assert_eq!( + final_state.messages.snapshot().last().unwrap().content, + "response" + ); } #[tokio::test] -async fn test_invoke_with_sinks_multiple() { +async fn invoke_with_sinks_accepts_multiple_sink_types() { use weavegraph::event_bus::{ChannelSink, MemorySink, StdOutSink}; - // Build simple graph - let app = GraphBuilder::new() - .add_node( - NodeKind::Custom("test".into()), - SimpleMessageNode::new("test output"), - ) - .add_edge(NodeKind::Start, NodeKind::Custom("test".into())) - .add_edge(NodeKind::Custom("test".into()), NodeKind::End) - .compile() - .unwrap(); - - // Create multiple sinks to verify the API accepts Vec> + let app = message_app(); let (tx, _rx) = flume::unbounded(); - let memory_sink = MemorySink::new(); - - // Execute with multiple sinks - this tests type compatibility - let initial_state = state_with_user("test multiple sinks"); let final_state = app .invoke_with_sinks( - initial_state, + state_with_user("prompt"), vec![ Box::new(StdOutSink::default()), Box::new(ChannelSink::new(tx)), - Box::new(memory_sink.clone()), + Box::new(MemorySink::new()), ], ) .await - .expect("Workflow should complete"); - - // Verify execution completed - assert!(!final_state.messages.is_empty(), "Should have messages"); - - // The test verifies that: - // 1. invoke_with_sinks() accepts multiple different sink types - // 2. The workflow completes successfully with multiple sinks - // 3. Type system allows Vec> as expected - // Event counting is inherently racy in tests due to EventBus Drop behavior + .expect("workflow completes with multiple sinks"); + assert_eq!( + final_state.messages.snapshot().last().unwrap().content, + "response" + ); } diff --git a/tests/channels.rs b/tests/channels.rs index 0ba6d26..fbb03e0 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -5,9 +5,7 @@ use weavegraph::channels::errors::*; use weavegraph::channels::{Channel, ErrorsChannel}; use weavegraph::types::ChannelType; -/******************** - * WeaveError tests - ********************/ +// WeaveError tests #[test] fn ladder_error_msg_and_chain() { @@ -32,48 +30,7 @@ fn ladder_error_serde_roundtrip() { assert_eq!(de, err); } -/******************** - * ErrorScope tests - ********************/ - -#[test] -fn error_scope_enum_variants_serde() { - let node = ErrorScope::Node { - kind: "Custom:Parser".into(), - step: 42, - }; - let ser_node = serde_json::to_value(&node).unwrap(); - assert_eq!(ser_node["scope"], "node"); - assert_eq!(ser_node["kind"], "Custom:Parser"); - assert_eq!(ser_node["step"], 42); - - let sch = ErrorScope::Scheduler { step: 10 }; - let ser_sch = serde_json::to_value(&sch).unwrap(); - assert_eq!(ser_sch["scope"], "scheduler"); - - let run = ErrorScope::Runner { - session: "abc".into(), - step: 7, - }; - let ser_run = serde_json::to_value(&run).unwrap(); - assert_eq!(ser_run["scope"], "runner"); - - let app = ErrorScope::App; - let ser_app = serde_json::to_value(&app).unwrap(); - assert_eq!(ser_app["scope"], "app"); - - assert_eq!( - serde_json::from_value::(ser_node).unwrap(), - node - ); - assert_eq!(serde_json::from_value::(ser_sch).unwrap(), sch); - assert_eq!(serde_json::from_value::(ser_run).unwrap(), run); - assert_eq!(serde_json::from_value::(ser_app).unwrap(), app); -} - -/******************** - * ErrorEvent tests - ********************/ +// ErrorEvent tests #[test] fn error_event_defaults_and_roundtrip() { @@ -104,9 +61,7 @@ fn error_event_defaults_are_empty_when_missing() { assert!(de.context.is_null()); } -/******************** - * ErrorEvent Constructor Tests - ********************/ +// ErrorEvent constructor tests #[test] fn error_event_node_constructor() { @@ -227,35 +182,6 @@ fn error_event_full_builder_chain() { assert_eq!(err.context["max_attempts"], 5); } -#[test] -fn error_event_constructors_serialize_correctly() { - // Test that constructed events serialize the same as manual construction - let manual = ErrorEvent { - when: Utc.with_ymd_and_hms(2024, 1, 1, 0, 0, 0).unwrap(), - scope: ErrorScope::Node { - kind: "Test".to_string(), - step: 1, - }, - error: WeaveError::msg("test"), - tags: vec!["tag1".to_string()], - context: json!({"key": "value"}), - }; - - let constructed = ErrorEvent::node("Test", 1, WeaveError::msg("test")) - .with_tag("tag1") - .with_context(json!({"key": "value"})); - - // Serialize both (ignore timestamp difference) - let manual_json = serde_json::to_value(&manual).unwrap(); - let constructed_json = serde_json::to_value(&constructed).unwrap(); - - // Compare everything except timestamp - assert_eq!(manual_json["scope"], constructed_json["scope"]); - assert_eq!(manual_json["error"], constructed_json["error"]); - assert_eq!(manual_json["tags"], constructed_json["tags"]); - assert_eq!(manual_json["context"], constructed_json["context"]); -} - #[test] fn error_event_string_into_conversions() { // Test that Into works for both &str and String @@ -282,13 +208,10 @@ fn error_event_string_into_conversions() { } } -/******************** - * Comprehensive Serialization Tests - ********************/ +// Serialization tests #[test] -fn test_error_event_serialization_all_scopes() { - // Test serialization of all ErrorScope variants +fn all_scope_variants_produce_correct_event_json() { let when = Utc.with_ymd_and_hms(2024, 1, 1, 12, 0, 0).unwrap(); // Node scope @@ -305,7 +228,6 @@ fn test_error_event_serialization_all_scopes() { assert_eq!(node_json["tags"], json!(["test"])); assert_eq!(node_json["context"]["node_id"], 42); - // Round-trip test let node_deserialized: ErrorEvent = serde_json::from_value(node_json).unwrap(); assert_eq!(node_deserialized.scope, node_event.scope); assert_eq!(node_deserialized.error, node_event.error); @@ -347,39 +269,31 @@ fn test_error_event_serialization_all_scopes() { } #[test] -fn test_ladder_error_nested_serialization() { - // Test serialization of nested WeaveError chains +fn weave_error_nested_cause_chain_serializes() { let simple_error = WeaveError::msg("simple error"); let simple_json = serde_json::to_value(&simple_error).unwrap(); assert_eq!(simple_json["message"], "simple error"); assert!(simple_json["cause"].is_null()); assert!(simple_json["details"].is_null()); - // Error with details let error_with_details = WeaveError::msg("error with details").with_details(json!({"code": 500, "retry": true})); let details_json = serde_json::to_value(&error_with_details).unwrap(); assert_eq!(details_json["message"], "error with details"); assert_eq!(details_json["details"]["code"], 500); assert_eq!(details_json["details"]["retry"], true); - - // Round-trip test let details_deserialized: WeaveError = serde_json::from_value(details_json).unwrap(); assert_eq!(details_deserialized, error_with_details); - // Error with cause let error_with_cause = WeaveError::msg("outer error").with_cause(WeaveError::msg("inner error")); let cause_json = serde_json::to_value(&error_with_cause).unwrap(); assert_eq!(cause_json["message"], "outer error"); assert_eq!(cause_json["cause"]["message"], "inner error"); assert!(cause_json["cause"]["cause"].is_null()); - - // Round-trip test let cause_deserialized: WeaveError = serde_json::from_value(cause_json).unwrap(); assert_eq!(cause_deserialized, error_with_cause); - // Deeply nested error chain let deep_error = WeaveError::msg("level 1").with_cause( WeaveError::msg("level 2") .with_cause(WeaveError::msg("level 3").with_details(json!({"deep": true}))), @@ -389,15 +303,12 @@ fn test_ladder_error_nested_serialization() { assert_eq!(deep_json["cause"]["message"], "level 2"); assert_eq!(deep_json["cause"]["cause"]["message"], "level 3"); assert_eq!(deep_json["cause"]["cause"]["details"]["deep"], true); - - // Round-trip test for complex error let deep_deserialized: WeaveError = serde_json::from_value(deep_json).unwrap(); assert_eq!(deep_deserialized, deep_error); } #[test] -fn test_error_event_full_serialization_roundtrip() { - // Test complete ErrorEvent with all fields populated +fn error_event_with_complex_nested_fields_round_trips() { let original = ErrorEvent::node( "ComplexNode", 99, @@ -420,7 +331,6 @@ fn test_error_event_full_serialization_roundtrip() { let json_str = serde_json::to_string(&original).unwrap(); let json_value: serde_json::Value = serde_json::from_str(&json_str).unwrap(); - // Verify structure assert_eq!(json_value["scope"]["scope"], "node"); assert_eq!(json_value["scope"]["kind"], "ComplexNode"); assert_eq!(json_value["scope"]["step"], 99); @@ -437,7 +347,6 @@ fn test_error_event_full_serialization_roundtrip() { "production" ); - // Deserialize back let deserialized: ErrorEvent = serde_json::from_str(&json_str).unwrap(); assert_eq!(deserialized.scope, original.scope); assert_eq!(deserialized.error, original.error); @@ -446,42 +355,34 @@ fn test_error_event_full_serialization_roundtrip() { } #[test] -fn test_error_event_schema_stability() { - // This test serves as a regression check - if the schema changes unexpectedly, - // this test will fail and alert us to review the change - +fn error_event_json_has_required_top_level_fields() { + // Regression guard: if the schema changes unexpectedly, this test will catch it. let event = ErrorEvent::node("SchemaTest", 1, WeaveError::msg("test")) .with_tag("regression") .with_context(json!({"test": true})); let json = serde_json::to_value(&event).unwrap(); - // Check that all expected top-level fields exist assert!(json.get("when").is_some(), "Missing 'when' field"); assert!(json.get("scope").is_some(), "Missing 'scope' field"); assert!(json.get("error").is_some(), "Missing 'error' field"); assert!(json.get("tags").is_some(), "Missing 'tags' field"); assert!(json.get("context").is_some(), "Missing 'context' field"); - // Check scope structure let scope = &json["scope"]; assert!( scope.get("scope").is_some(), "Missing 'scope.scope' discriminator" ); - // Check error structure let error = &json["error"]; assert!( error.get("message").is_some(), "Missing 'error.message' field" ); - // Note: cause and details may be omitted if null/empty due to skip_serializing_if + // Note: cause and details may be omitted when null due to skip_serializing_if - // Ensure tags is an array assert!(json["tags"].is_array(), "tags should be an array"); - - // Ensure context can be any JSON value assert!( json["context"].is_object() || json["context"].is_null(), "context should be object or null" @@ -489,8 +390,7 @@ fn test_error_event_schema_stability() { } #[test] -fn test_error_scope_variants_complete_coverage() { - // Ensure all scope variants serialize and deserialize correctly +fn all_error_scope_variants_serialize_and_round_trip() { let test_cases = vec![ ( ErrorScope::Node { @@ -514,23 +414,18 @@ fn test_error_scope_variants_complete_coverage() { ]; for (scope, expected_json) in test_cases { - // Serialize let serialized = serde_json::to_value(&scope).unwrap(); assert_eq!( serialized, expected_json, "Serialization mismatch for {:?}", scope ); - - // Deserialize let deserialized: ErrorScope = serde_json::from_value(serialized).unwrap(); assert_eq!(deserialized, scope, "Round-trip failed for {:?}", scope); } } -/******************** - * pretty_print tests - ********************/ +// pretty_print tests #[test] fn pretty_print_renders_usefully() { @@ -556,9 +451,7 @@ fn pretty_print_renders_usefully() { assert!(out.contains("/tmp/x")); } -/******************** - * ErrorsChannel tests - ********************/ +// ErrorsChannel tests #[test] fn errors_channel_basics() { @@ -571,13 +464,11 @@ fn errors_channel_basics() { let when = Utc::now(); - // Add first error using scheduler constructor let err1 = ErrorEvent::scheduler(1, WeaveError::msg("first")); let mut err1_with_time = err1; err1_with_time.when = when; ch.get_mut().push(err1_with_time); - // Add second error using node constructor with builder let err2 = ErrorEvent::node("Start", 2, WeaveError::msg("second")) .with_tag("retryable") .with_context(json!({"try":2})); @@ -609,17 +500,14 @@ fn errors_channel_new_constructor() { } #[test] -fn optional_cli_pretty_demo() { +fn pretty_print_formats_app_event_with_context() { let when = Utc.with_ymd_and_hms(2024, 2, 2, 2, 2, 2).unwrap(); let mut event = ErrorEvent::app(WeaveError::msg("display")) .with_tag("cli") .with_context(json!({})); event.when = when; - let events = vec![event]; - - let out = pretty_print(&events); - println!("\n=== Errors pretty showcase ===\n{}", out); + let out = pretty_print(&vec![event]); assert!(out.contains("display")); } diff --git a/tests/event_bus.rs b/tests/event_bus.rs index efa2e5c..4450378 100644 --- a/tests/event_bus.rs +++ b/tests/event_bus.rs @@ -58,7 +58,6 @@ async fn memory_sink_captures_events_with_scope_and_messages() { let emitter = bus.get_emitter(); - // Same scope twice emitter .emit(Event::node_message("Scope1", "one")) .expect("emit one"); @@ -66,7 +65,6 @@ async fn memory_sink_captures_events_with_scope_and_messages() { .emit(Event::node_message("Scope1", "two")) .expect("emit two"); - // Different scope emitter .emit(Event::diagnostic("Scope2", "three")) .expect("emit three"); @@ -80,7 +78,6 @@ async fn memory_sink_captures_events_with_scope_and_messages() { let entries = sink_snapshot.snapshot(); assert_eq!(entries.len(), 4); - // Verify events captured with correct scope and message assert_eq!(entries[0].scope_label(), Some("Scope1")); assert_eq!(entries[0].message(), "one"); @@ -105,7 +102,6 @@ async fn multiple_listen_calls_are_idempotent() { bus.listen_for_events(); bus.listen_for_events(); - // Emit a couple of events and ensure we don't get duplicate output. let emitter = bus.get_emitter(); emitter.emit(Event::node_message("S", "a")).unwrap(); emitter.emit(Event::node_message("S", "b")).unwrap(); @@ -146,7 +142,6 @@ async fn memory_sink_preserves_order_under_concurrency() { let _ = h.await; } - // Allow listener to drain the channel. tokio::time::sleep(std::time::Duration::from_millis((total * 3) as u64)).await; bus.stop_listener().await; @@ -195,10 +190,8 @@ async fn multi_sink_broadcast() { .emit(Event::diagnostic("test", "broadcast message")) .unwrap(); - // Give listener time to process tokio::time::sleep(std::time::Duration::from_millis(50)).await; - // Both sinks received the event let memory_events = memory.snapshot(); assert_eq!(memory_events.len(), 1); assert_eq!(memory_events[0].message(), "broadcast message"); @@ -228,12 +221,10 @@ async fn add_sink_dynamically() { #[tokio::test] async fn channel_sink_handles_dropped_receiver() { use std::io::ErrorKind; - use weavegraph::event_bus::sink::EventSink; let (tx, rx) = flume::unbounded(); let mut sink = ChannelSink::new(tx); - // Drop receiver drop(rx); let event = Event::diagnostic("test", "msg"); @@ -346,14 +337,12 @@ async fn stop_listener_drains_multiple_sinks() { tokio::time::sleep(Duration::from_millis(50)).await; bus.stop_listener().await; - // Both sinks should have received all events assert_eq!(snapshot1.snapshot().len(), 10); assert_eq!(snapshot2.snapshot().len(), 10); } #[tokio::test] async fn stop_listener_during_emission() { - use std::sync::Arc; use tokio::task; let bus = Arc::new(EventBus::with_sink(MemorySink::new())); @@ -368,9 +357,7 @@ async fn stop_listener_during_emission() { }); tokio::time::sleep(std::time::Duration::from_millis(20)).await; - // Should not panic and should shut down cleanly bus.stop_listener().await; - // Clean up emission task if still running emit_task.abort(); } @@ -382,7 +369,6 @@ async fn restart_after_stop() { let snapshot = sink.clone(); let bus = EventBus::with_sink(sink); - // First cycle bus.listen_for_events(); bus.get_emitter() .emit(Event::diagnostic("cycle1", "msg1")) @@ -390,7 +376,6 @@ async fn restart_after_stop() { tokio::time::sleep(Duration::from_millis(10)).await; bus.stop_listener().await; - // Second cycle bus.listen_for_events(); bus.get_emitter() .emit(Event::diagnostic("cycle2", "msg2")) @@ -494,12 +479,6 @@ struct RecordingEmitter { } impl RecordingEmitter { - fn new() -> Self { - Self { - events: Arc::new(Mutex::new(Vec::new())), - } - } - fn record(&self, event: Event) { self.events .lock() @@ -539,7 +518,7 @@ impl EventEmitter for RecordingEmitter { #[test] fn node_context_emits_all_event_variants() { - let emitter = Arc::new(RecordingEmitter::new()); + let emitter = Arc::new(RecordingEmitter::default()); let event_emitter: Arc = emitter.clone(); let ctx = NodeContext::new("node-a", 7, event_emitter); @@ -738,11 +717,8 @@ proptest! { } } -// ============================================================================ -// JSON Serialization Tests -// ============================================================================ +// JSON serialization -// Helper for shared writer in tests struct SharedWriter(Arc>>>); impl std::io::Write for SharedWriter { @@ -759,7 +735,7 @@ impl std::io::Write for SharedWriter { } #[test] -fn test_node_event_to_json_value() { +fn node_event_serializes_type_scope_message_and_metadata() { let event = Event::node_message_with_meta("router", 5, "routing", "Processing request"); let json = event.to_json_value(); @@ -772,8 +748,7 @@ fn test_node_event_to_json_value() { } #[test] -fn test_node_event_partial_metadata() { - // Node without node_id or step +fn node_event_without_node_id_or_step_serializes_with_null_fields() { let event = Event::node_message("test_scope", "test message"); let json = event.to_json_value(); @@ -786,7 +761,7 @@ fn test_node_event_partial_metadata() { } #[test] -fn test_diagnostic_event_to_json_value() { +fn diagnostic_event_serializes_type_scope_message_and_empty_metadata() { let event = Event::diagnostic("error_scope", "Something went wrong"); let json = event.to_json_value(); @@ -794,14 +769,13 @@ fn test_diagnostic_event_to_json_value() { assert_eq!(json["scope"], "error_scope"); assert_eq!(json["message"], "Something went wrong"); assert!(json["timestamp"].is_string()); - // Diagnostic events have minimal metadata assert!(json["metadata"].is_object()); let metadata = json["metadata"].as_object().unwrap(); assert!(metadata.is_empty()); } #[test] -fn test_llm_event_to_json_value() { +fn llm_event_serializes_all_fields_to_json_value() { let mut metadata = FxHashMap::default(); metadata.insert("content_type".to_string(), json!("reasoning")); metadata.insert("token_count".to_string(), json!(42)); @@ -829,7 +803,7 @@ fn test_llm_event_to_json_value() { } #[test] -fn test_llm_event_final_chunk() { +fn llm_event_final_chunk_sets_is_final_and_nulls_optional_ids() { let llm_event = LLMStreamingEvent::builder("Final chunk") .stream_id("stream-999") .is_final(true) @@ -845,11 +819,10 @@ fn test_llm_event_final_chunk() { } #[test] -fn test_to_json_string_compact() { +fn json_string_output_is_compact() { let event = Event::diagnostic("test", "message"); let json_str = event.to_json_string().unwrap(); - // Compact format has no extra whitespace assert!(json_str.contains("\"type\":\"diagnostic\"")); assert!(json_str.contains("\"scope\":\"test\"")); assert!(json_str.contains("\"message\":\"message\"")); @@ -857,18 +830,17 @@ fn test_to_json_string_compact() { } #[test] -fn test_to_json_pretty_formatted() { +fn json_pretty_output_has_indentation() { let event = Event::node_message("test", "hello"); let json_str = event.to_json_pretty().unwrap(); - // Pretty format has indentation assert!(json_str.contains(" \"type\": \"node\"")); assert!(json_str.contains(" \"scope\": \"test\"")); assert!(json_str.contains(" \"message\": \"hello\"")); } #[test] -fn test_json_roundtrip_via_to_json_string() { +fn json_string_round_trips_to_parsed_value() { let original = Event::node_message_with_meta("node1", 10, "scope1", "msg1"); let json_str = original.to_json_string().unwrap(); let parsed: Value = serde_json::from_str(&json_str).unwrap(); @@ -879,10 +851,9 @@ fn test_json_roundtrip_via_to_json_string() { } #[tokio::test] -async fn test_jsonlines_sink_stdout() { +async fn jsonlines_sink_writes_one_line_per_event() { use std::io::Cursor; - // Create in-memory buffer to capture output let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); let buffer_clone = buffer.clone(); @@ -894,20 +865,17 @@ async fn test_jsonlines_sink_stdout() { sink.handle(&event1).unwrap(); sink.handle(&event2).unwrap(); - // Extract buffer contents let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); let output = String::from_utf8(locked.get_ref().clone()).unwrap(); let lines: Vec<&str> = output.lines().collect(); assert_eq!(lines.len(), 2); - // Parse first line let json1: Value = serde_json::from_str(lines[0]).unwrap(); assert_eq!(json1["type"], "diagnostic"); assert_eq!(json1["scope"], "test1"); assert_eq!(json1["message"], "first message"); - // Parse second line let json2: Value = serde_json::from_str(lines[1]).unwrap(); assert_eq!(json2["type"], "node"); assert_eq!(json2["scope"], "test2"); @@ -915,7 +883,7 @@ async fn test_jsonlines_sink_stdout() { } #[tokio::test] -async fn test_jsonlines_sink_pretty_print() { +async fn jsonlines_sink_pretty_print_indents_output() { use std::io::Cursor; let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); @@ -935,7 +903,7 @@ async fn test_jsonlines_sink_pretty_print() { } #[tokio::test] -async fn test_jsonlines_sink_file_output() { +async fn jsonlines_sink_writes_events_to_file() { use std::fs; let temp_file = tempfile::NamedTempFile::new().unwrap(); @@ -953,7 +921,6 @@ async fn test_jsonlines_sink_file_output() { sink.handle(&event3).unwrap(); } // sink dropped, file flushed - // Read file contents let contents = fs::read_to_string(&path).unwrap(); let lines: Vec<&str> = contents.lines().collect(); @@ -971,7 +938,7 @@ async fn test_jsonlines_sink_file_output() { } #[tokio::test] -async fn test_jsonlines_sink_flush_behavior() { +async fn jsonlines_sink_flushes_after_each_event() { use std::io::Cursor; let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); @@ -982,8 +949,7 @@ async fn test_jsonlines_sink_flush_behavior() { let event = Event::diagnostic("flush_test", "should be flushed immediately"); sink.handle(&event).unwrap(); - // After single handle() call, buffer should already contain the event - // because EventSink::handle flushes after each event + // handle() flushes after each event, so the buffer is immediately readable let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); let output = String::from_utf8(locked.get_ref().clone()).unwrap(); @@ -991,11 +957,10 @@ async fn test_jsonlines_sink_flush_behavior() { } #[tokio::test] -async fn test_jsonlines_sink_with_eventbus() { +async fn jsonlines_sink_with_event_bus_captures_all_events() { let buffer = Arc::new(Mutex::new(std::io::Cursor::new(Vec::new()))); let buffer_clone = buffer.clone(); - // Create sink with shared buffer let sink = JsonLinesSink::new(Box::new(SharedWriter(buffer))); let bus = EventBus::with_sink(sink); bus.listen_for_events(); @@ -1011,7 +976,6 @@ async fn test_jsonlines_sink_with_eventbus() { tokio::time::sleep(Duration::from_millis(50)).await; bus.stop_listener().await; - // Extract and verify output let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); let output = String::from_utf8(locked.get_ref().clone()).unwrap(); let lines: Vec<&str> = output.lines().collect(); diff --git a/tests/event_bus_diagnostics.rs b/tests/event_bus_diagnostics.rs index d362fe4..5a6d215 100644 --- a/tests/event_bus_diagnostics.rs +++ b/tests/event_bus_diagnostics.rs @@ -1,3 +1,8 @@ +//! Integration tests for the event bus diagnostics subsystem. +//! +//! Covers sink-failure reporting, health snapshot aggregation, custom sink naming, +//! and the `emit_to_events` configuration flag. + use std::io; use weavegraph::event_bus::{Event, EventBus, EventSink}; @@ -42,7 +47,7 @@ fn bus_with_diagnostics( } #[tokio::test] -async fn diagnostics_happy_path_and_lagged_receiver() { +async fn diagnostic_received_on_sink_failure_and_lag_is_tolerated() { // Create a bus with a failing sink and a small diagnostics buffer to provoke lag. let bus = bus_with_diagnostics(true, false, 4, 1); bus.add_sink(FailingSink); @@ -132,7 +137,7 @@ async fn sink_naming_default_and_override() { } #[tokio::test] -async fn emit_to_events_toggle_behavior() { +async fn diagnostics_appear_in_event_stream_only_when_enabled() { // 1) emit_to_events = false: no diagnostics in main event stream let bus_off = bus_with_diagnostics(true, false, 8, 8); bus_off.add_sink(FailingSink); diff --git a/tests/event_bus_stress.rs b/tests/event_bus_stress.rs index 8ba908f..98207f2 100644 --- a/tests/event_bus_stress.rs +++ b/tests/event_bus_stress.rs @@ -61,7 +61,7 @@ fn make_burst_app(event_count: usize, counter: Arc) -> weavegraph:: } #[tokio::test] -async fn test_high_volume_event_emission() { +async fn bus_delivers_most_events_under_high_load() { let sink = MemorySink::new(); let sink_snapshot = sink.clone(); let bus = EventBus::with_sink(sink); @@ -91,7 +91,7 @@ async fn test_high_volume_event_emission() { } #[tokio::test] -async fn test_burst_node_emission() { +async fn burst_node_emits_all_events_on_run() { let counter = Arc::new(AtomicUsize::new(0)); let app = make_burst_app(100, counter.clone()); let mut runner = AppRunner::builder() @@ -120,7 +120,7 @@ async fn test_burst_node_emission() { } #[tokio::test] -async fn test_multiple_sinks() { +async fn all_sinks_receive_all_emitted_events() { let sink1 = MemorySink::new(); let sink2 = MemorySink::new(); let snap1 = sink1.clone(); @@ -145,7 +145,7 @@ async fn test_multiple_sinks() { } #[tokio::test] -async fn test_emit_after_stop_behavior() { +async fn events_emitted_before_stop_are_captured() { let sink = MemorySink::new(); let snap = sink.clone(); let bus = EventBus::with_sink(sink); @@ -165,7 +165,7 @@ async fn test_emit_after_stop_behavior() { } #[tokio::test] -async fn test_rapid_start_stop_cycles() { +async fn bus_survives_repeated_start_stop_cycles() { let sink = MemorySink::new(); let snap = sink.clone(); let bus = EventBus::with_sink(sink); @@ -184,7 +184,7 @@ async fn test_rapid_start_stop_cycles() { } #[tokio::test] -async fn test_event_ordering() { +async fn events_arrive_in_emission_order() { let sink = MemorySink::new(); let snap = sink.clone(); let bus = EventBus::with_sink(sink); @@ -215,7 +215,7 @@ async fn test_event_ordering() { } #[tokio::test] -async fn test_metrics_reflect_emissions() { +async fn metrics_reports_nonzero_capacity_and_no_drops() { let bus = EventBus::with_sink(MemorySink::new()); bus.listen_for_events(); @@ -228,6 +228,9 @@ async fn test_metrics_reflect_emissions() { bus.stop_listener().await; let metrics = bus.metrics(); - // Metrics should be valid assert!(metrics.capacity > 0); + assert_eq!( + metrics.dropped, 0, + "no events should be dropped under moderate load" + ); } diff --git a/tests/graphs.rs b/tests/graphs.rs index 5079c7f..38b9606 100644 --- a/tests/graphs.rs +++ b/tests/graphs.rs @@ -2,7 +2,7 @@ mod common; use common::*; use std::sync::Arc; -use weavegraph::graphs::{EdgePredicate, GraphBuilder}; +use weavegraph::graphs::{EdgePredicate, GraphBuilder, GraphCompileError}; use weavegraph::node::NodePartial; use weavegraph::reducers::Reducer; use weavegraph::state::VersionedState; @@ -41,7 +41,7 @@ impl Reducer for StableLabelReducerB { } #[test] -fn test_add_conditional_edge() { +fn conditional_edge_is_accessible_and_predicate_evaluates_correctly() { let route_to_y: EdgePredicate = std::sync::Arc::new(|_s| vec!["Y".to_string()]); let app = GraphBuilder::new() .add_node(NodeKind::Custom("Y".into()), NoopNode) @@ -61,14 +61,15 @@ fn test_add_conditional_edge() { } #[test] -fn test_graph_builder_new() { - let err = GraphBuilder::new().compile().err().unwrap(); - // Expect MissingEntry; structural validation prevents compiling empty graphs - let _ = err; // just ensure it returns an error; specific variant tested elsewhere +fn empty_graph_compilation_fails_with_missing_entry() { + assert!(matches!( + GraphBuilder::new().compile().err().unwrap(), + GraphCompileError::MissingEntry + )); } #[test] -fn test_add_node() { +fn compiled_graph_contains_registered_nodes() { let app = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -85,7 +86,7 @@ fn test_add_node() { } #[test] -fn test_add_edge() { +fn compiled_graph_contains_registered_edges() { let app = GraphBuilder::new() .add_node(NodeKind::Custom("C".to_string()), NoopNode) .add_edge(NodeKind::Start, NodeKind::End) @@ -101,7 +102,7 @@ fn test_add_edge() { } #[test] -fn test_compile() { +fn minimal_start_to_end_graph_compiles() { let gb = GraphBuilder::new().add_edge(NodeKind::Start, NodeKind::End); let app = gb.compile().unwrap(); assert_eq!(app.edges().len(), 1); @@ -114,7 +115,7 @@ fn test_compile() { } #[test] -fn test_graph_metadata_and_hash_change_with_definition() { +fn graph_metadata_reflects_structure_and_hash_differs_for_distinct_definitions() { let app_a = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) @@ -137,7 +138,7 @@ fn test_graph_metadata_and_hash_change_with_definition() { } #[test] -fn test_graph_hash_changes_with_reducer_identity() { +fn graph_hash_differs_when_reducer_type_changes() { let app_a = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .with_reducer(ChannelType::Extra, Arc::new(FirstExtraReducer)) @@ -157,7 +158,7 @@ fn test_graph_hash_changes_with_reducer_identity() { } #[test] -fn test_graph_hash_is_stable_for_equivalent_definition_ordering() { +fn graph_hash_is_stable_regardless_of_builder_insertion_order() { let app_a = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -181,7 +182,7 @@ fn test_graph_hash_is_stable_for_equivalent_definition_ordering() { } #[test] -fn test_graph_hash_changes_with_conditional_edge_registration_count() { +fn graph_hash_differs_when_conditional_edge_is_added() { let route_to_end: EdgePredicate = Arc::new(|_snapshot| vec!["End".to_string()]); let app_without_conditional = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) @@ -208,7 +209,7 @@ fn test_graph_hash_changes_with_conditional_edge_registration_count() { } #[test] -fn test_graph_hash_uses_custom_reducer_definition_label() { +fn graph_hash_reflects_custom_reducer_definition_label() { let app_a = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .with_reducer(ChannelType::Extra, Arc::new(StableLabelReducerA)) @@ -235,21 +236,7 @@ fn test_graph_hash_uses_custom_reducer_definition_label() { } #[test] -fn test_compile_missing_entry() { - let gb = GraphBuilder::new().add_edge(NodeKind::Start, NodeKind::End); - let app = gb.compile().unwrap(); - assert!(app.edges().get(&NodeKind::Start).is_some()); -} - -#[test] -fn test_compile_entry_not_registered() { - let gb = GraphBuilder::new().add_edge(NodeKind::Start, NodeKind::End); - let app = gb.compile().unwrap(); - assert_eq!(app.edges().len(), 1); -} - -#[test] -fn test_nodekind_other_variant() { +fn custom_nodekind_equality_compares_inner_string() { let k1 = NodeKind::Custom("foo".to_string()); let k2 = NodeKind::Custom("foo".to_string()); let k3 = NodeKind::Custom("bar".to_string()); @@ -258,28 +245,7 @@ fn test_nodekind_other_variant() { } #[test] -fn test_duplicate_edges_rejected() { - // Duplicate edges should now be rejected by validation - use weavegraph::graphs::GraphCompileError; - - let result = GraphBuilder::new() - .add_edge(NodeKind::Start, NodeKind::End) - .add_edge(NodeKind::Start, NodeKind::End) - .compile(); - - assert!(result.is_err()); - let err = result.err().unwrap(); - matches!(err, GraphCompileError::DuplicateEdge { .. }); -} - -#[test] -fn test_builder_fluent_api() { - let final_builder = GraphBuilder::new().add_edge(NodeKind::Start, NodeKind::End); - let _app = final_builder.compile().unwrap(); -} - -#[test] -fn test_runtime_config_integration() { +fn graph_builder_accepts_runtime_config() { use weavegraph::runtimes::RuntimeConfig; let config = RuntimeConfig::new(Some("test_session".into()), None); @@ -291,14 +257,10 @@ fn test_runtime_config_integration() { let _app = builder.compile().unwrap(); } -// ============================================================================ -// Enhanced Validation Tests (Directive 1) -// ============================================================================ +// Validation tests #[test] -fn test_cycle_detection_simple_cycle() { - use weavegraph::graphs::GraphCompileError; - +fn compile_rejects_graph_containing_a_cycle() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -311,7 +273,6 @@ fn test_cycle_detection_simple_cycle() { match result.err().unwrap() { GraphCompileError::CycleDetected { cycle } => { assert!(!cycle.is_empty()); - // Verify cycle contains A and B assert!(cycle.contains(&NodeKind::Custom("A".into()))); assert!(cycle.contains(&NodeKind::Custom("B".into()))); } @@ -320,9 +281,7 @@ fn test_cycle_detection_simple_cycle() { } #[test] -fn test_cycle_detection_self_loop() { - use weavegraph::graphs::GraphCompileError; - +fn compile_rejects_graph_containing_a_self_loop() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) @@ -340,8 +299,7 @@ fn test_cycle_detection_self_loop() { } #[test] -fn test_cycle_detection_no_cycle() { - // Linear graph should pass +fn compile_accepts_acyclic_linear_graph() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -354,9 +312,7 @@ fn test_cycle_detection_no_cycle() { } #[test] -fn test_unreachable_nodes_detection() { - use weavegraph::graphs::GraphCompileError; - +fn compile_rejects_graph_with_unreachable_nodes() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -378,8 +334,7 @@ fn test_unreachable_nodes_detection() { } #[test] -fn test_unreachable_nodes_all_reachable() { - // All nodes reachable should pass +fn compile_accepts_fully_connected_graph() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -392,9 +347,7 @@ fn test_unreachable_nodes_all_reachable() { } #[test] -fn test_no_path_to_end_detection() { - use weavegraph::graphs::GraphCompileError; - +fn compile_rejects_graph_with_nodes_having_no_path_to_end() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -416,8 +369,7 @@ fn test_no_path_to_end_detection() { } #[test] -fn test_no_path_to_end_all_paths_valid() { - // All nodes can reach End should pass +fn compile_accepts_graph_where_all_nodes_reach_end() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) @@ -428,12 +380,10 @@ fn test_no_path_to_end_all_paths_valid() { } #[test] -fn test_duplicate_edge_detection() { - use weavegraph::graphs::GraphCompileError; - +fn compile_rejects_duplicate_edges_with_error_detail() { let result = GraphBuilder::new() .add_edge(NodeKind::Start, NodeKind::End) - .add_edge(NodeKind::Start, NodeKind::End) // Duplicate + .add_edge(NodeKind::Start, NodeKind::End) .compile(); assert!(result.is_err()); @@ -447,13 +397,12 @@ fn test_duplicate_edge_detection() { } #[test] -fn test_duplicate_edge_with_different_targets() { - // Multiple different targets from same source should be allowed +fn compile_allows_multiple_edges_from_same_source_node() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) - .add_edge(NodeKind::Start, NodeKind::Custom("B".into())) // Different target, OK + .add_edge(NodeKind::Start, NodeKind::Custom("B".into())) .add_edge(NodeKind::Custom("A".into()), NodeKind::End) .add_edge(NodeKind::Custom("B".into()), NodeKind::End) .compile(); @@ -462,8 +411,7 @@ fn test_duplicate_edge_with_different_targets() { } #[test] -fn test_happy_path_simple_graph() { - // Verify a simple valid graph passes all validations +fn simple_process_node_graph_compiles() { let result = GraphBuilder::new() .add_node(NodeKind::Custom("process".into()), NoopNode) .add_edge(NodeKind::Start, NodeKind::Custom("process".into())) @@ -473,22 +421,10 @@ fn test_happy_path_simple_graph() { assert!(result.is_ok()); } -#[test] -fn test_happy_path_start_to_end_direct() { - // Direct Start -> End should pass - let result = GraphBuilder::new() - .add_edge(NodeKind::Start, NodeKind::End) - .compile(); - - assert!(result.is_ok()); -} - -// ============================================================================ -// Graph Iteration Tests (Phase 3.1) -// ============================================================================ +// Iteration and traversal tests #[test] -fn test_nodes_iterator() { +fn builder_nodes_iter_yields_all_registered_nodes() { let builder = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -508,7 +444,7 @@ fn test_nodes_iterator() { } #[test] -fn test_edges_iterator() { +fn builder_edges_iter_yields_all_registered_edges() { let builder = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_edge(NodeKind::Start, NodeKind::Custom("A".into())) @@ -523,7 +459,7 @@ fn test_edges_iterator() { } #[test] -fn test_node_count_and_edge_count() { +fn builder_reports_correct_node_and_edge_counts() { let builder = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -537,7 +473,7 @@ fn test_node_count_and_edge_count() { } #[test] -fn test_topological_sort_basic() { +fn topological_sort_respects_linear_dependency_order() { let builder = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -571,7 +507,7 @@ fn test_topological_sort_basic() { } #[test] -fn test_topological_sort_fan_out() { +fn topological_sort_places_start_first_and_end_last() { let builder = GraphBuilder::new() .add_node(NodeKind::Custom("A".into()), NoopNode) .add_node(NodeKind::Custom("B".into()), NoopNode) @@ -615,7 +551,7 @@ fn test_topological_sort_fan_out() { } #[test] -fn test_topological_sort_deterministic() { +fn topological_sort_is_deterministic_and_lexicographic() { let builder = GraphBuilder::new() .add_node(NodeKind::Custom("Z".into()), NoopNode) .add_node(NodeKind::Custom("Y".into()), NoopNode) diff --git a/tests/graphs_property.rs b/tests/graphs_property.rs index 583481f..938484e 100644 --- a/tests/graphs_property.rs +++ b/tests/graphs_property.rs @@ -20,15 +20,6 @@ fn node_name_strategy() -> impl Strategy { }) } -// Minimal sanity property using the generator (real graph properties will follow in later steps) -proptest! { - #[test] - fn prop_node_name_non_empty(name in node_name_strategy()) { - prop_assert!(!name.is_empty()); - prop_assert!(name.chars().next().unwrap().is_ascii_alphabetic()); - } -} - mod common; use common::*; @@ -87,14 +78,12 @@ proptest! { let nf: FxHashSet<_> = rep.next_frontier.into_iter().collect(); - // All predicate targets must appear (translated) let allowed: FxHashSet<_> = names.clone().into_iter().collect(); for n in names.clone() { assert!(nf.contains(&NodeKind::Custom(n))); } if include_end { assert!(nf.contains(&NodeKind::End)); } - // Frontier must not contain unknown custom nodes for k in nf { if let NodeKind::Custom(s) = k { assert!(allowed.contains(&s)); } } @@ -137,10 +126,8 @@ proptest! { let rep = match runner.run_step("sess_mix", StepOptions::default()).await.unwrap() { StepResult::Completed(rep) => rep, _ => unreachable!() }; let nf: FxHashSet<_> = rep.next_frontier.into_iter().collect(); - // Valid appear for n in &valid { assert!(nf.contains(&NodeKind::Custom(n.clone()))); } assert!(nf.contains(&NodeKind::End)); - // Invalid never appear for n in &invalid { assert!(!nf.contains(&NodeKind::Custom(n.clone()))); } }); } @@ -181,17 +168,14 @@ proptest! { for k in rep.next_frontier { if let NodeKind::Custom(s) = k { *counts.entry(s).or_insert(0) += 1; } } - // Each targeted custom node should appear at most once for n in pool { assert!(counts.get(&n).cloned().unwrap_or(0) <= 1); } }); } } -// ============================================================================ // Additional property tests for conditional edge routing -// ============================================================================ -/// Generate a valid key for extra data that can be used in predicates. +// Generate a valid key for extra data that can be used in predicates. fn extra_key_strategy() -> impl Strategy { prop::string::string_regex("[a-z][a-z0-9_]{0,8}").unwrap() } diff --git a/tests/messages.rs b/tests/messages.rs index 8e59e6c..4c6c47c 100644 --- a/tests/messages.rs +++ b/tests/messages.rs @@ -1,14 +1,14 @@ use weavegraph::message::{Message, Role}; #[test] -fn test_message_construction() { +fn message_with_role_stores_role_and_content() { let msg = Message::with_role(Role::User, "hello"); assert_eq!(msg.role, Role::User); assert_eq!(msg.content, "hello"); } #[test] -fn test_convenience_constructors() { +fn convenience_constructors_produce_typed_role_messages() { let user_msg = Message::user("Hello"); assert_eq!(user_msg.role, Role::User); assert_eq!(user_msg.content, "Hello"); @@ -27,14 +27,14 @@ fn test_convenience_constructors() { } #[test] -fn test_role_checking() { +fn user_message_role_is_user_not_assistant() { let user_msg = Message::user("Hello"); assert_eq!(user_msg.role, Role::User); assert_ne!(user_msg.role, Role::Assistant); } #[test] -fn test_serialization() { +fn message_survives_json_roundtrip() { let original = Message::user("Test message"); let json = serde_json::to_string(&original).expect("Serialization failed"); let deserialized: Message = serde_json::from_str(&json).expect("Deserialization failed"); diff --git a/tests/nodes.rs b/tests/nodes.rs index 1792f39..99d2b53 100644 --- a/tests/nodes.rs +++ b/tests/nodes.rs @@ -16,14 +16,14 @@ fn make_ctx(step: u64) -> (NodeContext, EventBus) { } #[tokio::test] -async fn test_node_context_creation() { +async fn node_context_exposes_id_and_step() { let (ctx, _event_bus) = make_ctx(5); assert_eq!(ctx.node_id, "test-node"); assert_eq!(ctx.step, 5); } #[test] -fn test_node_partial_default() { +fn node_partial_default_has_all_fields_none() { let np = NodePartial::default(); assert!(np.messages.is_none()); assert!(np.extra.is_none()); @@ -31,7 +31,7 @@ fn test_node_partial_default() { } #[test] -fn test_node_partial_with_messages() { +fn node_partial_with_messages_leaves_other_fields_none() { let messages = vec![Message::with_role( Role::Custom("test".to_string()), "test message", @@ -43,7 +43,7 @@ fn test_node_partial_with_messages() { } #[test] -fn test_node_partial_with_extra() { +fn node_partial_with_extra_leaves_other_fields_none() { let mut extra = new_extra_map(); extra.insert("test_key".to_string(), serde_json::json!("test_value")); @@ -54,7 +54,7 @@ fn test_node_partial_with_extra() { } #[test] -fn test_node_partial_with_errors() { +fn node_partial_with_errors_leaves_other_fields_none() { let errors = vec![ErrorEvent::default()]; let partial = NodePartial::new().with_errors(errors.clone()); assert!(partial.messages.is_none()); @@ -63,24 +63,22 @@ fn test_node_partial_with_errors() { } #[tokio::test] -async fn test_node_context_emit_error() { +async fn emit_fails_when_event_bus_dropped() { let (ctx, event_bus) = make_ctx(1); - drop(event_bus); // Drop the event bus to disconnect sender + drop(event_bus); tokio::task::yield_now().await; let result = ctx.emit("scope", "message"); assert!(matches!(result, Err(NodeContextError::EventBusUnavailable))); } #[test] -fn test_node_error_variants() { - // MissingInput +fn node_error_variants_expose_correct_fields() { let err = NodeError::MissingInput { what: "field" }; match err { NodeError::MissingInput { what } => assert_eq!(what, "field"), _ => panic!("Wrong variant"), } - // Provider let err = NodeError::Provider { provider: "svc", message: "fail".to_string(), @@ -93,7 +91,6 @@ fn test_node_error_variants() { _ => panic!("Wrong variant"), } - // Serde let json_err = serde_json::from_str::("not_json").unwrap_err(); let err = NodeError::Serde(json_err); match err { @@ -101,21 +98,18 @@ fn test_node_error_variants() { _ => panic!("Wrong variant"), } - // ValidationFailed let err = NodeError::ValidationFailed("bad input".to_string()); match err { NodeError::ValidationFailed(msg) => assert_eq!(msg, "bad input"), _ => panic!("Wrong variant"), } - // EventBus let err = NodeError::EventBus(NodeContextError::EventBusUnavailable); match err { NodeError::EventBus(NodeContextError::EventBusUnavailable) => (), _ => panic!("Wrong variant"), } - // Other let err = NodeError::other(std::io::Error::other("boom")); match err { NodeError::Other(inner) => assert_eq!(inner.to_string(), "boom"), @@ -124,7 +118,7 @@ fn test_node_error_variants() { } #[test] -fn test_node_result_ext_maps_external_error() { +fn external_error_maps_to_node_other_error() { let result: std::result::Result = Err(std::io::Error::other("io boom")); let err = result.node_err().unwrap_err(); @@ -134,14 +128,6 @@ fn test_node_result_ext_maps_external_error() { } } -#[test] -fn test_node_context_error_variant() { - let err = NodeContextError::EventBusUnavailable; - match err { - NodeContextError::EventBusUnavailable => (), - } -} - struct DummyNode; #[async_trait] impl Node for DummyNode { @@ -159,7 +145,7 @@ impl Node for DummyNode { } #[tokio::test] -async fn test_node_trait_success() { +async fn node_returns_partial_with_expected_message_role() { let (ctx, _event_bus) = make_ctx(0); let node = DummyNode; let snapshot = VersionedState::new_with_user_message("dummy").snapshot(); @@ -173,9 +159,9 @@ async fn test_node_trait_success() { } #[tokio::test] -async fn test_node_trait_eventbus_error() { +async fn node_run_propagates_event_bus_disconnect_as_error() { let (ctx, event_bus) = make_ctx(0); - drop(event_bus); // disconnect event bus + drop(event_bus); tokio::task::yield_now().await; let node = DummyNode; let snapshot = VersionedState::new_with_user_message("dummy").snapshot(); diff --git a/tests/reducers.rs b/tests/reducers.rs index 8de3330..295bf60 100644 --- a/tests/reducers.rs +++ b/tests/reducers.rs @@ -12,34 +12,14 @@ mod common; use common::*; use weavegraph::types::ChannelType; -// Fresh baseline state helper fn base_state() -> VersionedState { state_with_user("a") } -// Local guard prototype mirroring runtime logic -fn channel_guard(channel: ChannelType, partial: &NodePartial) -> bool { - match channel { - ChannelType::Message => partial - .messages - .as_ref() - .map(|v| !v.is_empty()) - .unwrap_or(false), - ChannelType::Extra => partial - .extra - .as_ref() - .map(|m| !m.is_empty()) - .unwrap_or(false), - ChannelType::Error => false, - } -} - -/******************** - * AddMessages tests - ********************/ +// AddMessages #[test] -fn test_add_messages_appends_state() { +fn add_messages_appends_to_existing_messages_without_bumping_version() { let reducer = AddMessages; let mut state = base_state(); let initial_version = state.messages.version(); @@ -53,12 +33,11 @@ fn test_add_messages_appends_state() { assert_eq!(snapshot.len(), initial_len + 1); assert_eq!(snapshot[0].role, Role::User); assert_eq!(snapshot[1].role, Role::System); - // Reducer does not bump version (barrier responsibility) assert_eq!(state.messages.version(), initial_version); } #[test] -fn test_add_messages_empty_partial_noop() { +fn add_messages_with_empty_vec_is_a_no_op() { let reducer = AddMessages; let mut state = base_state(); let initial_version = state.messages.version(); @@ -72,15 +51,12 @@ fn test_add_messages_empty_partial_noop() { assert_eq!(state.messages.version(), initial_version); } -/******************** - * MapMerge (extra) tests - ********************/ +// MapMerge (extra) #[test] -fn test_map_merge_merges_and_overwrites_state() { +fn map_merge_inserts_new_and_overwrites_existing_keys() { let reducer = MapMerge; let mut state = base_state(); - // Seed extra state .extra .get_mut() @@ -89,7 +65,7 @@ fn test_map_merge_merges_and_overwrites_state() { let mut extra_update = FxHashMap::default(); extra_update.insert("k2".into(), Value::String("v2".into())); - extra_update.insert("k1".into(), Value::String("v3".into())); // overwrite existing + extra_update.insert("k1".into(), Value::String("v3".into())); let partial = NodePartial::new().with_extra(extra_update); @@ -108,12 +84,11 @@ fn test_map_merge_merges_and_overwrites_state() { Some(&Value::String("v2".into())), "new key should be inserted" ); - // Version unchanged (barrier responsibility) assert_eq!(state.extra.version(), initial_version); } #[test] -fn test_map_merge_empty_partial_noop() { +fn map_merge_with_empty_map_is_a_no_op() { let reducer = MapMerge; let mut state = base_state(); state @@ -131,12 +106,10 @@ fn test_map_merge_empty_partial_noop() { assert_eq!(state.extra.version(), initial_version); } -/******************** - * Enum wrapper / dispatch - ********************/ +// Reducer dispatch #[test] -fn test_enum_wrapper_dispatch() { +fn multiple_reducers_applied_sequentially_update_all_channels() { let reducers: Vec> = vec![Arc::new(AddMessages), Arc::new(MapMerge)]; let mut state = base_state(); @@ -164,33 +137,26 @@ fn test_enum_wrapper_dispatch() { ); } -/******************** - * Guard logic - ********************/ +// Registry #[test] -fn test_channel_guard_logic() { - let empty = NodePartial::default(); - assert!(!channel_guard(ChannelType::Message, &empty)); - assert!(!channel_guard(ChannelType::Extra, &empty)); +fn registry_skips_empty_channels_and_leaves_state_unchanged() { + let registry = ReducerRegistry::default(); + let mut state = base_state(); + let initial_messages = state.messages.snapshot(); + let initial_extra = state.extra.snapshot(); - let msg_partial = - NodePartial::new().with_messages(vec![Message::with_role(Role::Assistant, "m")]); - assert!(channel_guard(ChannelType::Message, &msg_partial)); - assert!(!channel_guard(ChannelType::Extra, &msg_partial)); + let empty = NodePartial::default(); + for channel in [ChannelType::Message, ChannelType::Extra] { + let _ = registry.try_update(channel, &mut state, &empty); + } - let mut extra_map = FxHashMap::default(); - extra_map.insert("k".into(), Value::String("v".into())); - let extra_partial = NodePartial::new().with_extra(extra_map); - assert!(channel_guard(ChannelType::Extra, &extra_partial)); + assert_eq!(state.messages.snapshot(), initial_messages); + assert_eq!(state.extra.snapshot(), initial_extra); } -/******************** - * Registry integration-like flow - ********************/ - #[test] -fn test_registry_integration_like_flow() { +fn registry_applies_reducers_for_channels_with_data() { let registry = ReducerRegistry::default(); let mut state = base_state(); @@ -201,24 +167,18 @@ fn test_registry_integration_like_flow() { .with_messages(vec![Message::with_role(Role::Assistant, "from node")]) .with_extra(extra_update); - // Simulate runtime iterating channels for channel in [ChannelType::Message, ChannelType::Extra] { - if channel_guard(channel.clone(), &partial) { - let _ = registry.try_update(channel, &mut state, &partial); - } + let _ = registry.try_update(channel, &mut state, &partial); } assert_message_contains(&state, "from node"); assert_extra_has(&state, "origin"); } -/***************************** - * Concurrency tests (Stage 4) - *****************************/ +// Concurrency -/// Test concurrent reducer application from multiple threads #[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn test_reducer_thread_safety() { +async fn concurrent_registry_updates_accumulate_all_messages() { let registry = Arc::new(ReducerRegistry::default()); let state = Arc::new(tokio::sync::Mutex::new(base_state())); @@ -244,85 +204,72 @@ async fn test_reducer_thread_safety() { } let final_state = state.lock().await; - // Initial state has 1 message, we added 10 more assert_eq!(final_state.messages.snapshot().len(), 11); } -/// Test deterministic behavior under concurrent access #[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn test_reducer_determinism_under_concurrency() { - // Run same operations multiple times, verify state convergence - for _ in 0..10 { - let registry = Arc::new(ReducerRegistry::default()); - let state1 = Arc::new(tokio::sync::Mutex::new(base_state())); - let state2 = Arc::new(tokio::sync::Mutex::new(base_state())); - - // Apply same partials concurrently to both states - let partials: Vec = (0..5) - .map(|i| { - NodePartial::new() - .with_messages(vec![Message::with_role(Role::User, &format!("test_{}", i))]) - }) - .collect(); - - // Apply to state1 - let handles1: Vec<_> = partials - .iter() - .map(|partial| { - let registry = Arc::clone(®istry); - let state = Arc::clone(&state1); - let partial = partial.clone(); - - tokio::spawn(async move { - let mut state_guard = state.lock().await; - let _ = registry.try_update(ChannelType::Message, &mut state_guard, &partial); - }) - }) - .collect(); - - // Apply to state2 - let handles2: Vec<_> = partials - .iter() - .map(|partial| { - let registry = Arc::clone(®istry); - let state = Arc::clone(&state2); - let partial = partial.clone(); - - tokio::spawn(async move { - let mut state_guard = state.lock().await; - let _ = registry.try_update(ChannelType::Message, &mut state_guard, &partial); - }) - }) - .collect(); +async fn separate_states_with_same_concurrent_updates_have_equal_message_counts() { + let registry = Arc::new(ReducerRegistry::default()); + let state1 = Arc::new(tokio::sync::Mutex::new(base_state())); + let state2 = Arc::new(tokio::sync::Mutex::new(base_state())); + + let partials: Vec = (0..5) + .map(|i| { + NodePartial::new() + .with_messages(vec![Message::with_role(Role::User, &format!("test_{}", i))]) + }) + .collect(); - for handle in handles1.into_iter().chain(handles2) { - handle.await.unwrap(); - } + let handles1: Vec<_> = partials + .iter() + .map(|partial| { + let registry = Arc::clone(®istry); + let state = Arc::clone(&state1); + let partial = partial.clone(); + + tokio::spawn(async move { + let mut state_guard = state.lock().await; + let _ = registry.try_update(ChannelType::Message, &mut state_guard, &partial); + }) + }) + .collect(); - // Verify final states are identical - let final_state1 = state1.lock().await; - let final_state2 = state2.lock().await; + let handles2: Vec<_> = partials + .iter() + .map(|partial| { + let registry = Arc::clone(®istry); + let state = Arc::clone(&state2); + let partial = partial.clone(); - assert_eq!( - final_state1.messages.snapshot().len(), - final_state2.messages.snapshot().len() - ); + tokio::spawn(async move { + let mut state_guard = state.lock().await; + let _ = registry.try_update(ChannelType::Message, &mut state_guard, &partial); + }) + }) + .collect(); - // Both should have initial message + 5 new messages - assert_eq!(final_state1.messages.snapshot().len(), 6); + for handle in handles1.into_iter().chain(handles2) { + handle.await.unwrap(); } + + let final_state1 = state1.lock().await; + let final_state2 = state2.lock().await; + + assert_eq!( + final_state1.messages.snapshot().len(), + final_state2.messages.snapshot().len() + ); + assert_eq!(final_state1.messages.snapshot().len(), 6); } -/// Test channel isolation - reducers for one channel don't affect others #[test] -fn test_reducer_channel_isolation() { +fn message_and_extra_reducers_do_not_affect_each_other() { let registry = ReducerRegistry::default(); let mut state = base_state(); let initial_messages = state.messages.snapshot().len(); let initial_extra_keys = state.extra.snapshot().len(); - // Apply message-only partial let message_partial = NodePartial::new() .with_messages(vec![Message::with_role(Role::System, "isolated message")]); @@ -330,11 +277,9 @@ fn test_reducer_channel_isolation() { .try_update(ChannelType::Message, &mut state, &message_partial) .unwrap(); - // Verify only messages channel was affected assert_eq!(state.messages.snapshot().len(), initial_messages + 1); assert_eq!(state.extra.snapshot().len(), initial_extra_keys); - // Apply extra-only partial let mut extra_map = FxHashMap::default(); extra_map.insert( "isolated_key".into(), @@ -347,7 +292,6 @@ fn test_reducer_channel_isolation() { .try_update(ChannelType::Extra, &mut state, &extra_partial) .unwrap(); - // Verify only extra channel was affected (messages unchanged from previous operation) assert_eq!(state.messages.snapshot().len(), initial_messages + 1); assert_eq!(state.extra.snapshot().len(), initial_extra_keys + 1); assert_eq!( From 143aa99e4d9bbf5f88bc5104ce480465fb66a5dc Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 19:52:58 -0400 Subject: [PATCH 09/15] second batch of test revisions --- tests/runtimes_checkpointers.rs | 80 +---- tests/runtimes_concurrent.rs | 32 +- tests/runtimes_persistence.rs | 10 +- tests/runtimes_persistence_postgres.rs | 23 +- tests/runtimes_persistence_sqlite.rs | 12 +- tests/runtimes_replay.rs | 4 +- tests/runtimes_runner.rs | 465 ++++++++++++------------- tests/runtimes_types.rs | 13 +- tests/schedulers.rs | 38 +- tests/state_channels.rs | 85 +++-- tests/streaming_sse.rs | 3 +- tests/telemetry.rs | 41 +-- tests/types.rs | 18 +- tests/utils.rs | 10 +- 14 files changed, 351 insertions(+), 483 deletions(-) diff --git a/tests/runtimes_checkpointers.rs b/tests/runtimes_checkpointers.rs index 296bcef..763463d 100644 --- a/tests/runtimes_checkpointers.rs +++ b/tests/runtimes_checkpointers.rs @@ -10,15 +10,13 @@ use weavegraph::runtimes::checkpointer::{ }; use weavegraph::runtimes::checkpointer_sqlite::{SQLiteCheckpointer, StepQuery}; use weavegraph::schedulers::{Scheduler, SchedulerState}; -use weavegraph::state::VersionedState; use weavegraph::types::NodeKind; mod common; use common::*; -// Base Checkpointer trait tests #[tokio::test] -async fn test_inmemory_checkpointer_save_and_load_roundtrip() { +async fn inmemory_checkpoint_survives_save_and_reload() { let cp_store = InMemoryCheckpointer::new(); let mut session = SessionState { state: state_with_user("hi"), @@ -49,7 +47,7 @@ async fn test_inmemory_checkpointer_save_and_load_roundtrip() { } #[tokio::test] -async fn test_inmemory_checkpointer_list_sessions() { +async fn inmemory_checkpointer_lists_all_saved_session_ids() { let cp_store = InMemoryCheckpointer::new(); let session = SessionState { state: state_with_user("x"), @@ -71,63 +69,10 @@ async fn test_inmemory_checkpointer_list_sessions() { assert_eq!(ids, vec!["alpha", "beta"]); } -#[tokio::test] -async fn test_save_and_load_roundtrip() { - let cp_store = InMemoryCheckpointer::new(); - let mut session = weavegraph::runtimes::SessionState { - state: VersionedState::new_with_user_message("hi"), - step: 3, - frontier: vec![NodeKind::Start], - scheduler: weavegraph::schedulers::Scheduler::new(4), - scheduler_state: SchedulerState::default(), - }; - session.scheduler_state.versions_seen.insert( - "Start".into(), - FxHashMap::from_iter([("messages".into(), 1_u64), ("extra".into(), 1_u64)]), - ); - - let cp = Checkpoint::from_session("sess1", &session); - cp_store.save(cp.clone()).await.unwrap(); - - let loaded = cp_store.load_latest("sess1").await.unwrap().unwrap(); - assert_eq!(loaded.step, 3); - assert_eq!(loaded.frontier, vec![NodeKind::Start]); - assert_eq!( - loaded.versions_seen.get("Start").unwrap().get("messages"), - Some(&1) - ); - assert_eq!( - loaded.state.messages.snapshot().len(), - session.state.messages.snapshot().len() - ); -} - -#[tokio::test] -async fn test_list_sessions() { - let cp_store = InMemoryCheckpointer::new(); - let session = weavegraph::runtimes::SessionState { - state: VersionedState::new_with_user_message("x"), - step: 0, - frontier: vec![NodeKind::Start], - scheduler: weavegraph::schedulers::Scheduler::new(1), - scheduler_state: SchedulerState::default(), - }; - cp_store - .save(Checkpoint::from_session("alpha", &session)) - .await - .unwrap(); - cp_store - .save(Checkpoint::from_session("beta", &session)) - .await - .unwrap(); - let mut ids = cp_store.list_sessions().await.unwrap(); - ids.sort(); - assert_eq!(ids, vec!["alpha", "beta"]); -} +// SQLite checkpointer tests -// SQLite Checkpointer tests #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_sqlite_checkpointer_roundtrip() { +async fn sqlite_checkpoint_survives_save_and_reload() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect sqlite memory"); @@ -156,10 +101,8 @@ async fn test_sqlite_checkpointer_roundtrip() { updated_channels: vec!["messages".to_string()], }; - // Save (async trait method) cp.save(cp_struct.clone()).await.expect("save"); - // Load let loaded = cp .load_latest("sessX") .await @@ -181,7 +124,6 @@ async fn test_sqlite_checkpointer_roundtrip() { Some(&serde_json::json!(42)) ); - // Restore session state utility compatibility let session_state = restore_session_state(&loaded); assert_eq!(session_state.step, 1); assert_eq!(session_state.frontier.len(), 1); @@ -189,7 +131,7 @@ async fn test_sqlite_checkpointer_roundtrip() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_sqlite_list_sessions() { +async fn sqlite_checkpointer_lists_all_saved_session_ids() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect"); @@ -216,7 +158,7 @@ async fn test_sqlite_list_sessions() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_sqlite_load_nonexistent() { +async fn sqlite_checkpointer_returns_none_for_unknown_session() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect"); @@ -225,7 +167,7 @@ async fn test_sqlite_load_nonexistent() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_sqlite_step_execution_metadata() { +async fn sqlite_checkpoint_preserves_step_execution_metadata() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect sqlite memory"); @@ -246,7 +188,6 @@ async fn test_sqlite_step_execution_metadata() { cp.save(checkpoint.clone()).await.expect("save checkpoint"); - // Query the step to verify execution metadata is preserved let query = StepQuery { limit: Some(10), ..Default::default() @@ -265,12 +206,11 @@ async fn test_sqlite_step_execution_metadata() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_sqlite_query_steps_pagination() { +async fn sqlite_query_steps_respects_limit_and_offset() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect sqlite memory"); - // Create multiple checkpoints for step in 1..=5 { let state = state_with_user(&format!("step {step}")); let checkpoint = Checkpoint { @@ -292,7 +232,6 @@ async fn test_sqlite_query_steps_pagination() { cp.save(checkpoint).await.expect("save checkpoint"); } - // Test pagination with limit let query = StepQuery { limit: Some(2), offset: Some(0), @@ -312,9 +251,8 @@ async fn test_sqlite_query_steps_pagination() { assert_eq!(result.checkpoints[1].step, 4); } -// Concurrency behavior: ensure async RwLock in InMemoryCheckpointer allows many concurrent saves #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_inmemory_checkpointer_concurrent_operations() { +async fn inmemory_checkpointer_accepts_concurrent_saves() { let cp_store = Arc::new(InMemoryCheckpointer::new()); let mut handles = Vec::new(); diff --git a/tests/runtimes_concurrent.rs b/tests/runtimes_concurrent.rs index f2af152..e2927cc 100644 --- a/tests/runtimes_concurrent.rs +++ b/tests/runtimes_concurrent.rs @@ -93,7 +93,7 @@ fn make_marker_app(marker: &str) -> weavegraph::app::App { } #[tokio::test] -async fn test_multiple_sessions_same_runner() { +async fn all_sessions_on_shared_runner_complete_exactly_once() { let counter = Arc::new(AtomicUsize::new(0)); let app = make_counting_app(counter.clone(), 0); let mut runner = AppRunner::builder() @@ -104,7 +104,6 @@ async fn test_multiple_sessions_same_runner() { let session_count = 5; - // Create multiple sessions for i in 0..session_count { let session_id = format!("session_{i}"); let state = state_with_user(&format!("message {i}")); @@ -114,7 +113,6 @@ async fn test_multiple_sessions_same_runner() { .expect("session creation"); } - // Run steps on each session sequentially for i in 0..session_count { let session_id = format!("session_{i}"); let result = runner.run_step(&session_id, StepOptions::default()).await; @@ -124,13 +122,11 @@ async fn test_multiple_sessions_same_runner() { } } - // Counter should reflect all executions assert_eq!(counter.load(Ordering::SeqCst), session_count); } #[tokio::test] -async fn test_session_isolation() { - // Each session gets its own marker to verify state isolation +async fn independent_runners_each_execute_their_own_app_logic() { let app1 = make_marker_app("session_A_marker"); let app2 = make_marker_app("session_B_marker"); @@ -145,7 +141,6 @@ async fn test_session_isolation() { .build() .await; - // Create and run both sessions runner1 .create_session("session_A".into(), state_with_user("A")) .await @@ -164,10 +159,8 @@ async fn test_session_isolation() { .await .unwrap(); - // Verify both completed match (result_a, result_b) { (StepResult::Completed(rep_a), StepResult::Completed(rep_b)) => { - // Both should have run their marker nodes assert!(rep_a.ran_nodes.contains(&NodeKind::Custom("marker".into()))); assert!(rep_b.ran_nodes.contains(&NodeKind::Custom("marker".into()))); } @@ -176,8 +169,7 @@ async fn test_session_isolation() { } #[tokio::test] -async fn test_session_state_independence() { - // Create a single app but multiple sessions to verify state is isolated +async fn two_sessions_on_shared_runner_execute_independently() { let counter = Arc::new(AtomicUsize::new(0)); let app = make_counting_app(counter.clone(), 0); let mut runner = AppRunner::builder() @@ -186,7 +178,6 @@ async fn test_session_state_independence() { .build() .await; - // Create sessions with different initial states runner .create_session("session_1".into(), state_with_user("initial_1")) .await @@ -196,30 +187,24 @@ async fn test_session_state_independence() { .await .unwrap(); - // Run session 1 - the graph completes in one step (Start -> counter -> End) let result1 = runner .run_step("session_1", StepOptions::default()) .await .unwrap(); - - // Run session 2 once let result2 = runner .run_step("session_2", StepOptions::default()) .await .unwrap(); - // Both sessions should complete assert!(matches!(result1, StepResult::Completed(_))); assert!(matches!(result2, StepResult::Completed(_))); - - // Counter should reflect 2 executions (one per session) assert_eq!(counter.load(Ordering::SeqCst), 2); } #[tokio::test] -async fn test_high_session_count() { +async fn fifty_sessions_on_shared_runner_all_complete() { let counter = Arc::new(AtomicUsize::new(0)); - let app = make_counting_app(counter.clone(), 1); // 1ms delay + let app = make_counting_app(counter.clone(), 1); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) @@ -228,7 +213,6 @@ async fn test_high_session_count() { let session_count = 50; - // Create all sessions for i in 0..session_count { runner .create_session(format!("stress_{i}"), state_with_user(&format!("{i}"))) @@ -236,7 +220,6 @@ async fn test_high_session_count() { .unwrap(); } - // Run all sessions sequentially let mut success_count = 0; for i in 0..session_count { let result = runner @@ -252,7 +235,7 @@ async fn test_high_session_count() { } #[tokio::test] -async fn test_session_resume_different_order() { +async fn sessions_complete_correctly_regardless_of_run_order() { let counter = Arc::new(AtomicUsize::new(0)); let app = make_counting_app(counter.clone(), 0); let mut runner = AppRunner::builder() @@ -261,7 +244,6 @@ async fn test_session_resume_different_order() { .build() .await; - // Create sessions runner .create_session("first".into(), state_with_user("1")) .await @@ -275,7 +257,6 @@ async fn test_session_resume_different_order() { .await .unwrap(); - // Run in reverse order runner .run_step("third", StepOptions::default()) .await @@ -289,6 +270,5 @@ async fn test_session_resume_different_order() { .await .unwrap(); - // All three should have executed assert_eq!(counter.load(Ordering::SeqCst), 3); } diff --git a/tests/runtimes_persistence.rs b/tests/runtimes_persistence.rs index ecfe8bb..84c928c 100644 --- a/tests/runtimes_persistence.rs +++ b/tests/runtimes_persistence.rs @@ -16,7 +16,7 @@ mod common; use common::*; #[test] -fn test_state_round_trip() { +fn state_survives_json_serialization_roundtrip() { let mut vs = state_with_user("hello"); vs.extra .get_mut() @@ -33,7 +33,7 @@ fn test_state_round_trip() { } #[test] -fn test_state_deserialize_without_errors_channel() { +fn errors_channel_defaults_to_empty_when_omitted_from_json() { let json = r#"{ "messages": {"version": 1, "items": []}, "extra": {"version": 1, "map": {}} @@ -44,7 +44,7 @@ fn test_state_deserialize_without_errors_channel() { } #[test] -fn test_checkpoint_round_trip() { +fn checkpoint_survives_json_serialization_roundtrip() { // Build synthetic checkpoint let vs = state_with_user("seed"); let cp = Checkpoint { @@ -91,9 +91,7 @@ fn test_checkpoint_round_trip() { } #[test] -fn test_nodekind_encode_decode() { - use weavegraph::types::NodeKind; - +fn nodekind_encodes_and_decodes_all_variants() { let kinds = vec![ NodeKind::Start, NodeKind::End, diff --git a/tests/runtimes_persistence_postgres.rs b/tests/runtimes_persistence_postgres.rs index 392ede3..d4ab6f2 100644 --- a/tests/runtimes_persistence_postgres.rs +++ b/tests/runtimes_persistence_postgres.rs @@ -53,7 +53,7 @@ fn unique_session_id(prefix: &str) -> String { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_postgres_checkpointer_roundtrip() { +async fn checkpoint_state_and_metadata_survive_postgres_roundtrip() { let cp = connect_or_fail().await; let session_id = unique_session_id("roundtrip"); @@ -108,7 +108,7 @@ async fn test_postgres_checkpointer_roundtrip() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_list_sessions_and_empty_load() { +async fn saved_sessions_appear_in_list_and_missing_session_returns_none() { let cp = connect_or_fail().await; // Use unique session IDs for this test @@ -148,7 +148,7 @@ async fn test_list_sessions_and_empty_load() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_step_execution_metadata_query_and_pagination() { +async fn step_query_paginates_results_newest_first() { let cp = connect_or_fail().await; let session_id = unique_session_id("paginate"); @@ -195,7 +195,7 @@ async fn test_step_execution_metadata_query_and_pagination() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_error_persistence_roundtrip() { +async fn error_events_survive_postgres_roundtrip() { let cp = connect_or_fail().await; let session_id = unique_session_id("err"); @@ -205,7 +205,7 @@ async fn test_error_persistence_roundtrip() { .with_tag("t") .with_context(serde_json::json!({"a":1})); - state.errors.get_mut().push(err.clone()); + state.errors.get_mut().push(err); let checkpoint = Checkpoint { session_id: session_id.clone(), step: 1, @@ -226,7 +226,7 @@ async fn test_error_persistence_roundtrip() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_idempotent_save() { +async fn duplicate_save_is_idempotent() { let cp = connect_or_fail().await; let session_id = unique_session_id("idempotent"); @@ -254,7 +254,7 @@ async fn test_idempotent_save() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_concurrency_check() { +async fn save_with_stale_expected_step_is_rejected() { let cp = connect_or_fail().await; let session_id = unique_session_id("concurrency"); @@ -310,7 +310,7 @@ async fn test_concurrency_check() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_out_of_order_writes_do_not_regress_latest() { +async fn out_of_order_write_does_not_overwrite_higher_step() { let cp = connect_or_fail().await; let session_id = unique_session_id("out_of_order"); @@ -371,7 +371,7 @@ async fn test_out_of_order_writes_do_not_regress_latest() { } #[tokio::test(flavor = "multi_thread", worker_threads = 4)] -async fn test_concurrent_writers_only_one_wins_concurrency_check() { +async fn concurrent_writers_only_one_wins_concurrency_check() { let cp = Arc::new(connect_or_fail().await); let session_id = unique_session_id("concurrent_writers"); @@ -441,13 +441,8 @@ async fn test_concurrent_writers_only_one_wins_concurrency_check() { .into_iter() .filter(|r| r.is_ok()) .count(); - let err_count = [res_a.as_ref(), res_b.as_ref()] - .into_iter() - .filter(|r| r.is_err()) - .count(); assert_eq!(ok_count, 1, "exactly one writer should succeed"); - assert_eq!(err_count, 1, "exactly one writer should fail"); // Latest must be step 2, with one of the markers. let loaded = cp.load_latest(&session_id).await.unwrap().unwrap(); diff --git a/tests/runtimes_persistence_sqlite.rs b/tests/runtimes_persistence_sqlite.rs index f7aae03..f664462 100644 --- a/tests/runtimes_persistence_sqlite.rs +++ b/tests/runtimes_persistence_sqlite.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "sqlite")] +#![cfg(feature = "sqlite")] use chrono::Utc; use rustc_hash::FxHashMap; use weavegraph::channels::Channel; @@ -11,7 +11,7 @@ mod common; use common::*; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_sqlite_checkpointer_roundtrip() { +async fn checkpoint_state_and_metadata_survive_sqlite_roundtrip() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect sqlite memory"); @@ -66,7 +66,7 @@ async fn test_sqlite_checkpointer_roundtrip() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_list_sessions_and_empty_load() { +async fn saved_sessions_appear_in_list_and_missing_session_returns_none() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect"); @@ -99,7 +99,7 @@ async fn test_list_sessions_and_empty_load() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_step_execution_metadata_query_and_pagination() { +async fn step_query_paginates_results_newest_first() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect"); @@ -146,7 +146,7 @@ async fn test_step_execution_metadata_query_and_pagination() { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_error_persistence_roundtrip() { +async fn error_events_survive_sqlite_roundtrip() { let cp = SQLiteCheckpointer::connect("sqlite::memory:") .await .expect("connect"); @@ -156,7 +156,7 @@ async fn test_error_persistence_roundtrip() { .with_tag("t") .with_context(serde_json::json!({"a":1})); - state.errors.get_mut().push(err.clone()); + state.errors.get_mut().push(err); let checkpoint = Checkpoint { session_id: "err_sess".into(), step: 1, diff --git a/tests/runtimes_replay.rs b/tests/runtimes_replay.rs index 3007506..3eb6b03 100644 --- a/tests/runtimes_replay.rs +++ b/tests/runtimes_replay.rs @@ -40,7 +40,7 @@ fn replay_event_comparison_supports_custom_normalizer() { } #[test] -fn replay_run_comparison_checks_state_and_events() { +fn replay_run_comparison_matches_equal_runs_and_reports_conformance_error_on_state_mismatch() { let left_state = VersionedState::builder() .with_extra("value", json!(1)) .build(); @@ -66,7 +66,7 @@ fn replay_run_comparison_checks_state_and_events() { } #[test] -fn replay_comparison_constructors_and_assertion_errors_preserve_differences() { +fn replay_comparison_stores_differences_and_assert_matches_returns_conformance_error() { assert!(ReplayComparison::matched().assert_matches().is_ok()); let comparison = ReplayComparison::with_differences(vec!["first".into(), "second".into()]); diff --git a/tests/runtimes_runner.rs b/tests/runtimes_runner.rs index d2f5815..2cf8b62 100644 --- a/tests/runtimes_runner.rs +++ b/tests/runtimes_runner.rs @@ -26,14 +26,13 @@ use weavegraph::{FrontierCommand, NodeRoute}; mod common; use common::*; -// Removed ad-hoc NodeA/NodeB; using common TestNode/FailingNode helpers instead. - fn make_test_app() -> weavegraph::app::App { - let mut builder = GraphBuilder::new(); - builder = builder.add_node(NodeKind::Custom("test".into()), TestNode { name: "test" }); - builder = builder.add_edge(NodeKind::Start, NodeKind::Custom("test".into())); - builder = builder.add_edge(NodeKind::Custom("test".into()), NodeKind::End); - builder.compile().unwrap() + GraphBuilder::new() + .add_node(NodeKind::Custom("test".into()), TestNode { name: "test" }) + .add_edge(NodeKind::Start, NodeKind::Custom("test".into())) + .add_edge(NodeKind::Custom("test".into()), NodeKind::End) + .compile() + .unwrap() } #[derive(Default)] @@ -197,7 +196,7 @@ impl Node for ClockProbeNode { } #[tokio::test] -async fn test_conditional_edge_routing() { +async fn conditional_edge_routes_to_labeled_target_based_on_state() { let pred: EdgePredicate = std::sync::Arc::new(|snap: StateSnapshot| { if snap.extra.contains_key("go_yes") { vec!["Y".to_string()] @@ -227,47 +226,44 @@ async fn test_conditional_edge_routing() { .extra .get_mut() .insert("go_yes".to_string(), serde_json::json!(1)); - match runner - .create_session("sess1".to_string(), state.clone()) - .await - .unwrap() - { - SessionInit::Fresh => {} - SessionInit::Resumed { .. } => panic!("expected fresh session"), - } - let report = runner + + assert_eq!( + runner + .create_session("sess1".to_string(), state.clone()) + .await + .unwrap(), + SessionInit::Fresh + ); + let StepResult::Completed(rep) = runner .run_step("sess1", StepOptions::default()) .await - .unwrap(); - if let StepResult::Completed(rep) = report { - assert!(rep.next_frontier.contains(&NodeKind::Custom("Y".into()))); - assert!(!rep.next_frontier.contains(&NodeKind::Custom("N".into()))); - } else { - panic!("Expected completed step"); - } - let state2 = state_with_user("hi"); - match runner - .create_session("sess2".to_string(), state2.clone()) - .await - .unwrap() - { - SessionInit::Fresh => {} - SessionInit::Resumed { .. } => panic!("expected fresh session"), - } - let report2 = runner + .expect("run_step should succeed") + else { + panic!("expected completed step"); + }; + assert!(rep.next_frontier.contains(&NodeKind::Custom("Y".into()))); + assert!(!rep.next_frontier.contains(&NodeKind::Custom("N".into()))); + + assert_eq!( + runner + .create_session("sess2".to_string(), state_with_user("hi")) + .await + .unwrap(), + SessionInit::Fresh + ); + let StepResult::Completed(rep2) = runner .run_step("sess2", StepOptions::default()) .await - .unwrap(); - if let StepResult::Completed(rep2) = report2 { - assert!(rep2.next_frontier.contains(&NodeKind::Custom("N".into()))); - assert!(!rep2.next_frontier.contains(&NodeKind::Custom("Y".into()))); - } else { - panic!("Expected completed step"); - } + .expect("run_step should succeed") + else { + panic!("expected completed step"); + }; + assert!(rep2.next_frontier.contains(&NodeKind::Custom("N".into()))); + assert!(!rep2.next_frontier.contains(&NodeKind::Custom("Y".into()))); } #[tokio::test] -async fn runner_event_stream_only_once() { +async fn event_stream_returns_none_on_second_call() { let app = make_test_app(); let mut runner = AppRunner::builder() .app(app) @@ -288,17 +284,16 @@ async fn runner_event_stream_only_once() { } #[tokio::test] -async fn test_create_session() { +async fn create_session_returns_fresh_and_is_retrievable() { let app = make_test_app(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = state_with_user("hello"); let result = runner - .create_session("test_session".into(), initial_state) + .create_session("test_session".into(), state_with_user("hello")) .await .unwrap(); assert_eq!(result, SessionInit::Fresh); @@ -306,7 +301,7 @@ async fn test_create_session() { } #[tokio::test] -async fn test_builder_custom_checkpointer_takes_precedence_over_enum() { +async fn builder_custom_checkpointer_takes_precedence_over_enum_type() { let app = make_test_app(); let session_id = "builder-custom-precedence"; let checkpoint = checkpoint_from_state( @@ -333,7 +328,7 @@ async fn test_builder_custom_checkpointer_takes_precedence_over_enum() { } #[tokio::test] -async fn test_runtime_config_custom_checkpointer_takes_precedence() { +async fn runtime_config_custom_checkpointer_restores_existing_session() { let session_id = "runtime-config-custom-precedence"; let checkpoint = checkpoint_from_state( session_id, @@ -364,75 +359,72 @@ async fn test_runtime_config_custom_checkpointer_takes_precedence() { "custom checkpointer should be invoked" ); assert_eq!(probe.save_calls(), 0); - assert!(probe.load_calls() > 0); } #[tokio::test] -async fn test_run_step_basic() { +async fn run_step_executes_scheduled_nodes_and_returns_step_report() { let app = make_test_app(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = state_with_user("hello"); assert_eq!( runner - .create_session("test_session".into(), initial_state) + .create_session("test_session".into(), state_with_user("hello")) .await .unwrap(), SessionInit::Fresh ); - let result = runner + let StepResult::Completed(report) = runner .run_step("test_session", StepOptions::default()) - .await; - assert!(result.is_ok()); - - if let Ok(StepResult::Completed(report)) = result { - assert_eq!(report.step, 1); - assert_eq!(report.ran_nodes.len(), 1); - assert!( - report - .barrier_outcome - .updated_channels - .contains(&"messages") - ); - } else { - panic!("Expected completed step, got: {:?}", result); - } + .await + .expect("run_step should succeed") + else { + panic!("expected completed step"); + }; + assert_eq!(report.step, 1); + assert_eq!(report.ran_nodes.len(), 1); + assert!( + report + .barrier_outcome + .updated_channels + .contains(&"messages") + ); } #[tokio::test] -async fn test_run_until_complete() { +async fn run_until_complete_returns_state_with_all_node_outputs() { let app = make_test_app(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = VersionedState::new_with_user_message("hello"); assert_eq!( runner - .create_session("test_session".into(), initial_state) + .create_session( + "test_session".into(), + VersionedState::new_with_user_message("hello") + ) .await .unwrap(), SessionInit::Fresh ); - let result = runner.run_until_complete("test_session").await; - assert!(result.is_ok()); - - let final_state = result.unwrap(); - // user + test node message - assert_eq!(final_state.messages.len(), 2); + let final_state = runner + .run_until_complete("test_session") + .await + .expect("run_until_complete should succeed"); + assert_eq!(final_state.messages.len(), 2); // user + test node message assert_message_contains(&final_state, "ran:test:step:1"); } #[tokio::test] -async fn test_iterative_invocation_processes_identical_inputs() { +async fn iterative_session_accumulates_state_across_invocations() { let app = make_iterative_app(); let mut runner = AppRunner::builder() .app(app) @@ -473,7 +465,7 @@ async fn test_iterative_invocation_processes_identical_inputs() { } #[tokio::test] -async fn test_iterative_session_rejects_invalid_entry_without_creating_session() { +async fn iterative_session_creation_rejects_invalid_entry_node() { let app = make_iterative_app(); let mut runner = AppRunner::builder() .app(app) @@ -500,7 +492,7 @@ async fn test_iterative_session_rejects_invalid_entry_without_creating_session() } #[tokio::test] -async fn test_iterative_invocation_rejects_invalid_entry_without_applying_input() { +async fn iterative_invocation_rejects_invalid_entry_without_mutating_state() { let app = make_iterative_app(); let mut runner = AppRunner::builder() .app(app) @@ -533,7 +525,7 @@ async fn test_iterative_invocation_rejects_invalid_entry_without_applying_input( } #[tokio::test] -async fn test_iterative_invocation_rejects_unregistered_custom_entry() { +async fn iterative_invocation_rejects_unregistered_custom_entry_node() { let app = make_iterative_app(); let mut runner = AppRunner::builder() .app(app) @@ -562,7 +554,7 @@ async fn test_iterative_invocation_rejects_unregistered_custom_entry() { } #[tokio::test] -async fn test_iterative_custom_entry_runs_from_registered_node() { +async fn iterative_custom_entry_executes_graph_from_registered_node() { let app = GraphBuilder::new() .add_node(NodeKind::Custom("first".into()), TickAccumulatorNode) .add_node(NodeKind::Custom("second".into()), TickAccumulatorNode) @@ -602,7 +594,7 @@ async fn test_iterative_custom_entry_runs_from_registered_node() { } #[tokio::test] -async fn test_iterative_event_stream_stays_open_until_finished() { +async fn iterative_event_stream_stays_open_across_invocations_until_finished() { let app = make_iterative_app(); let mut runner = AppRunner::builder() .app(app) @@ -669,7 +661,7 @@ async fn test_iterative_event_stream_stays_open_until_finished() { } #[tokio::test] -async fn test_iterative_event_stream_reports_errors_without_closing_until_finished() { +async fn iterative_event_stream_reports_error_events_without_closing_stream() { let app = GraphBuilder::new() .add_node(NodeKind::Custom("fail".into()), FailingNode::default()) .add_edge(NodeKind::Start, NodeKind::Custom("fail".into())) @@ -719,7 +711,7 @@ async fn test_iterative_event_stream_reports_errors_without_closing_until_finish } #[tokio::test] -async fn test_finish_iterative_session_reports_missing_session() { +async fn finish_iterative_session_errors_on_unknown_session_id() { let app = make_iterative_app(); let mut runner = AppRunner::builder() .app(app) @@ -739,7 +731,7 @@ async fn test_finish_iterative_session_reports_missing_session() { } #[tokio::test] -async fn test_iterative_invocation_resumes_latest_checkpoint() { +async fn iterative_session_resumes_from_checkpoint_after_restart() { const SESSION_ID: &str = "iterative-resume"; let mut uninterrupted = AppRunner::builder() @@ -823,7 +815,7 @@ async fn test_iterative_invocation_resumes_latest_checkpoint() { } #[tokio::test] -async fn test_runtime_clock_reaches_node_context_and_events() { +async fn mock_clock_timestamp_propagates_to_node_context_and_events() { let app = GraphBuilder::new() .add_node(NodeKind::Custom("clock".into()), ClockProbeNode) .add_edge(NodeKind::Start, NodeKind::Custom("clock".into())) @@ -872,7 +864,7 @@ async fn test_runtime_clock_reaches_node_context_and_events() { } #[tokio::test] -async fn test_runner_metadata_reports_graph_runtime_and_backends() { +async fn run_metadata_reports_graph_hash_and_backend_identifiers() { let app = make_iterative_app(); let graph_hash = app.graph_definition_hash(); let runner = AppRunner::builder() @@ -919,7 +911,7 @@ impl Node for WorkerNode { } #[tokio::test] -async fn test_frontier_command_replace_routes_nodes() { +async fn frontier_replace_command_redirects_execution_to_replacement_node() { let app = GraphBuilder::new() .add_node(NodeKind::Custom("controller".into()), ReplaceController) .add_node(NodeKind::Custom("worker".into()), WorkerNode) @@ -937,67 +929,60 @@ async fn test_frontier_command_replace_routes_nodes() { .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = state_with_user("control"); runner - .create_session("frontier-session".into(), initial_state) + .create_session("frontier-session".into(), state_with_user("control")) .await .expect("create session"); - let first_step = runner + let StepResult::Completed(first_report) = runner .run_step("frontier-session", StepOptions::default()) .await - .expect("first step"); - - match first_step { - StepResult::Completed(report) => { - assert_eq!( - report.ran_nodes, - vec![NodeKind::Custom("controller".into())] - ); - assert_eq!(report.barrier_outcome.frontier_commands.len(), 1); - match &report.barrier_outcome.frontier_commands[0].1 { - FrontierCommand::Replace(routes) => { - let kinds: Vec = routes.iter().map(NodeRoute::to_node_kind).collect(); - assert_eq!(kinds.len(), 1); - assert_eq!(kinds[0], NodeKind::Custom("worker".into())); - } - other => panic!("expected replace command, got {other:?}"), - } - } - other => panic!("expected completed step, got {other:?}"), - } + .expect("first step") + else { + panic!("expected completed step"); + }; + assert_eq!( + first_report.ran_nodes, + vec![NodeKind::Custom("controller".into())] + ); + assert_eq!(first_report.barrier_outcome.frontier_commands.len(), 1); + let FrontierCommand::Replace(routes) = &first_report.barrier_outcome.frontier_commands[0].1 + else { + panic!( + "expected replace command, got {:?}", + first_report.barrier_outcome.frontier_commands[0].1 + ); + }; + let kinds: Vec = routes.iter().map(NodeRoute::to_node_kind).collect(); + assert_eq!(kinds, vec![NodeKind::Custom("worker".into())]); - let second_step = runner + let StepResult::Completed(second_report) = runner .run_step("frontier-session", StepOptions::default()) .await - .expect("second step"); - - match second_step { - StepResult::Completed(report) => { - assert!( - report - .ran_nodes - .contains(&NodeKind::Custom("worker".into())) - ); - } - other => panic!("expected completed step, got {other:?}"), - } + .expect("second step") + else { + panic!("expected completed step"); + }; + assert!( + second_report + .ran_nodes + .contains(&NodeKind::Custom("worker".into())) + ); } #[tokio::test] -async fn test_interrupt_before() { +async fn interrupt_before_pauses_step_before_named_node() { let app = make_test_app(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = state_with_user("hello"); assert_eq!( runner - .create_session("test_session".into(), initial_state) + .create_session("test_session".into(), state_with_user("hello")) .await .unwrap(), SessionInit::Fresh @@ -1008,29 +993,28 @@ async fn test_interrupt_before() { ..Default::default() }; - let result = runner.run_step("test_session", options).await; - assert!(result.is_ok()); - - if let Ok(StepResult::Paused(paused)) = result { - assert!(matches!(paused.reason, PausedReason::BeforeNode(_))); - } else { - panic!("Expected paused step, got: {:?}", result); - } + let StepResult::Paused(paused) = runner + .run_step("test_session", options) + .await + .expect("run_step should succeed") + else { + panic!("expected paused step"); + }; + assert!(matches!(paused.reason, PausedReason::BeforeNode(_))); } #[tokio::test] -async fn test_interrupt_after() { +async fn interrupt_after_pauses_step_after_named_node() { let app = make_test_app(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = state_with_user("hello"); assert_eq!( runner - .create_session("test_session".into(), initial_state) + .create_session("test_session".into(), state_with_user("hello")) .await .unwrap(), SessionInit::Fresh @@ -1041,23 +1025,21 @@ async fn test_interrupt_after() { ..Default::default() }; - let result = runner.run_step("test_session", options).await; - assert!(result.is_ok()); - - if let Ok(StepResult::Paused(paused)) = result { - assert!(matches!(paused.reason, PausedReason::AfterNode(_))); - } else { - panic!("Expected paused step, got: {:?}", result); - } + let StepResult::Paused(paused) = runner + .run_step("test_session", options) + .await + .expect("run_step should succeed") + else { + panic!("expected paused step"); + }; + assert!(matches!(paused.reason, PausedReason::AfterNode(_))); } #[tokio::test] -async fn test_resume_from_checkpoint() { +async fn session_resumes_step_and_state_from_sqlite_checkpoint() { let temp_dir = tempfile::tempdir().unwrap(); let db_path = temp_dir.path().join("test_resume.db"); - // Build the app with a runtime config that points SQLite to our temp path, - // avoiding any process-wide environment mutation. let app = GraphBuilder::new() .add_node(NodeKind::Custom("test".into()), TestNode { name: "test" }) .add_edge(NodeKind::Start, NodeKind::Custom("test".into())) @@ -1085,16 +1067,15 @@ async fn test_resume_from_checkpoint() { SessionInit::Fresh ); - let step1_result = runner1 + let StepResult::Completed(step1_report) = runner1 .run_step(session_id, StepOptions::default()) .await - .unwrap(); - if let StepResult::Completed(report) = step1_result { - assert_eq!(report.step, 1); - assert!(!report.ran_nodes.is_empty()); - } else { - panic!("Expected completed step"); - } + .expect("first run_step should succeed") + else { + panic!("expected completed step"); + }; + assert_eq!(step1_report.step, 1); + assert!(!step1_report.ran_nodes.is_empty()); let session_after_step1 = runner1.get_session(session_id).unwrap().clone(); assert_eq!(session_after_step1.step, 1); @@ -1105,12 +1086,11 @@ async fn test_resume_from_checkpoint() { .checkpointer(CheckpointerType::SQLite) .build() .await; - let resume_result = runner2 - .create_session(session_id.into(), initial_state) - .await - .unwrap(); assert!(matches!( - resume_result, + runner2 + .create_session(session_id.into(), initial_state) + .await + .unwrap(), SessionInit::Resumed { checkpoint_step: 1 } )); let resumed_session = runner2.get_session(session_id).unwrap(); @@ -1120,12 +1100,10 @@ async fn test_resume_from_checkpoint() { resumed_session.state.messages.len(), session_after_step1.state.messages.len() ); - - // No environment cleanup necessary; the DB URL was provided via runtime config. } #[tokio::test] -async fn test_multi_target_conditional_edge() { +async fn conditional_edge_fans_out_to_multiple_targets_when_predicate_returns_many() { let multi_pred: EdgePredicate = std::sync::Arc::new(|snap: StateSnapshot| { if snap.extra.contains_key("fan_out") { vec!["A".to_string(), "B".to_string(), "C".to_string()] @@ -1134,7 +1112,7 @@ async fn test_multi_target_conditional_edge() { } }); - let gb = GraphBuilder::new() + let app = GraphBuilder::new() .add_node(NodeKind::Custom("Root".into()), TestNode { name: "root" }) .add_node(NodeKind::Custom("A".into()), TestNode { name: "A" }) .add_node(NodeKind::Custom("B".into()), TestNode { name: "B" }) @@ -1148,9 +1126,9 @@ async fn test_multi_target_conditional_edge() { .add_edge(NodeKind::Custom("B".into()), NodeKind::End) .add_edge(NodeKind::Custom("C".into()), NodeKind::End) .add_edge(NodeKind::Custom("Single".into()), NodeKind::End) - .add_conditional_edge(NodeKind::Custom("Root".into()), multi_pred); - - let app = gb.compile().unwrap(); + .add_conditional_edge(NodeKind::Custom("Root".into()), multi_pred) + .compile() + .unwrap(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) @@ -1167,41 +1145,53 @@ async fn test_multi_target_conditional_edge() { .await .unwrap(); - let step1 = runner + let StepResult::Completed(step1_report) = runner .run_step("multi_test", StepOptions::default()) .await - .unwrap(); - if let StepResult::Completed(report) = step1 { - assert_eq!(report.ran_nodes, vec![NodeKind::Custom("Root".into())]); - assert_eq!(report.next_frontier.len(), 3); - assert!(report.next_frontier.contains(&NodeKind::Custom("A".into()))); - assert!(report.next_frontier.contains(&NodeKind::Custom("B".into()))); - assert!(report.next_frontier.contains(&NodeKind::Custom("C".into()))); - } else { - panic!("Expected completed step"); - } + .expect("run_step should succeed") + else { + panic!("expected completed step"); + }; + assert_eq!( + step1_report.ran_nodes, + vec![NodeKind::Custom("Root".into())] + ); + assert_eq!(step1_report.next_frontier.len(), 3); + assert!( + step1_report + .next_frontier + .contains(&NodeKind::Custom("A".into())) + ); + assert!( + step1_report + .next_frontier + .contains(&NodeKind::Custom("B".into())) + ); + assert!( + step1_report + .next_frontier + .contains(&NodeKind::Custom("C".into())) + ); - let state2 = state_with_user("test2"); runner - .create_session("single_test".to_string(), state2) + .create_session("single_test".to_string(), state_with_user("test2")) .await .unwrap(); - let step2 = runner + let StepResult::Completed(step2_report) = runner .run_step("single_test", StepOptions::default()) .await - .unwrap(); - if let StepResult::Completed(report) = step2 { - assert_eq!( - report.next_frontier, - vec![NodeKind::Custom("Single".into())] - ); - } else { - panic!("Expected completed step"); - } + .expect("run_step should succeed") + else { + panic!("expected completed step"); + }; + assert_eq!( + step2_report.next_frontier, + vec![NodeKind::Custom("Single".into())] + ); } #[tokio::test] -async fn test_conditional_edge_with_invalid_targets() { +async fn conditional_edge_filters_unregistered_targets_from_frontier() { let mixed_pred: EdgePredicate = std::sync::Arc::new(|_snap: StateSnapshot| { vec![ "Valid".to_string(), @@ -1210,7 +1200,7 @@ async fn test_conditional_edge_with_invalid_targets() { ] }); - let gb = GraphBuilder::new() + let app = GraphBuilder::new() .add_node(NodeKind::Custom("Root".into()), TestNode { name: "root" }) .add_node(NodeKind::Custom("Valid".into()), TestNode { name: "valid" }) .add_edge( @@ -1219,68 +1209,69 @@ async fn test_conditional_edge_with_invalid_targets() { ) .add_edge(NodeKind::Custom("Valid".into()), NodeKind::End) .add_edge(NodeKind::Start, NodeKind::Custom("Root".into())) - .add_conditional_edge(NodeKind::Custom("Root".into()), mixed_pred); - - let app = gb.compile().unwrap(); + .add_conditional_edge(NodeKind::Custom("Root".into()), mixed_pred) + .compile() + .unwrap(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let state = state_with_user("test"); runner - .create_session("mixed_test".to_string(), state) + .create_session("mixed_test".to_string(), state_with_user("test")) .await .unwrap(); - let step = runner + let StepResult::Completed(report) = runner .run_step("mixed_test", StepOptions::default()) .await - .unwrap(); - if let StepResult::Completed(report) = step { - assert_eq!(report.next_frontier.len(), 2); - assert!( - report - .next_frontier - .contains(&NodeKind::Custom("Valid".into())) - ); - assert!(report.next_frontier.contains(&NodeKind::End)); - assert!( - !report - .next_frontier - .contains(&NodeKind::Custom("Invalid".into())) - ); - } else { - panic!("Expected completed step"); - } + .expect("run_step should succeed") + else { + panic!("expected completed step"); + }; + assert_eq!(report.next_frontier.len(), 2); + assert!( + report + .next_frontier + .contains(&NodeKind::Custom("Valid".into())) + ); + assert!(report.next_frontier.contains(&NodeKind::End)); + assert!( + !report + .next_frontier + .contains(&NodeKind::Custom("Invalid".into())) + ); } #[tokio::test] -async fn test_error_event_appended_on_failure() { - let mut gb = GraphBuilder::new(); - gb = gb.add_node(NodeKind::Custom("X".into()), FailingNode::default()); - gb = gb.add_edge(NodeKind::Start, NodeKind::Custom("X".into())); - gb = gb.add_edge(NodeKind::Custom("X".into()), NodeKind::End); - - let app = gb.compile().unwrap(); +async fn node_failure_appends_scoped_error_event_to_state() { + let app = GraphBuilder::new() + .add_node(NodeKind::Custom("X".into()), FailingNode::default()) + .add_edge(NodeKind::Start, NodeKind::Custom("X".into())) + .add_edge(NodeKind::Custom("X".into()), NodeKind::End) + .compile() + .unwrap(); let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - let initial_state = state_with_user("hello"); assert!(matches!( runner - .create_session("err_sess".into(), initial_state) + .create_session("err_sess".into(), state_with_user("hello")) .await .unwrap(), SessionInit::Fresh )); - let res = runner.run_step("err_sess", StepOptions::default()).await; - assert!(res.is_err()); + assert!( + runner + .run_step("err_sess", StepOptions::default()) + .await + .is_err() + ); let sess = runner.get_session("err_sess").unwrap(); let errors_snapshot = sess.state.errors.snapshot(); @@ -1289,13 +1280,13 @@ async fn test_error_event_appended_on_failure() { "expected errors to be present in errors channel" ); - let error_event = &errors_snapshot[0]; - assert!(matches!( - error_event.scope, - weavegraph::channels::errors::ErrorScope::Node { .. } - )); - if let weavegraph::channels::errors::ErrorScope::Node { kind, step } = &error_event.scope { - assert_eq!(kind, "Custom:X"); - assert_eq!(step, &1); - } + let weavegraph::channels::errors::ErrorScope::Node { kind, step } = &errors_snapshot[0].scope + else { + panic!( + "expected Node error scope, got {:?}", + errors_snapshot[0].scope + ); + }; + assert_eq!(kind, "Custom:X"); + assert_eq!(step, &1); } diff --git a/tests/runtimes_types.rs b/tests/runtimes_types.rs index 25b69b4..85c01b2 100644 --- a/tests/runtimes_types.rs +++ b/tests/runtimes_types.rs @@ -1,22 +1,21 @@ use weavegraph::runtimes::types::*; #[test] -fn test_session_id_creation() { +fn session_id_reflects_given_string() { let id = SessionId::new("test_session"); assert_eq!(id.as_str(), "test_session"); assert_eq!(id.to_string(), "test_session"); } #[test] -fn test_session_id_generation() { +fn generated_session_ids_are_unique() { let id1 = SessionId::generate(); let id2 = SessionId::generate(); - // Generated IDs should be different assert_ne!(id1, id2); } #[test] -fn test_step_number_arithmetic() { +fn step_number_next_increments_value_and_zero_is_initial() { let step = StepNumber::new(5); assert_eq!(step.value(), 5); assert_eq!(step.next().value(), 6); @@ -28,8 +27,6 @@ fn test_step_number_arithmetic() { } #[test] -fn test_step_number_saturation() { - let max_step = StepNumber::new(u64::MAX); - let next = max_step.next(); - assert_eq!(next.value(), u64::MAX); // Should saturate, not overflow +fn step_number_next_saturates_at_max_without_overflow() { + assert_eq!(StepNumber::new(u64::MAX).next().value(), u64::MAX); } diff --git a/tests/schedulers.rs b/tests/schedulers.rs index f05872a..c7c4612 100644 --- a/tests/schedulers.rs +++ b/tests/schedulers.rs @@ -4,9 +4,7 @@ use serde_json::json; use std::sync::Arc; use weavegraph::event_bus::EventBus; use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; -use weavegraph::schedulers::scheduler::{ - Scheduler, SchedulerRunContext, SchedulerState, StepRunResult, -}; +use weavegraph::schedulers::scheduler::{Scheduler, SchedulerRunContext, SchedulerState}; use weavegraph::state::StateSnapshot; use weavegraph::types::NodeKind; use weavegraph::utils::clock::MockClock; @@ -33,7 +31,7 @@ impl Node for SchedulerContextProbe { } #[tokio::test] -async fn test_superstep_propagates_node_error() { +async fn failing_node_causes_superstep_to_return_scheduler_error() { let sched = Scheduler::new(4); let mut state = SchedulerState::default(); let mut nodes: FxHashMap> = FxHashMap::default(); @@ -70,31 +68,27 @@ async fn test_superstep_propagates_node_error() { } #[test] -fn test_should_run_and_record_seen() { +fn node_reruns_only_when_snapshot_version_advances() { let sched = Scheduler::new(4); let mut state = SchedulerState::default(); - let id = "Other(\"A\")"; // legacy label encoding used in versions_seen + let id = "Other(\"A\")"; // string key used by record_seen / should_run - // No record -> should run let snap1 = create_test_snapshot(1, 1); assert!(sched.should_run(&state, id, &snap1)); - // Record seen -> no change -> should not run sched.record_seen(&mut state, id, &snap1); assert!(!sched.should_run(&state, id, &snap1)); - // Version bump on messages -> should run let snap2 = create_test_snapshot(2, 1); assert!(sched.should_run(&state, id, &snap2)); - // Record bump, then only extra increases -> should run sched.record_seen(&mut state, id, &snap2); let snap3 = create_test_snapshot(2, 3); assert!(sched.should_run(&state, id, &snap3)); } #[tokio::test] -async fn test_superstep_skips_end_and_nochange() { +async fn end_node_is_always_skipped_and_unchanged_snapshot_prevents_rerun() { let sched = Scheduler::new(8); let mut state = SchedulerState::default(); let nodes = make_test_registry(); @@ -105,9 +99,9 @@ async fn test_superstep_skips_end_and_nochange() { ]; let event_bus = EventBus::default(); - // First run: nothing recorded, both A and B should run; End skipped. + // Fresh snapshot: A and B run; End is always skipped. let snap = create_test_snapshot(1, 1); - let res1: StepRunResult = sched + let res1 = sched .superstep( &mut state, &nodes, @@ -119,7 +113,6 @@ async fn test_superstep_skips_end_and_nochange() { .await .unwrap(); - // All ran except End let ran1: std::collections::HashSet<_> = res1.ran_nodes.iter().cloned().collect(); assert!(ran1.contains(&NodeKind::Custom("A".into()))); assert!(ran1.contains(&NodeKind::Custom("B".into()))); @@ -127,7 +120,7 @@ async fn test_superstep_skips_end_and_nochange() { assert!(res1.skipped_nodes.contains(&NodeKind::End)); assert_eq!(res1.outputs.len(), 2); - // Record_seen happened inside superstep; with same snapshot, nothing should run now. + // Same snapshot: superstep recorded the versions; nothing runs again. let res2 = sched .superstep( &mut state, @@ -141,20 +134,19 @@ async fn test_superstep_skips_end_and_nochange() { .unwrap(); assert!(res2.ran_nodes.is_empty()); - // Both A and B plus End appear in skipped (version-gated or End) let skipped2: std::collections::HashSet<_> = res2.skipped_nodes.iter().cloned().collect(); assert!(skipped2.contains(&NodeKind::Custom("A".into()))); assert!(skipped2.contains(&NodeKind::Custom("B".into()))); assert!(skipped2.contains(&NodeKind::End)); assert!(res2.outputs.is_empty()); - // Increase messages version -> A and B should run again + // Bumped messages version: A and B run again; End remains skipped. let snap_bump = create_test_snapshot(2, 1); let res3 = sched .superstep( &mut state, &nodes, - frontier.clone(), + frontier, snap_bump, 3, SchedulerRunContext::new(event_bus.get_emitter()), @@ -168,7 +160,7 @@ async fn test_superstep_skips_end_and_nochange() { } #[tokio::test] -async fn test_superstep_outputs_order_agnostic() { +async fn concurrent_superstep_ran_nodes_preserves_frontier_order_and_outputs_are_set_equal() { let nodes = make_delayed_registry(); let frontier = vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())]; let snap = create_test_snapshot(1, 1); @@ -179,7 +171,7 @@ async fn test_superstep_outputs_order_agnostic() { .superstep( &mut state, &nodes, - frontier.clone(), + frontier, snap, 1, SchedulerRunContext::new(event_bus.get_emitter()), @@ -203,7 +195,7 @@ async fn test_superstep_outputs_order_agnostic() { } #[tokio::test] -async fn test_superstep_serialized_with_limit_1() { +async fn concurrency_limit_one_serializes_execution_and_output_order_matches_ran_nodes() { let nodes = make_delayed_registry(); let frontier = vec![NodeKind::Custom("A".into()), NodeKind::Custom("B".into())]; let snap = create_test_snapshot(1, 1); @@ -214,7 +206,7 @@ async fn test_superstep_serialized_with_limit_1() { .superstep( &mut state, &nodes, - frontier.clone(), + frontier, snap, 1, SchedulerRunContext::new(event_bus.get_emitter()), @@ -232,7 +224,7 @@ async fn test_superstep_serialized_with_limit_1() { } #[tokio::test] -async fn test_scheduler_run_context_injects_clock_and_invocation_id() { +async fn run_context_clock_and_invocation_id_are_visible_to_nodes() { let sched = Scheduler::new(1); let mut state = SchedulerState::default(); let mut nodes: FxHashMap> = FxHashMap::default(); diff --git a/tests/state_channels.rs b/tests/state_channels.rs index 6ba9f4a..bc07ebd 100644 --- a/tests/state_channels.rs +++ b/tests/state_channels.rs @@ -1,6 +1,5 @@ use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; -use weavegraph::channels::Channel; use weavegraph::message::{Message, Role}; use weavegraph::node::NodePartial; use weavegraph::state::{StateKey, StateSlotError, VersionedState}; @@ -41,7 +40,7 @@ impl Serialize for AlwaysFailsSerialize { const FAILING_SLOT: StateKey = StateKey::new("wg", "failing", 1); #[test] -fn test_new_with_user_message_initializes_fields() { +fn new_with_user_message_creates_state_with_one_user_message() { let s = VersionedState::new_with_user_message("hello"); let snap = s.snapshot(); assert_eq!(snap.messages.len(), 1); @@ -55,7 +54,7 @@ fn test_new_with_user_message_initializes_fields() { } #[test] -fn test_new_with_messages_initializes_fields() { +fn new_with_messages_creates_state_with_all_supplied_messages() { let messages = vec![ Message::with_role(Role::User, "hello"), Message::with_role(Role::Assistant, "hi there"), @@ -74,42 +73,41 @@ fn test_new_with_messages_initializes_fields() { } #[test] -fn test_snapshot_is_deep_copy() { +fn snapshot_is_independent_of_subsequent_mutations() { let mut s = VersionedState::new_with_user_message("x"); let snap = s.snapshot(); - s.messages.get_mut()[0].content = "changed".into(); - s.extra - .get_mut() - .insert("k".into(), Value::String("v".into())); + s.add_message("user", "second"); + s.add_extra("k", Value::String("v".into())); + assert_eq!(snap.messages.len(), 1); assert_eq!(snap.messages[0].content, "x"); assert!(!snap.extra.contains_key("k")); } #[test] -fn test_new_with_messages_snapshot_is_deep_copy() { +fn snapshot_from_multi_message_init_is_independent_of_subsequent_mutations() { let mut state = VersionedState::new_with_messages(vec![ Message::with_role(Role::User, "original"), Message::with_role(Role::Assistant, "response"), ]); let snapshot = state.snapshot(); - state.messages.get_mut()[0].content = "changed".into(); - state - .extra - .get_mut() - .insert("k".into(), Value::String("v".into())); + state.add_message("user", "third"); + state.add_extra("k", Value::String("v".into())); + assert_eq!(snapshot.messages.len(), 2); assert_eq!(snapshot.messages[0].content, "original"); assert_eq!(snapshot.messages[1].content, "response"); assert!(!snapshot.extra.contains_key("k")); } #[test] -fn test_extra_flexible_types() { - let mut s = VersionedState::new_with_user_message("y"); - s.extra.get_mut().insert("number".into(), json!(123)); - s.extra.get_mut().insert("text".into(), json!("abc")); - s.extra.get_mut().insert("array".into(), json!([1, 2, 3])); +fn extra_slot_accepts_number_string_and_array_json_values() { + let s = VersionedState::builder() + .with_user_message("y") + .with_extra("number", json!(123)) + .with_extra("text", json!("abc")) + .with_extra("array", json!([1, 2, 3])) + .build(); let snap = s.snapshot(); assert_eq!(snap.extra["number"], json!(123)); assert_eq!(snap.extra["text"], json!("abc")); @@ -117,28 +115,29 @@ fn test_extra_flexible_types() { } #[test] -fn test_clone_is_deep() { +fn clone_does_not_share_state_with_original() { let mut s = VersionedState::new_with_user_message("msg"); - s.extra - .get_mut() - .insert("k1".into(), Value::String("v1".into())); + s.add_extra("k1", Value::String("v1".into())); let cloned = s.clone(); - s.messages.get_mut()[0].content = "changed".into(); - s.extra - .get_mut() - .insert("k2".into(), Value::String("v2".into())); - assert_ne!(cloned.messages.snapshot(), s.messages.snapshot()); - assert_ne!(cloned.extra.snapshot(), s.extra.snapshot()); - assert_eq!(cloned.messages.snapshot()[0].content, "msg"); + s.add_message("user", "second"); + s.add_extra("k2", Value::String("v2".into())); + + let orig_snap = s.snapshot(); + let clone_snap = cloned.snapshot(); + + assert_ne!(orig_snap.messages, clone_snap.messages); + assert_ne!(orig_snap.extra, clone_snap.extra); + assert_eq!(clone_snap.messages.len(), 1); + assert_eq!(clone_snap.messages[0].content, "msg"); assert_eq!( - cloned.extra.snapshot().get("k1"), + clone_snap.extra.get("k1"), Some(&Value::String("v1".into())) ); - assert!(!cloned.extra.snapshot().contains_key("k2")); + assert!(!clone_snap.extra.contains_key("k2")); } #[test] -fn test_builder_pattern() { +fn builder_creates_state_with_all_supplied_messages_and_extra() { let state = VersionedState::builder() .with_user_message("Hello") .with_assistant_message("Hi there!") @@ -162,7 +161,7 @@ fn test_builder_pattern() { } #[test] -fn test_convenience_methods() { +fn add_message_and_add_extra_append_to_existing_state() { let mut state = VersionedState::new_with_user_message("Initial"); let _ = state .add_message("assistant", "Response") @@ -180,7 +179,7 @@ fn test_convenience_methods() { } #[test] -fn test_typed_state_slots_round_trip() { +fn typed_slot_stores_and_retrieves_value_under_storage_key() { let portfolio = PortfolioSnapshot { cash_cents: 12_345, position_count: 2, @@ -202,7 +201,7 @@ fn test_typed_state_slots_round_trip() { } #[test] -fn test_state_key_accessors_and_schema_versions() { +fn state_key_encodes_namespace_name_and_version_in_storage_key() { assert_eq!(PORTFOLIO.namespace(), "wq"); assert_eq!(PORTFOLIO.name(), "portfolio"); assert_eq!(PORTFOLIO.schema_version(), 1); @@ -212,7 +211,7 @@ fn test_state_key_accessors_and_schema_versions() { } #[test] -fn test_typed_state_slots_missing_and_optional_reads() { +fn missing_typed_slot_returns_none_for_get_and_errors_for_require() { let snapshot = VersionedState::builder().build().snapshot(); assert_eq!(snapshot.get_typed(PORTFOLIO).unwrap(), None); @@ -223,7 +222,7 @@ fn test_typed_state_slots_missing_and_optional_reads() { } #[test] -fn test_typed_state_slots_report_deserialization_errors_with_key() { +fn corrupt_typed_slot_deserialization_error_includes_storage_key() { let state = VersionedState::builder() .with_extra( &PORTFOLIO.storage_key(), @@ -244,7 +243,7 @@ fn test_typed_state_slots_report_deserialization_errors_with_key() { } #[test] -fn test_typed_state_slots_report_serialization_errors_with_key() { +fn unserializable_typed_slot_serialization_error_includes_storage_key() { let builder_error = VersionedState::builder() .with_typed_extra(FAILING_SLOT, AlwaysFailsSerialize) .unwrap_err(); @@ -270,7 +269,7 @@ fn test_typed_state_slots_report_serialization_errors_with_key() { } #[test] -fn test_typed_state_slots_schema_versions_can_coexist() { +fn typed_slots_at_different_schema_versions_coexist_independently() { let v1 = PortfolioSnapshot { cash_cents: 100, position_count: 1, @@ -293,7 +292,7 @@ fn test_typed_state_slots_schema_versions_can_coexist() { } #[test] -fn test_versioned_state_add_typed_extra_chains_and_overwrites_slot() { +fn add_typed_extra_overwrites_previous_value_for_same_slot() { let first = PortfolioSnapshot { cash_cents: 10, position_count: 1, @@ -314,7 +313,7 @@ fn test_versioned_state_add_typed_extra_chains_and_overwrites_slot() { } #[test] -fn test_node_partial_with_typed_extra() { +fn node_partial_stores_typed_value_under_correct_storage_key() { let portfolio = PortfolioSnapshot { cash_cents: 500, position_count: 1, @@ -335,7 +334,7 @@ fn test_node_partial_with_typed_extra() { } #[test] -fn test_node_partial_with_typed_extra_merges_with_existing_extra_and_overwrites_same_slot() { +fn node_partial_typed_extra_overwrites_slot_and_preserves_other_keys() { let old = PortfolioSnapshot { cash_cents: 1, position_count: 1, diff --git a/tests/streaming_sse.rs b/tests/streaming_sse.rs index 4380be4..af392cf 100644 --- a/tests/streaming_sse.rs +++ b/tests/streaming_sse.rs @@ -68,7 +68,8 @@ async fn handler( #[tokio::test(flavor = "multi_thread")] #[ignore] -async fn axum_sse_example_streams_until_completion() -> Result<(), Box> { +async fn axum_sse_handler_streams_events_until_end_marker() -> Result<(), Box> +{ let app = Arc::new( GraphBuilder::new() .add_node(NodeKind::Custom("test".into()), TestNode) diff --git a/tests/telemetry.rs b/tests/telemetry.rs index fcbd00b..a16ffd5 100644 --- a/tests/telemetry.rs +++ b/tests/telemetry.rs @@ -11,9 +11,7 @@ fn render_event_includes_colors_and_context() { let fmt = PlainFormatter::with_mode(FormatterMode::Colored); let ev = Event::node_message_with_meta("nodeX", 7, "ScopeX", "hello"); let render = fmt.render_event(&ev); - // Context should be set to scope label assert_eq!(render.context.as_deref(), Some("ScopeX")); - // Lines should contain colored body and reset code let joined = render.join_lines(); assert!(joined.contains(LINE_COLOR)); assert!(joined.contains(RESET_COLOR)); @@ -37,14 +35,13 @@ fn render_errors_formats_scope_lines_and_details() { let mut e2 = ErrorEvent::app(WeaveError::msg("oops")); e2.when = now; - let renders = fmt.render_errors(&[e1.clone(), e2.clone()]); + let renders = fmt.render_errors(&[e1, e2]); assert_eq!(renders.len(), 2); - // First render: should include colored scope, error, cause, tags, and context - let r0 = renders[0].clone(); - let head = r0.lines[0].clone(); - assert!(head.contains(CONTEXT_COLOR)); - assert!(head.contains(RESET_COLOR)); + // Runner scope + let r0 = &renders[0]; + assert!(r0.lines[0].contains(CONTEXT_COLOR)); + assert!(r0.lines[0].contains(RESET_COLOR)); let body = r0.lines.join(""); assert!(body.contains("error: boom")); assert!(body.contains("cause: inner")); @@ -55,36 +52,16 @@ fn render_errors_formats_scope_lines_and_details() { Some("Runner { session: \"sess\", step: 3 }") ); - // Second render: App scope with minimal fields - let r1 = renders[1].clone(); - let hdr = r1.lines[0].clone(); - assert!(hdr.contains("App")); + // App scope + let r1 = &renders[1]; + assert!(r1.lines[0].contains("App")); let body1 = r1.lines.join(""); assert!(body1.contains("error: oops")); - // no cause/tags/context lines should appear assert!(!body1.contains("cause:")); assert!(!body1.contains("tags:")); assert!(!body1.contains("context:")); } -#[test] -fn formatter_mode_colored_includes_ansi_codes() { - let fmt = PlainFormatter::with_mode(FormatterMode::Colored); - let ev = Event::node_message_with_meta("test_node", 1, "TestScope", "test message"); - let render = fmt.render_event(&ev); - let output = render.join_lines(); - - // Should contain ANSI color codes - assert!( - output.contains(LINE_COLOR), - "Colored mode should include LINE_COLOR" - ); - assert!( - output.contains(RESET_COLOR), - "Colored mode should include RESET_COLOR" - ); -} - #[test] fn formatter_mode_plain_excludes_ansi_codes() { let fmt = PlainFormatter::with_mode(FormatterMode::Plain); @@ -175,7 +152,7 @@ fn formatter_mode_plain_nested_errors_exclude_colors() { } #[test] -fn formatter_mode_auto_default() { +fn formatter_mode_defaults_to_auto_and_new_matches_default() { // FormatterMode::Auto should be the default let mode = FormatterMode::default(); assert_eq!(mode, FormatterMode::Auto); diff --git a/tests/types.rs b/tests/types.rs index 4139edc..91a53ce 100644 --- a/tests/types.rs +++ b/tests/types.rs @@ -2,7 +2,7 @@ use weavegraph::runtimes::types::{SessionId, StepNumber}; use weavegraph::types::{ChannelType, NodeKind}; #[test] -fn test_nodekind_predicates() { +fn nodekind_predicates_match_variant() { assert!(NodeKind::Start.is_start()); assert!(!NodeKind::Start.is_end()); assert!(!NodeKind::Start.is_custom()); @@ -18,7 +18,7 @@ fn test_nodekind_predicates() { } #[test] -fn test_nodekind_encode_decode() { +fn nodekind_encode_decode_roundtrip() { let test_cases = vec![ (NodeKind::Start, "Start"), (NodeKind::End, "End"), @@ -38,7 +38,7 @@ fn test_nodekind_encode_decode() { } #[test] -fn test_display() { +fn nodekind_and_channel_type_display_as_expected_strings() { assert_eq!(NodeKind::Start.to_string(), "Start"); assert_eq!(NodeKind::End.to_string(), "End"); assert_eq!( @@ -52,7 +52,7 @@ fn test_display() { } #[test] -fn test_nodekind_helper_targets() { +fn nodekind_target_strings_match_variant() { // as_target on variants assert_eq!(NodeKind::Start.as_target(), "Start"); assert_eq!(NodeKind::End.as_target(), "End"); @@ -64,7 +64,7 @@ fn test_nodekind_helper_targets() { } #[test] -fn test_serde_support() { +fn nodekind_and_channel_type_serde_roundtrip() { let nodes = vec![ NodeKind::Start, NodeKind::End, @@ -87,14 +87,14 @@ fn test_serde_support() { // Runtime types tests #[test] -fn test_session_id_creation() { +fn session_id_preserves_given_string() { let id = SessionId::new("test_session"); assert_eq!(id.as_str(), "test_session"); assert_eq!(id.to_string(), "test_session"); } #[test] -fn test_session_id_generation() { +fn session_id_generate_produces_unique_values() { let id1 = SessionId::generate(); let id2 = SessionId::generate(); // Generated IDs should be different @@ -102,7 +102,7 @@ fn test_session_id_generation() { } #[test] -fn test_step_number_arithmetic() { +fn step_number_next_increments_and_zero_is_initial() { let step = StepNumber::new(5); assert_eq!(step.value(), 5); assert_eq!(step.next().value(), 6); @@ -114,7 +114,7 @@ fn test_step_number_arithmetic() { } #[test] -fn test_step_number_saturation() { +fn step_number_next_saturates_at_max() { let max_step = StepNumber::new(u64::MAX); let next = max_step.next(); assert_eq!(next.value(), u64::MAX); // Should saturate, not overflow diff --git a/tests/utils.rs b/tests/utils.rs index 2853e5f..7bb66b1 100644 --- a/tests/utils.rs +++ b/tests/utils.rs @@ -6,7 +6,7 @@ use weavegraph::utils::id_generator::*; use weavegraph::utils::json_ext::*; #[test] -fn test_collections_helpers() { +fn extra_map_stores_and_retrieves_typed_values() { let mut map = new_extra_map(); map.insert_string("name", "test"); map.insert_number("count", 42); @@ -22,7 +22,7 @@ fn test_collections_helpers() { } #[test] -fn test_json_ext_deep_merge_and_path() { +fn deep_merge_combines_nested_objects_and_path_lookup_works() { let left = json!({"a": 1, "b": {"x": 10}}); let right = json!({"b": {"y": 20}, "c": 3}); let merged = deep_merge(&left, &right, MergeStrategy::DeepMerge).unwrap(); @@ -33,7 +33,7 @@ fn test_json_ext_deep_merge_and_path() { } #[test] -fn test_id_generator_basics() { +fn run_id_has_run_prefix_and_seeded_config_produces_unique_ids() { let id_gen = IdGenerator::new(); let run = id_gen.generate_run_id(); assert!(run.starts_with("run-")); @@ -50,7 +50,7 @@ fn test_id_generator_basics() { } #[test] -fn test_deterministic_rng() { +fn deterministic_rng_with_same_seed_produces_same_sequence() { let mut r1 = DeterministicRng::new(42); let mut r2 = DeterministicRng::new(42); assert_eq!(r1.random_u64(), r2.random_u64()); @@ -58,7 +58,7 @@ fn test_deterministic_rng() { } #[test] -fn test_clock_utils() { +fn mock_clock_advance_and_elapsed_check_work() { let mut clock = MockClock::new(1000); assert_eq!(clock.now(), 1000); clock.advance_secs(10); From 4cc427839895445e0ac98fb07397da4eeb7ffc1b Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 20:33:00 -0400 Subject: [PATCH 10/15] final revisions and switch from MIT to Apache-2.0 --- .github/workflows/ci.yml | 228 ++-- CONTRIBUTING.md | 4 + Cargo.toml | 2 +- LICENSE | 223 +++- README.md | 16 +- benches/event_bus_throughput.rs | 65 +- deny.toml | 24 +- docker-compose.yml | 74 +- docs/ARCHITECTURE.md | 631 +++------- docs/STREAMING.md | 708 +++-------- examples/advanced_patterns.rs | 136 +-- examples/basic_nodes.rs | 24 +- examples/convenience_streaming.rs | 4 +- examples/graph_execution.rs | 100 +- examples/scheduler_fanout.rs | 55 +- examples/streaming_events.rs | 180 ++- src/lib.rs | 18 +- tests/common/testing.rs | 4 +- tests/event_bus.rs | 1494 +++++++++++------------- tests/nodes.rs | 15 +- tests/runtimes_persistence_postgres.rs | 785 +++++++------ tests/runtimes_runner.rs | 161 ++- tests/smoke.rs | 6 +- tests/state_channels.rs | 51 +- 24 files changed, 2177 insertions(+), 2831 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c513dda..e808b9c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,167 +1,167 @@ -name: CI - -on: - push: - branches: - - main - pull_request: - -env: - CARGO_TERM_COLOR: always - RUST_VERSION: 1.90.0 - -jobs: - fmt: - name: fmt - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 +# Weavegraph CI — runs on push to main and on every pull request. +name: "CI" +# triggers +on: # event hooks + push: # push events + branches: ["main"] + pull_request: # all PR events + types: [opened, synchronize, reopened] + workflow_dispatch: # allow manual runs from the GitHub UI +# build env +env: # build-time defaults + CARGO_TERM_COLOR: "always" + RUST_VERSION: "1.90.0" +# jobs +jobs: # CI pipeline + # ── format ─────────────────────────────────────────────────────────────── + fmt: # formatting check + name: cargo fmt + runs-on: ubuntu-24.04 + steps: # fmt steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: + with: # fmt toolchain toolchain: ${{ env.RUST_VERSION }} components: rustfmt - - uses: Swatinem/rust-cache@v2 - with: + - uses: "Swatinem/rust-cache@v2" + with: # fmt cache prefix-key: fmt - - name: cargo fmt --check - run: cargo fmt --all -- --check - + - run: cargo fmt --all -- --check + # ─────────────────────────────────────────────────────────────────────── + # ── lint ───────────────────────────────────────────────────────────────── clippy: - name: clippy (${{ matrix.toolchain }} / ${{ matrix.track }}) - runs-on: ubuntu-latest - strategy: + name: cargo clippy (${{ matrix.toolchain }}) + runs-on: ubuntu-24.04 + strategy: # toolchain matrix fail-fast: false - matrix: + matrix: # toolchain variants include: - toolchain: "1.90.0" - track: required experimental: false - - toolchain: "stable" - track: canary + - toolchain: stable experimental: true continue-on-error: ${{ matrix.experimental }} - steps: - - uses: actions/checkout@v4 + steps: # clippy steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: - toolchain: ${{ matrix.toolchain }} + with: # clippy toolchain + toolchain: "${{ matrix.toolchain }}" components: clippy - - uses: Swatinem/rust-cache@v2 - with: + - uses: "Swatinem/rust-cache@v2" + with: # clippy cache prefix-key: clippy-${{ matrix.toolchain }} - - name: cargo clippy - run: cargo clippy --workspace --all-targets --all-features -- -D warnings - + - run: cargo clippy --workspace --all-targets --all-features -- -D warnings +# --- + # ── test ───────────────────────────────────────────────────────────────── test: - name: test (${{ matrix.toolchain }} / ${{ matrix.track }}) - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: + name: cargo test (${{ matrix.toolchain }}) + runs-on: ubuntu-24.04 + strategy: # test matrix + fail-fast: false # continue even if one toolchain fails + matrix: # test toolchains include: - toolchain: "1.90.0" - track: required experimental: false - - toolchain: "stable" - track: canary + - toolchain: stable experimental: true continue-on-error: ${{ matrix.experimental }} - steps: - - uses: actions/checkout@v4 + steps: # test steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: - toolchain: ${{ matrix.toolchain }} - - uses: Swatinem/rust-cache@v2 - with: + with: # test toolchain + toolchain: "${{ matrix.toolchain }}" + - uses: "Swatinem/rust-cache@v2" + with: # test cache prefix-key: test-${{ matrix.toolchain }} - - name: cargo test (lib only - postgres requires external service) - run: cargo test --lib --all-features + # postgres feature tests require an external service; run lib-only here + - run: cargo test --lib --all-features + # ── docs ───────────────────────────────────────────────────────────────── doc: - name: doc - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 + name: cargo doc + runs-on: ubuntu-24.04 + steps: # doc steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: + with: # doc toolchain toolchain: nightly - - uses: Swatinem/rust-cache@v2 - with: + - uses: "Swatinem/rust-cache@v2" + with: # doc cache prefix-key: doc - - name: cargo doc - env: + - name: Build docs — deny rustdoc warnings + env: # rustdoc flags RUSTDOCFLAGS: "--cfg docsrs -D warnings" run: cargo +nightly doc --workspace --all-features --no-deps - + # ─────────────────────────────────────────────────────────────────────── + # ── semver ─────────────────────────────────────────────────────────────── semver-checks: name: cargo semver-checks - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 + runs-on: ubuntu-24.04 + steps: # semver steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: + with: # semver toolchain toolchain: stable - uses: taiki-e/install-action@v2 - with: + with: # semver install tool: cargo-semver-checks - - uses: Swatinem/rust-cache@v2 - with: + - uses: "Swatinem/rust-cache@v2" + with: # semver cache prefix-key: semver-checks - - name: Check semver - # cargo-semver-checks requires rustc >= 1.91.0; run on stable, not pinned MSRV - run: cargo +stable semver-checks check-release --workspace + # semver-checks requires rustc >= 1.91.0; run on stable rather than pinned MSRV + - run: cargo +stable semver-checks check-release --workspace - deny: - name: cargo deny - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 + # ── deny ───────────────────────────────────────────────────────────────── + deny: # license/advisory check + name: "cargo deny" + runs-on: ubuntu-24.04 + steps: # deny steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: + with: # deny toolchain toolchain: ${{ env.RUST_VERSION }} - - uses: taiki-e/install-action@v2 - with: - tool: cargo-deny - - uses: Swatinem/rust-cache@v2 - with: - prefix-key: deny + - uses: "taiki-e/install-action@v2" + with: # deny install + tool: "cargo-deny" + - uses: "Swatinem/rust-cache@v2" + with: # deny cache + prefix-key: "deny" - run: cargo deny check - - machete: - name: cargo machete - runs-on: ubuntu-latest - continue-on-error: true - steps: - - uses: actions/checkout@v4 + # ─────────────────────────────────────────────────────────────────────── + # ── machete ────────────────────────────────────────────────────────────── + machete: # unused-dep check + name: "cargo machete" + runs-on: ubuntu-24.04 + continue-on-error: true # non-blocking + steps: # machete steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: + with: # machete toolchain toolchain: ${{ env.RUST_VERSION }} - - uses: taiki-e/install-action@v2 - with: - tool: cargo-machete - - uses: Swatinem/rust-cache@v2 - with: - prefix-key: machete + - uses: "taiki-e/install-action@v2" + with: # machete install + tool: "cargo-machete" + - uses: "Swatinem/rust-cache@v2" + with: # machete cache + prefix-key: "machete" - run: cargo machete --with-metadata + # ── benchmarks ─────────────────────────────────────────────────────────── benchmarks: name: benchmark regression - runs-on: ubuntu-latest - # Only run on pushes to main to establish baselines and detect regressions + runs-on: ubuntu-24.04 if: github.ref == 'refs/heads/main' - steps: - - uses: actions/checkout@v4 + steps: # bench steps + - uses: "actions/checkout@v4" - uses: dtolnay/rust-toolchain@stable - with: + with: # bench toolchain toolchain: ${{ env.RUST_VERSION }} - - uses: Swatinem/rust-cache@v2 - with: + - uses: "Swatinem/rust-cache@v2" + with: # bench cache prefix-key: bench - - name: Run benchmarks - run: cargo bench --workspace - - name: Store benchmark results - uses: actions/upload-artifact@v4 - with: + - run: cargo bench --workspace + - uses: actions/upload-artifact@v4 + with: # artifact upload name: benchmark-results path: target/criterion retention-days: 30 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c3373be..6085acb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -152,3 +152,7 @@ This project follows the [Contributor Covenant Code of Conduct](CODE_OF_CONDUCT. - Consult the [documentation](https://docs.rs/weavegraph) Thank you for helping make Weavegraph better! + +## License + +By contributing to Weavegraph, you agree that your contributions will be licensed under the [Apache License, Version 2.0](LICENSE). diff --git a/Cargo.toml b/Cargo.toml index 92c77b4..df92454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "weavegraph" version = "0.6.0" edition = "2024" description = "Graph-driven, concurrent agent workflow framework with versioned state, deterministic barrier merges, and rich diagnostics." -license = "MIT" +license = "Apache-2.0" repository = "https://github.com/Idleness76/weavegraph" readme = "README.md" keywords = ["graph", "workflow", "concurrency", "agents", "tracing"] diff --git a/LICENSE b/LICENSE index 3f51c8b..d645695 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,202 @@ -MIT License - -Copyright (c) 2025 weavegraph Contributors - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 111126e..c33a9a3 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,9 @@ [![Crates.io](https://img.shields.io/crates/v/weavegraph.svg)](https://crates.io/crates/weavegraph) [![Documentation](https://docs.rs/weavegraph/badge.svg)](https://docs.rs/weavegraph) -[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) +[![License: Apache-2.0](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) **Graph-driven, concurrent agent workflow framework for Rust.** - --- > **Pre-1.0 Status** @@ -111,7 +110,7 @@ async fn main() -> Result<(), Box> { For testing and ephemeral workflows use the InMemory checkpointer: -```rust +~~~rust use weavegraph::runtimes::{AppRunner, CheckpointerType}; // After compiling the graph into an `App`: @@ -119,9 +118,9 @@ let runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() - .await; -``` - + .await + .expect("runner build"); +~~~ Run the comprehensive test suite: ```bash @@ -162,7 +161,6 @@ Before merging or cutting a release, run full local parity checks: `ci-local.sh` intentionally fails if required tools are missing (`cargo-semver-checks`, `cargo-deny`) so a local pass is a meaningful signal for CI. ## Resources - - **[Migration Guide](docs/MIGRATION.md)** - Upgrade paths between pre-1.0 releases - **[Architecture Guide](docs/ARCHITECTURE.md)** - Deep dive into core design and internals - **[Examples Directory](examples/)** - Runnable patterns: graph execution, scheduling, streaming, persistence, and more @@ -177,11 +175,11 @@ We welcome contributions! See [CONTRIBUTING.md](CONTRIBUTING.md). ## License -MIT — see [LICENSE](LICENSE). +Apache-2.0 — see [LICENSE](LICENSE). ## 🔗 Links - [Documentation](https://docs.rs/weavegraph) - [Crates.io](https://crates.io/crates/weavegraph) - [Repository](https://github.com/Idleness76/weavegraph) -- [Issues](https://github.com/Idleness76/weavegraph/issues) +- [Issues](https://github.com/Idleness76/weavegraph/issues) — bug reports, feature requests, and discussion diff --git a/benches/event_bus_throughput.rs b/benches/event_bus_throughput.rs index 2b33c72..b3bd3e0 100644 --- a/benches/event_bus_throughput.rs +++ b/benches/event_bus_throughput.rs @@ -1,36 +1,33 @@ use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; -use tokio::runtime::Runtime; -use weavegraph::event_bus::{Event, EventBus}; - -const BATCH_SIZES: &[usize] = &[64, 256, 1024]; - -async fn publish_batch(bus: &EventBus, batch: usize) { - bus.listen_for_events(); - let emitter = bus.get_emitter(); - for i in 0..batch { - emitter - .emit(Event::diagnostic("bench", format!("message-{i}"))) +use tokio::runtime::Runtime; // bench harness runtime +use weavegraph::event_bus::{Event, EventBus}; // types under benchmark +// --- +const SIZES: &[usize] = &[64, 256, 1024]; // batch sweep +// --- +async fn drain(bus: &EventBus, n: usize) { + bus.listen_for_events(); // arm the internal sink + let tx = bus.get_emitter(); // cloneable emitter handle + for seq in 0..n { + tx.emit(Event::diagnostic("throughput", format!("msg-{seq}"))) .expect("emit"); - } - bus.stop_listener().await; -} - -fn event_bus_throughput(c: &mut Criterion) { - let runtime = Runtime::new().expect("runtime"); - let mut group = c.benchmark_group("event_bus_publish"); - - for &batch in BATCH_SIZES { - group.throughput(Throughput::Elements(batch as u64)); - group.bench_with_input(BenchmarkId::from_parameter(batch), &batch, |b, &size| { - b.to_async(&runtime).iter(|| async { - let bus = EventBus::default(); - publish_batch(&bus, size).await; - }); - }); - } - - group.finish(); -} - -criterion_group!(benches, event_bus_throughput); -criterion_main!(benches); + } // end emit loop + bus.stop_listener().await; // flush and stop +} // end drain +// --- +fn throughput(c: &mut Criterion) { + let rt = Runtime::new().expect("tokio runtime"); + let mut grp = c.benchmark_group("event_bus_publish"); // group label + // sweep over batch sizes + for &n in SIZES { + grp.throughput(Throughput::Elements(n as u64)); + grp.bench_with_input(BenchmarkId::from_parameter(n), &n, |b, &count| { + b.to_async(&rt).iter(|| async { + drain(&EventBus::default(), count).await; + }); // end iter + }); // end bench_with_input + } // end sizes loop + grp.finish(); +} // end throughput +// --- +criterion_group!(benches, throughput); // register bench group +criterion_main!(benches); // harness entry point diff --git a/deny.toml b/deny.toml index 2d2b46e..428fb09 100644 --- a/deny.toml +++ b/deny.toml @@ -1,24 +1,22 @@ +[advisories] +ignore = [] + [licenses] allow = [ - "MIT", "Apache-2.0", + "BSL-1.0", "BSD-2-Clause", "BSD-3-Clause", + "CDLA-Permissive-2.0", "ISC", + "LGPL-2.1-or-later", + "MIT", + "MPL-2.0", "OpenSSL", - "Zlib", "Unicode-3.0", "Unlicense", - "LGPL-2.1-or-later", - "BSL-1.0", - "CDLA-Permissive-2.0", - "MPL-2.0", -] - -[advisories] -ignore = [] - + "Zlib", +] # allowed SPDX identifiers [bans] -# Multiple versions of these crates are unavoidable transitive dependencies -# They come from incompatible version ranges in upstream dependencies +# Multiple crate versions are sometimes unavoidable due to transitive dependency version ranges. multiple-versions = "warn" diff --git a/docker-compose.yml b/docker-compose.yml index df2d4be..7c23ff4 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,43 +1,43 @@ -services: - ollama: - container_name: ollama - image: ollama/ollama:latest - volumes: - - ollama_data:/root/.ollama - ports: - - "11434:11434" +networks: + weavegraph: + name: weavegraph + +volumes: + ollama_data: # Ollama model weights and runtime data + postgres_data: # PostgreSQL data directory + +services: # containers used by local dev and integration tests + ollama: # local LLM inference service + container_name: ollama # named for easy docker exec access + image: "ollama/ollama:latest" # always pull the latest tag + restart: unless-stopped # restart on crash but not on explicit stop + tty: true # allocate a pseudo-TTY for ollama's interactive output networks: - weavegraph + ports: + - "11434:11434" + volumes: + - ollama_data:/root/.ollama + environment: # runtime config for ollama server + OLLAMA_HOST: "0.0.0.0" + OLLAMA_KEEP_ALIVE: "24h" + # postgres — used by integration tests that require the postgres feature flag + postgres: # weavegraph test database + container_name: weavegraph_postgres # unique name for CI scripts + image: "postgres:18-alpine" # small Alpine-based image restart: unless-stopped - tty: true - environment: - - OLLAMA_KEEP_ALIVE=24h - - OLLAMA_HOST=0.0.0.0 - - postgres: - container_name: weavegraph_postgres - image: postgres:18-alpine + networks: # attach to shared overlay network + - weavegraph # shared overlay network + ports: + - "5432:5432" + volumes: # persist data across container restarts + - postgres_data:/var/lib/postgresql/data # named volume environment: + POSTGRES_DB: weavegraph_test POSTGRES_USER: weavegraph POSTGRES_PASSWORD: weavegraph - POSTGRES_DB: weavegraph_test - ports: - - "5432:5432" - networks: - - weavegraph - restart: unless-stopped - volumes: - - postgres_data:/var/lib/postgresql/data - healthcheck: - test: ["CMD-SHELL", "pg_isready -U weavegraph -d weavegraph_test"] - interval: 5s - timeout: 5s - retries: 5 - -volumes: - ollama_data: - postgres_data: - -networks: - weavegraph: - name: weavegraph + healthcheck: # wait for postgres to be ready before tests run + test: ["CMD-SHELL", "pg_isready -U weavegraph -d weavegraph_test"] # health probe + interval: 5s # check every 5 seconds + timeout: 5s # query must complete within 5 seconds + retries: 5 # mark unhealthy after 5 consecutive failures diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 3fd8761..df534dd 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -1,465 +1,190 @@ -# Architecture Overview - -Comprehensive technical documentation for Weavegraph's internal design and module organization. - -**Related Documentation:** -- [Quickstart](QUICKSTART.md) - Core concepts, messages, state, and graphs -- [Operations Guide](OPERATIONS.md) - Event streaming, persistence, testing, and production -- [Documentation Index](INDEX.md) - Complete reference with anchor links - -## 🎓 Project Background - -Weavegraph originated as a capstone project for a Rust online course, developed by contributors with Python/TypeScript backgrounds and experience with LangGraph and LangChain. The goal was to bring similar graph-based workflow capabilities to Rust while leveraging its performance, safety, and concurrency advantages. - -While rooted in educational exploration, Weavegraph continues active development well beyond the classroom setting. The core architecture is solid and the framework is functional, but as an early beta release (v0.2.x), it's still maturing; use with awareness of ongoing API evolution. - - -| Crate | Purpose | Highlights | -| ----- | ------- | ---------- | -| `weavegraph` | Executes concurrent, stateful graphs with structured observability. | Graph builder + runtime, event bus, checkpointing, reducers, scheduler. | -| `wg-ragsmith` | Provides ingestion, semantic chunking, and storage utilities for RAG workloads. | HTML/JSON parsers, semantic chunkers, SQLite vector store helpers. | - - -## Overview flowchart of the app (mermaid) - -```mermaid -flowchart TB - -subgraph Client - user[Client App or UI] +# Weavegraph crate anatomy + +- Start with [QUICKSTART.md](QUICKSTART.md) if you want the public API before the internals. +- Read [STREAMING.md](STREAMING.md) for the event delivery helpers described here. +- Use [OPERATIONS.md](OPERATIONS.md) for deployment, persistence, and troubleshooting notes. +- Browse [INDEX.md](INDEX.md) for the rest of the published guides. + +This file explains the runtime pieces that live under `src/` in the `weavegraph` crate. The codebase separates graph authoring from session execution: `graphs::GraphBuilder` assembles a workflow definition, `app::App` stores the compiled package, and `runtimes::AppRunner` turns that package into live sessions with checkpoints and structured events. + +## Module ledger +| Runtime slice | Primary modules | What that slice is responsible for | +| --- | --- | --- | +| Graph definition | `graphs::{builder,edges,compilation,iteration}` | Register nodes, store unconditional and conditional edges, validate topology, and expose graph iteration helpers. | +| Compiled facade | `app` | Hold the validated node registry plus graph metadata, then expose `invoke`, `invoke_streaming`, `invoke_with_channel`, and `invoke_with_sinks`. | +| Session engine | `runtimes::{runner,execution,session,streaming,observer}` | Create sessions, execute supersteps, emit completion markers, and manage iterative invocations. | +| Persistence backends | `runtimes::{checkpointer,checkpointer_sqlite,checkpointer_postgres,persistence,replay}` | Save and restore checkpoints in memory, SQLite, or Postgres. | +| Scheduling | `schedulers::scheduler` | Decide which frontier nodes run now, skip already-consumed snapshots, and cap concurrent work. | +| State model | `state`, `channels`, `message`, `control`, `types` | Represent the versioned workflow state, user/assistant messages, extra JSON data, errors, and routing commands. | +| Merge policy | `reducers::{reducer_registry,add_messages,map_merge,add_errors}` | Apply `NodePartial` deltas to the versioned channels in a deterministic order. | +| Observability | `event_bus::{bus,hub,event,sink,diagnostics}` | Broadcast node and framework events to sinks, subscribers, and health diagnostics. | +| Optional integrations | `llm`, `telemetry`, `utils` | Attach LLM adapters, render telemetry, and provide shared helpers such as IDs, clocks, and JSON utilities. | + +## Source tree snapshot +~~~text +src/ +├── app.rs +├── graphs/ +│ ├── builder.rs +│ ├── compilation.rs +│ ├── edges.rs +│ └── iteration.rs +├── runtimes/ +│ ├── runner.rs +│ ├── execution.rs +│ ├── session.rs +│ ├── streaming.rs +│ ├── checkpointer.rs +│ ├── checkpointer_sqlite.rs +│ └── checkpointer_postgres.rs +├── event_bus/ +│ ├── bus.rs +│ ├── hub.rs +│ ├── event.rs +│ ├── sink.rs +│ └── diagnostics.rs +├── schedulers/scheduler.rs +├── reducers/ +├── state.rs +├── node.rs +└── lib.rs +~~~ + +## End-to-end control path +~~~mermaid +flowchart LR +subgraph AuthoringSurface[Authoring surface] +builder[GraphBuilder chain] +registry[Reducer registry] end - -subgraph Build - gb[GraphBuilder] +subgraph CompiledArtifact[Compiled artifact] +app[App definition] end - -subgraph Runtime - app[App] - sched[Scheduler] - router[Router: Edges and Commands] - barrier[Barrier Applier] +subgraph SessionHost[Per-session runtime] +runner[AppRunner] +sched[Frontier scheduler] +barrier[Barrier merge] end - -subgraph Nodes - usernode[Custom User Nodes] - llmnode[LLM Node] - toolnode[Tool Node] +subgraph StateCells[Versioned state] +store[VersionedState] +snap[StateSnapshot clone] end - -subgraph State - vstate[Versioned State] - snap[State Snapshot] +subgraph NodeLayer[Executable nodes] +workers[Node implementations] end - -subgraph Reducers - redreg[Reducer Registry] +subgraph IOEdges[Persistence and observers] +cp[Checkpointer backend] +bus[EventBus and EventHub] +clients[Subscribers or sinks] end - -subgraph Checkpoint - cpif[Checkpointer: SQLite/InMemory] -end - -subgraph EventBus - eventbus[Event Bus with Sinks] -end - -subgraph Rig - rigad[Rig Adapter] - llmprov[LLM Provider: Ollama/MCP] -end - -subgraph Tools - toolreg[Tool Registry] - exttools[External Tools] -end - -user --> gb -gb -->|compile| app - -user -->|invoke/invoke_streaming| app - -app --> sched -sched -->|creates| snap -vstate --> snap - -sched --> usernode -sched --> llmnode -sched --> toolnode - -usernode -->|NodePartial| barrier -llmnode -->|NodePartial| barrier -toolnode -->|NodePartial| barrier -redreg --> barrier -barrier -->|merges updates| vstate - -snap --> router -app --> router -router -->|next frontier| sched - -llmnode --> rigad -rigad --> llmprov -llmprov --> rigad -rigad --> llmnode - -toolnode --> toolreg -toolnode --> exttools -exttools --> toolnode - -barrier --> cpif - -app --> eventbus -eventbus -->|streams events| user -``` - ---- - -## Workspace Topology - -``` -docs/ → Architectural plans, production hardening roadmap. -weavegraph/ → Core orchestration crate (library + examples + tests). -wg-ragsmith/ → RAG utilities crate (library + examples + tests). -data/ → Local development databases (ignored in version control). -external/ → Vendor snapshots (RAGatouille, raptor) kept outside the workspace. -.github/workflows/ → Continuous integration pipelines. -ARCHITECTURE.md → This document. -``` - -The workspace targets Rust 1.89 as the minimum supported version and enables 2024 edition -features across both crates. - ---- - -## `weavegraph` Crate - -`weavegraph` implements the runtime that powers concurrent, graph-based workflows. The library -is organised around a handful of core modules: - -| Module | Highlights | -| ------ | ---------- | -| `graphs::{builder, edges, compilation}` | `GraphBuilder` DSL for wiring nodes, unconditional and conditional edges, and compiling into a runnable `App`. | -| `app` | High-level façade that owns compiled nodes/edges, reducer registry, and runtime config. Provides `invoke`, `invoke_streaming`, and event stream APIs. | -| `runtimes::{runner, checkpointer_*, runtime_config}` | `AppRunner` drives supersteps, coordinates the scheduler, applies barriers, and persists to SQLite (via `sqlx::migrate!`). | -| `schedulers` | Dependency-aware scheduler that fans out runnable nodes and enforces bounded concurrency. | -| `node` | `Node` trait, `NodeContext`, `NodePartial`, and error types used by application code. | -| `state`, `channels`, `reducers` | Versioned state model split across message/extra/error channels with deterministic merge reducers. | -| `event_bus` | Broadcast-based event hub with sinks (stdout, memory, channel, JSON Lines) and streaming helpers for web servers or CLIs. Events support JSON serialization for log aggregation. | -| `telemetry`, `utils` | Tracing helpers, deterministic RNG, clocks, ID generators, and collection utilities. | - -### Authoring Nodes & State - -Weavegraph applications revolve around three building blocks: nodes, state, and graphs. - -> **Note:** `NodeKind::Start` and `NodeKind::End` are virtual structural endpoints. -> You never register them with `add_node`; attempts to do so are ignored with a warning. -> Define only your executable (custom) nodes and connect them with edges from `Start` and to `End`. - -```rust -use weavegraph::{ - graphs::GraphBuilder, - message::{Message, Role}, - node::{Node, NodeContext, NodePartial}, - state::VersionedState, - types::NodeKind, -}; -use async_trait::async_trait; - -struct GreetingNode; - -#[async_trait] -impl Node for GreetingNode { - async fn run( - &self, - _snapshot: weavegraph::state::StateSnapshot, - ctx: NodeContext, - ) -> Result { - ctx.emit("greeting", "Saying hi!")?; - Ok(NodePartial::new().with_messages(vec![Message::with_role( - Role::Assistant, - "Hello!", - )])) - } -} - -let app = GraphBuilder::new() - .add_node(NodeKind::Custom("greet".into()), GreetingNode) - .add_edge(NodeKind::Start, NodeKind::Custom("greet".into())) - .add_edge(NodeKind::Custom("greet".into()), NodeKind::End) - .compile()?; - -let initial = VersionedState::new_with_user_message("Hi?"); -let result = app.invoke(initial).await?; -``` - -**Key practices:** - -- Prefer typed roles with `Message::with_role(Role::...)` - see [Messages](QUICKSTART.md#messages) -- Build state with `VersionedState::new_with_user_message` or the builder pattern - see [State Management](QUICKSTART.md#state) -- Use `NodeContext::emit*` helpers for telemetry instead of writing directly to stdout -- Return structured errors (`NodeError::MissingInput`, `NodeError::Provider`, `NodeError::Other`) or populate `NodePartial::with_errors` for recoverable issues - see [Error Handling](OPERATIONS.md#errors) - -### Custom Reducers {#custom-reducers} - -Weavegraph supports custom reducers for extending or replacing channel update behavior. By default, -three reducers are registered: - -- **Message channel**: `AddMessages` – Appends messages to the message list -- **Extra channel**: `MapMerge` – Shallow merges JSON objects in the extra data map -- **Error channel**: `AddErrors` – Appends error events to the error list - -To register custom reducers: - -```rust +builder -->|compile + validate| app +registry -->|attached during build| app +app -->|spawn invocation| runner +runner --> sched +store --> snap +sched -->|dispatch frontier| workers +snap --> workers +workers -->|NodePartial values| barrier +barrier -->|reduced updates| store +barrier -->|autosave checkpoint| cp +workers -->|ctx.emit / llm events| bus +runner -->|framework diagnostics| bus +bus --> clients +~~~ + +## GraphBuilder and compile-time checks +`GraphBuilder` starts empty and accepts fluent registration calls. `add_node` stores only executable `NodeKind::Custom` entries; attempts to register `Start` or `End` are ignored with a warning because those are structural markers rather than real nodes. +- `add_edge` records a fixed adjacency from one node kind to another. +- `add_conditional_edge` stores a predicate that inspects `StateSnapshot` and chooses the next frontier at runtime. +- `with_runtime_config` and `with_event_bus_config` carry runtime settings into the finished `App`. +- `with_reducer` appends one reducer to a channel, while `with_reducer_registry` swaps the whole merge policy. +`GraphBuilder::compile()` delegates to `graphs::compilation` before constructing the `App`. The validator rejects missing entry edges, unconditional cycles, duplicate edges, references to unknown custom nodes, and edges that originate from `End`. Reachability checks for “reachable from Start” and “has a route to End” run when the graph has no conditional edges, because predicates can hide the true path until execution time. + +## `App` compared with `AppRunner` +`app::App` is the reusable compiled definition. It owns the node map, unconditional edge map, conditional edge list, reducer registry, and `RuntimeConfig`. Cloning an `App` is cheap enough for request handlers because the expensive work already happened at compile time. +`runtimes::AppRunner` is the execution host. It owns the per-run session table, the chosen `EventBus`, the optional checkpointer, the autosave flag, and optional observer or clock injection. One `App` can therefore serve many runners with separate event sinks or persistence settings. +The main entry points line up like this: +- `App::invoke` builds a runner from the runtime config and waits for the final `VersionedState`. +- `App::invoke_with_channel` appends a `ChannelSink` and returns `(Result, flume::Receiver)`. +- `App::invoke_with_sinks` keeps the configured sinks and appends any extra sinks supplied by the caller. +- `App::invoke_streaming` allocates a fresh bus plus `EventStream`, spawns the workflow on Tokio, and hands back `(InvocationHandle, EventStream)`. +- `AppRunner::builder()` is the lower-level route when you need full control over the bus, checkpoint backend, or iterative session lifecycle. +`AppRunner` also exposes `create_session`, `create_iterative_session`, `invoke_next`, and `finish_iterative_session`, which is why the runner remains the escape hatch for multi-turn applications. + +## Versioned state, snapshots, and barrier reduction +The workflow state lives in `state::VersionedState`, which groups three independent channels: message history, arbitrary extra JSON, and accumulated error events. Each channel tracks its own version so the scheduler can tell whether a node has already consumed the current data. +Nodes never mutate the shared state directly. Each `Node::run` receives an immutable `StateSnapshot` plus a `NodeContext`, then returns a `NodePartial`. A partial can carry any combination of: +- appended messages, +- extra key/value updates, +- recoverable error events, +- a frontier command that changes routing. +`App::apply_barrier` collects all partials from one superstep, merges them into one aggregate update, and runs the reducer registry channel by channel. Channel versions increase only when the content actually changed, which keeps scheduler decisions deterministic. +The built-in reducer registry is assembled in `ReducerRegistry::default()`: +- `AddMessages` appends emitted messages. +- `MapMerge` performs a shallow JSON merge for extras and treats `null` as key deletion. +- `AddErrors` appends recoverable error entries. +Custom reducers can be stacked on a channel in registration order, so middleware-style validation or post-processing is possible without replacing the whole runtime. + +### Replacing one merge rule +~~~rust,no_run use std::sync::Arc; +use weavegraph::graphs::GraphBuilder; use weavegraph::reducers::{Reducer, ReducerRegistry}; +use weavegraph::state::VersionedState; +use weavegraph::node::NodePartial; use weavegraph::types::ChannelType; - -// Define a custom reducer -struct MyCustomReducer; - -impl Reducer for MyCustomReducer { +struct LastWriteWins; +impl Reducer for LastWriteWins { + fn definition_label(&self) -> &'static str { "docs::last_write_wins" } fn apply(&self, state: &mut VersionedState, update: &NodePartial) { - // Custom merge logic here - } -} - -// Register during graph building -let app = GraphBuilder::new() - .add_node(...) - .with_reducer(ChannelType::Message, Arc::new(MyCustomReducer)) - .compile()?; - -// Or replace the entire registry -let custom_registry = ReducerRegistry::new() - .with_reducer(ChannelType::Message, Arc::new(MyCustomReducer)); - -let app = GraphBuilder::new() - .add_node(...) - .with_reducer_registry(custom_registry) - .compile()?; -``` - -Multiple reducers can be registered for the same channel and will be applied in registration order. -This enables middleware-style processing, validation, or transformation of channel updates during -barrier synchronization. - -### Execution Flow - -1. **Authoring** – Build a graph with `GraphBuilder`, registering nodes (implementations of `Node`) - and the edges that connect them. Conditional edges can inspect `StateSnapshot` at runtime. - See [Graph Building](QUICKSTART.md#graphs) for details. -2. **Compilation** – `GraphBuilder::compile()` validates topology and produces an `App`. -3. **Invocation** – `App::invoke()` (or streaming variants like `invoke_streaming`, `invoke_with_channel`) - constructs an `AppRunner` with the chosen checkpointer (`InMemory` or SQLite), and event bus configuration. - See [Event Streaming](OPERATIONS.md#event-streaming) for streaming patterns. -4. **Scheduling** – The scheduler selects runnable nodes, issues `NodeContext`s, and executes - nodes concurrently. Each node returns a `NodePartial` with channel deltas and optional - control-flow directives. -5. **Barrier & Reduction** – Reducers merge channel updates deterministically, update the - versioned state, and hand control back to the scheduler for the next superstep. - See [Custom Reducers](#custom-reducers) above. -6. **Persistence & Observability** – Checkpointer snapshots state into SQLite (when enabled), - the event bus broadcasts diagnostics / LLM chunk streams, and telemetry surfaces to sinks. - See [Persistence](OPERATIONS.md#persistence) and [Event Streaming](OPERATIONS.md#event-streaming). - -### Optional Features - -* `rig` – Enables Rig-based LLM support (Ollama/MCP integrations). -* `llm` – Backward-compatible alias to `rig` for 0.3.x (planned removal in 0.4.0). -* `sqlite-migrations` – Turns on SQLite-backed persistence (default). -* `examples` – Pulls in extra dependencies used by a subset of examples (e.g. `reqwest`, `scraper`). - -### Tests & Examples - -* `weavegraph/tests/` – Covers state channels, reducers, scheduler semantics, checkpointer, and event bus. - See [Testing](OPERATIONS.md#testing) for running tests and patterns. -* `examples/` – Progressive walkthroughs: - * `basic_nodes.rs`, `graph_execution.rs`, `scheduler_fanout.rs` show core messaging and state channels. - See [Messages](QUICKSTART.md#messages) and [State](QUICKSTART.md#state). - * `advanced_patterns.rs` covers conditional routing and control-flow helpers. - * `streaming_events.rs`, `convenience_streaming.rs` demonstrate the - broadcast event bus and web-friendly streaming patterns. - See [Event Streaming](OPERATIONS.md#event-streaming). - * `event_backpressure.rs`, `json_serialization.rs`, `errors_pretty.rs` cover production-facing - concerns like lag handling, JSON sinks, and pretty diagnostics. - ---- - -### Backpressure and Drop Policy - -The event bus uses a bounded broadcast channel (default capacity: 1024 events per subscriber). -When a subscriber falls behind faster producers, the following semantics apply: - -- Slow subscribers receive a lag notice and skip older events (no blocking of producers) -- Missed events are counted and exposed via sink diagnostics -- A WARN log entry is emitted with the number of dropped events and the running total -- Streams continue from the most recent position for graceful degradation under load - -To adjust capacity, configure the event bus when using `App::invoke_streaming` or construct -an `EventBus` directly with custom capacity via `EventBus::with_capacity`. - -For practical guidance and code samples, see: -- [Event Streaming](OPERATIONS.md#event-streaming) for patterns and sink configuration -- `docs/STREAMING.md` for detailed tuning guidance - -## `wg-ragsmith` Crate - -`wg-ragsmith` contains the ingestion and vector-store tooling used by RAG pipelines. It can be -used standalone or pulled into `weavegraph` via the `examples` feature. - -| Module | Highlights | -| ------ | ---------- | -| `ingestion::{cache, chunk, resume}` | Disk-backed document cache, chunk-to-ingestion conversion, and resumable pipeline tracking. | -| `semantic_chunking::{html, json, segmenter, embeddings, service}` | HTML/JSON preprocessors, statistical breakpoint strategies, mock/real embedding providers, and the async chunking service. | -| `stores::sqlite` | `SqliteChunkStore` built on `rig-sqlite` + `sqlite-vec`, including schema, vec3 registration, and helper methods to upsert/search chunks. | -| `types` | `RagError` and supporting data structures for ingestion/persistence. | - -### Examples - -* `examples/rust_book_pipeline.rs` – Async ingestion pipeline that scrapes the Rust book, - chunks and embeds sections, and writes them into SQLite. -* `examples/query_chunks.rs` & `query_db.sh` – Smoke tests showing how to query stored chunks. - -These examples share environment variables with the weavegraph RAG demo (see `.env.example`). - -### Feature Flags - -* `semantic-chunking-tiktoken` (default) – OpenAI tiktoken tokeniser. -* `semantic-chunking-rust-bert` – Enables Rust-BERT based embedding pipeline. -* `semantic-chunking-segtok` – Alternative segmentation strategy. - ---- - -## Shared Operational Pieces - -* **Tooling** – Standard Rust tooling (`cargo fmt`, `cargo clippy`, `cargo test`, - `cargo +nightly doc`, `cargo deny`, `cargo machete`) plus `sqlx` migrations keep local - workflows and CI aligned. -* **CI/CD** – `.github/workflows/ci.yml` runs required checks on `1.90.0`, uses current - stable as a canary lane, and validates docs on `nightly` with - `RUSTDOCFLAGS="--cfg docsrs -D warnings"`. -* **Migrations** – `weavegraph/migrations` houses the `sqlx` migration set for the SQLite - checkpointer. Use `sqlx migrate` to apply or rollback changes. -* **Docs** – `docs/` captures forward-looking design documents (event bus refactor, - control-flow commands, hybrid RAG pipeline) and the production readiness plan. Use - this architecture document as the entry point. - ---- -## petgraph Comparison - -Weavegraph's graph implementation was designed with workflow execution in mind, making different -tradeoffs than the general-purpose [petgraph](https://github.com/petgraph/petgraph) crate. -This section documents the architectural differences and integration opportunities. - -### Architecture Comparison - -| Aspect | Weavegraph | petgraph | -|--------|-----------|----------| -| **Primary Use Case** | Workflow orchestration with async node execution | General graph algorithms and data structures | -| **Graph Type** | Custom `FxHashMap>` adjacency | `Graph`, `StableGraph`, `GraphMap`, `MatrixGraph` | -| **Node Identity** | `NodeKind` enum (Start/End/Custom) | `NodeIndex` (u32 handle) | -| **Node Data** | Nodes carry `Arc` trait objects | Generic node weight type `N` | -| **Edge Storage** | HashMap adjacency list with conditional predicates | Compact edge list with indices | -| **Edge Data** | Unconditional or `EdgePredicate` closures | Generic edge weight type `E` | -| **Cycle Detection** | Custom DFS with three-color marking | `petgraph::algo::is_cyclic_directed` | -| **Reachability** | Custom BFS from Start | `petgraph::algo::has_path_connecting` | -| **Algorithms** | Validation-focused (cycles, reachability, deadends) | Rich library (Dijkstra, MST, SCC, isomorphism, max flow) | -| **Async Support** | First-class (nodes are async) | None (pure data structure) | -| **Serialization** | Custom JSON via serde | `serde-1` feature | - -### Key Differences Explained - -**Why Weavegraph uses a custom graph:** - -1. **Domain-Specific Semantics** — `NodeKind::Start` and `NodeKind::End` are virtual structural - endpoints that cannot be registered as executable nodes. This enables clear workflow boundaries - without special-casing in user code. - -2. **Conditional Edges** — petgraph edges are static data. Weavegraph edges can be runtime - predicates (`EdgePredicate`) that inspect state to determine routing. This is fundamental to - agent decision-making workflows. - -3. **Execution Context** — Nodes aren't just data; they're async executables with access to - `NodeContext` for event emission and metadata. petgraph's node weights are passive data. - -4. **Validation Errors** — Compilation produces domain-specific errors like `UnknownNode`, - `CycleDetected { path }`, `UnreachableFromStart`, and `DeadendNode` with actionable context. - -**petgraph advantages:** - -1. **Battle-tested** — 3.7k+ GitHub stars, 144 contributors, extensive production usage -2. **Memory-efficient** — Compact edge storage, cache-friendly node indices -3. **Algorithm library** — Dijkstra, topological sort, strongly connected components, etc. -4. **Index stability** — `StableGraph` maintains valid indices through mutations - -### Integration Approach - -Weavegraph takes a **selective adoption** approach rather than replacing its core graph: - -```rust -// Feature-gated behind `petgraph-compat` -#[cfg(feature = "petgraph-compat")] -impl From<&CompiledGraph> for petgraph::Graph { - fn from(graph: &CompiledGraph) -> Self { - // Convert for visualization or analysis + if let Some(patch) = &update.extra { + for (key, value) in patch { state.extra.get_mut().insert(key.clone(), value.clone()); } + } } } -``` - -**Current integrations:** - -- **Graph iteration API** — `Graph::nodes()` and `Graph::edges()` iterators mirror petgraph idioms -- **Topological sort** — `Graph::topological_sort()` for deterministic node ordering -- **DOT export** — Optional petgraph-based visualization via `dot` format - -**Future opportunities:** - -- **Advanced routing** — Use petgraph's shortest path for "fastest path to End" analysis -- **Cycle detection fallback** — Validate against petgraph's implementation -- **Graph visualization** — Generate DOT/GraphViz output for debugging - -### When to Use petgraph Directly - -Use petgraph when you need: -- Pure graph algorithms without execution semantics -- Memory-optimal large graph storage -- Pre-built algorithms (MST, max flow, isomorphism) -- Static graph analysis tooling - -Use Weavegraph when you need: -- Async node execution with state management -- Conditional runtime routing based on state -- Event streaming and observability -- Checkpoint/resume workflow persistence -- LLM agent orchestration patterns - -### Code Example: Hybrid Usage - -```rust -use weavegraph::graphs::GraphBuilder; -use weavegraph::types::NodeKind; - -// Build workflow with Weavegraph -let builder = GraphBuilder::new() - .add_node(NodeKind::Custom("analyze".into()), AnalyzeNode) - .add_node(NodeKind::Custom("summarize".into()), SummarizeNode) - .add_edge(NodeKind::Start, NodeKind::Custom("analyze".into())) - .add_edge(NodeKind::Custom("analyze".into()), NodeKind::Custom("summarize".into())) - .add_edge(NodeKind::Custom("summarize".into()), NodeKind::End); - -// Convert to petgraph for analysis (feature-gated) -#[cfg(feature = "petgraph-compat")] -{ - use weavegraph::graphs::PetgraphConversion; - let pg = builder.to_petgraph(); - - // Use petgraph algorithms - let topo_order = petgraph::algo::toposort(&pg.graph, None)?; - let dot = petgraph::dot::Dot::new(&pg.graph); - println!("DOT output:\n{:?}", dot); -} - -// Execute with Weavegraph -let app = builder.compile()?; -let result = app.invoke(initial_state).await?; -``` +let registry = ReducerRegistry::new().with_reducer(ChannelType::Extra, Arc::new(LastWriteWins)); +let _builder = GraphBuilder::new().with_reducer_registry(registry); +~~~ + +## Scheduler behavior +The scheduler is intentionally small. `Scheduler` stores only a concurrency limit; `SchedulerState` carries the `versions_seen` map that remembers which message and extra versions each node already processed. +A superstep works in this order: +1. Pull the current frontier from the session state. +2. Skip `Start` and `End`, because they are structural nodes rather than user code. +3. Compare the snapshot versions against `versions_seen`; a node reruns only when a tracked version increased or the node has never been seen. +4. Execute eligible nodes concurrently up to `concurrency_limit`. +5. Hand the collected `NodePartial` values to the barrier reducer. +6. Compute the next frontier from unconditional edges, conditional edges, and any `FrontierCommand` values. +`create_session` seeds the scheduler with `available_parallelism()` from the host process, while `Scheduler::new(0)` still clamps to one worker so the runtime never creates a zero-width executor. +`NodeContext` is the bridge between the scheduler and user code. It carries the node identifier, step number, an event emitter, and optional clock or invocation metadata. The helper methods `emit`, `emit_diagnostic`, `emit_llm_chunk`, `emit_llm_final`, and `emit_llm_error` all route through that context. + +## Event bus and streaming hooks +The event system is split into two layers. +- `EventHub` owns the Tokio broadcast channel and tracks dropped-event metrics. +- `EventBus` manages worker tasks for sinks, subscribes clients, exposes diagnostics, and closes the channel when a run ends. +Every subscription becomes an `EventStream`. That stream can be consumed with `recv`, `try_recv`, `into_blocking_iter`, `into_async_stream`, or `next_timeout`. Slow subscribers do not block producers; the hub logs lag and increments the drop counter instead. +`event_bus::Event` has three variants: +- `Event::Node` for `NodeContext::emit` traffic, +- `Event::Diagnostic` for framework markers and generic telemetry, +- `Event::LLM` for chunk, final, and error notifications emitted by LLM-oriented nodes. +Sink delivery happens in per-sink worker tasks, and each worker calls `handle()` inside `spawn_blocking`, which lets a file sink or stdout sink perform blocking I/O without stalling async node execution. +A separate diagnostics channel tracks sink failures. `EventBus::diagnostics()` returns a `DiagnosticsStream`, and `EventBus::sink_health()` returns the last known counters and timestamps for each sink. Those diagnostics stay isolated from the main event feed unless `DiagnosticsConfig.emit_to_events` is enabled. +The runtime closes streams with framework markers rather than silent shutdown: +- `STREAM_END_SCOPE` means the whole event stream is finished and the channel will close. +- `INVOCATION_END_SCOPE` marks the end of one `invoke_next` call while keeping an iterative stream alive. +For runnable examples, see [examples/streaming_events.rs](../examples/streaming_events.rs), [examples/convenience_streaming.rs](../examples/convenience_streaming.rs), and [examples/production_streaming.rs](../examples/production_streaming.rs). + +## Checkpoint implementations +Checkpoint persistence is defined by the `Checkpointer` trait with three core operations: `save`, `load_latest`, and `list_sessions`. +- `InMemoryCheckpointer` keeps only the newest snapshot for each session and is the lightest option for tests or one-shot runs. +- `SQLiteCheckpointer` stores full step history in a local database and can apply embedded migrations when the `sqlite-migrations` feature is active. +- `PostgresCheckpointer` stores the same checkpoint model in PostgreSQL and can apply embedded migrations behind the `postgres-migrations` feature. +The shared `Checkpoint` struct captures the session identifier, step number, full `VersionedState`, current frontier, scheduler `versions_seen`, concurrency limit, created timestamp, and step-level metadata such as ran nodes, skipped nodes, and updated channels. `restore_session_state` reconstructs a runnable session from that record. +Checkpointing happens at the runner layer, not inside nodes. That keeps user code focused on business logic while `AppRunner` decides whether to autosave, resume from storage, or persist barrier results after each superstep. + +## Suggested next reads +- [examples/graph_execution.rs](../examples/graph_execution.rs) shows the basic compile-and-run path. +- [examples/scheduler_fanout.rs](../examples/scheduler_fanout.rs) demonstrates concurrent frontier execution. +- [examples/advanced_patterns.rs](../examples/advanced_patterns.rs) covers conditional routing and control commands. +- [STREAMING.md](STREAMING.md) narrows in on the live event APIs. +- [OPERATIONS.md](OPERATIONS.md) covers production tuning, checkpoint deployment, and diagnostics policy. diff --git a/docs/STREAMING.md b/docs/STREAMING.md index e3c5be8..ae3ed09 100644 --- a/docs/STREAMING.md +++ b/docs/STREAMING.md @@ -1,547 +1,167 @@ -# Streaming Events Quickstart - -This guide shows you how to stream workflow events to web clients using Weavegraph's `EventStream` helpers and, when needed, the legacy channel-based sinks. - -## Choose Your Pattern - -| Scenario | API | Event Consumption | Notes | Example | -|----------|-----|-------------------|-------|---------| -| CLI / scripts | `App::invoke_with_channel` | flume receiver | Simplest to wire progress bars, returns `(Result, Receiver)` | `examples/convenience_streaming.rs` | -| CLI with multiple sinks | `App::invoke_with_sinks` | sinks + optional channel | Inject stdout/file sinks without touching `AppRunner` | same as above | -| Web servers / SSE/WebSocket | `App::invoke_streaming` | `EventStream` (async/iter/poll) | Preferred for live streaming; emits `STREAM_END_SCOPE` sentinel when finished | `examples/streaming_events.rs` | -| Full control | `AppRunner::builder()` | custom `EventBus` | Use when you need per-request isolation or reuse a runner | `examples/streaming_events.rs` | - -### ⭐ Simple Patterns (Convenience Methods) - -For CLI tools and simple scripts, use the new convenience methods: - -```rust -// Pattern 1: Single channel (simplest) -let (result, events) = app.invoke_with_channel(initial_state).await; - -// Pattern 2: Multiple sinks -app.invoke_with_sinks( - initial_state, - vec![Box::new(StdOutSink::default()), Box::new(ChannelSink::new(tx))] +# Delivering live events from Weavegraph + +- Use [ARCHITECTURE.md](ARCHITECTURE.md) for the larger runtime map behind these helpers. +- Use [OPERATIONS.md](OPERATIONS.md) when you need deployment, persistence, or troubleshooting guidance. +- Runnable samples live in [examples/convenience_streaming.rs](../examples/convenience_streaming.rs), [examples/streaming_events.rs](../examples/streaming_events.rs), and [examples/production_streaming.rs](../examples/production_streaming.rs). + +Weavegraph exposes three convenience entry points on `App` and one shared subscription type, `EventStream`. Pick the helper based on where the events need to go: a `flume` receiver, a custom sink list, or a broadcast stream that feeds SSE or WebSocket code. + +## Which helper to reach for +| If you need... | Call this API | What you get back | Typical fit | +| --- | --- | --- | --- | +| a quick receiver in a CLI or test harness | `App::invoke_with_channel` | `(Result, flume::Receiver)` | progress output, smoke tests, ad-hoc scripts | +| extra sinks in addition to the runtime defaults | `App::invoke_with_sinks` | `Result` | stdout + JSONL, memory capture, metrics fan-out | +| a background run plus a first-class subscription | `App::invoke_streaming` | `(InvocationHandle, EventStream)` | SSE, WebSocket, long-lived dashboards | +| total control over session lifetime | `AppRunner::builder()` | your own runner and your own bus | iterative sessions, custom checkpointing, bespoke wiring | + +## `invoke_with_channel` +`invoke_with_channel` is the smallest streaming surface. The method builds an `EventBus` from `RuntimeConfig`, appends a `ChannelSink`, starts the workflow, and hands the caller the paired receiver. +Use it when a single consumer is enough and you do not need to manage `EventStream` directly. +~~~rust,no_run +use tokio::time::{timeout, Duration}; +use weavegraph::state::VersionedState; +# async fn sample(app: weavegraph::app::App) -> Result<(), Box> { +let seed = VersionedState::new_with_user_message("compile daily report"); +let (result, rx) = app.invoke_with_channel(seed).await; +let collector = tokio::spawn(async move { + while let Ok(event) = timeout(Duration::from_millis(250), rx.recv_async()).await.unwrap_or_else(|_| Err(flume::RecvError::Disconnected)) { + println!("{} :: {}", event.scope_label().unwrap_or("unknown"), event.message()); + } // receiver loop +}); // spawned collector +let finished = result?; +collector.await?; +assert!(!finished.snapshot().messages.is_empty()); +# Ok(()) } +~~~ +A few practical notes: +- The workflow still runs to completion even if the receiver is slow. +- If the receiver disappears, `ChannelSink` reports a broken pipe and the bus keeps serving any remaining sinks. +- The result half resolves to the final `VersionedState`, so this helper still works for one-shot command-line tools. + +## `invoke_with_sinks` +`invoke_with_sinks` keeps the sinks already described by `RuntimeConfig.event_bus` and appends the sinks you pass in. This is the helper to use when one invocation needs multiple outputs at once. +~~~rust,no_run +use weavegraph::event_bus::{ChannelSink, JsonLinesSink, StdOutSink}; +use weavegraph::state::VersionedState; +# async fn sample(app: weavegraph::app::App) -> Result<(), Box> { +let (tx, rx) = flume::unbounded(); +let state = VersionedState::new_with_user_message("refresh cache"); +let _final_state = app.invoke_with_sinks( + state, + vec![ + Box::new(StdOutSink::default()), + Box::new(JsonLinesSink::to_stdout()), + Box::new(ChannelSink::new(tx)), + ], ).await?; -``` - -**When to use:** Single-execution scenarios, CLI tools, progress monitoring - -**Example:** `cargo run --example convenience_streaming` - -### Production Pattern (Web Servers) - -Use [`App::invoke_streaming`](../src/app.rs) to launch the workflow and get an `EventStream` you can forward to SSE/WebSocket clients: - -```rust -use std::sync::Arc; -use axum::response::sse::{Event as SseEvent, Sse}; -use futures_util::StreamExt; -use tokio::{signal, sync::Mutex}; -use weavegraph::event_bus::STREAM_END_SCOPE; - -let (invocation, events) = app.invoke_streaming(initial_state).await; -let invocation = Arc::new(Mutex::new(Some(invocation))); - -let sse_stream = async_stream::stream! { - let mut stream = events.into_async_stream(); - while let Some(event) = stream.next().await { - yield Ok::( - SseEvent::default().json_data(event.clone()).unwrap() - ); - if event.scope_label() == Some(STREAM_END_SCOPE) { - break; - } - } -}; - -let response = Sse::new(sse_stream); - -tokio::spawn({ - let invocation = Arc::clone(&invocation); - async move { - tokio::select! { - _ = async { - if let Some(handle) = invocation.lock().await.take() { - if let Err(err) = handle.join().await { - tracing::error!("workflow failed: {err}"); - } - } - } => {} - _ = signal::ctrl_c() => { - if let Some(handle) = invocation.lock().await.take() { - handle.abort(); - } - } - } - } -}); - -response -``` - -**When to use:** SSE/WebSocket transports (or as a base for similar streaming adapters). The stream closes automatically when the sentinel diagnostic with scope `STREAM_END_SCOPE` arrives. - -**Example:** `cargo run --example streaming_events` - ---- - -## ⚠️ Notes on legacy patterns - -`App::invoke_with_channel` and `invoke_with_sinks` remain available for scripts that prefer flume channels or multiple sinks. Under the hood they now use the same broadcast hub as `invoke_streaming`. - -## Quick Start - -Run the self-contained example: - -```bash -cargo run --example streaming_events -``` - -This demonstrates the core pattern without requiring additional dependencies. - -## Key Components - -### 1. EventStream - -`EventStream` represents the broadcast output of the EventBus. Convert it to different consumption styles: - -```rust -let (invocation, events) = app.invoke_streaming(initial_state).await; - -// Async iterator (SSE/WebSocket) -let mut stream = events.into_async_stream(); -while let Some(event) = stream.next().await { /* ... */ } - -// Blocking iterator (CLI tools) -for event in events.into_blocking_iter() { /* ... */ } - -// Timed polling -if let Some(event) = events.next_timeout(Duration::from_secs(1)).await { /* ... */ } -``` - -`next_timeout` skips over lag notifications automatically—if the stream logs a warning about dropped events, consider increasing the configured buffer (see below). - -### 2. Legacy ChannelSink (Optional) - -If you still prefer channel-based forwarding, the convenience helpers continue to work: - -```rust -let (result, events) = app.invoke_with_channel(initial_state).await; -``` - -- The `Event` enum now includes `Node`, `Diagnostic`, **and** `LLM` variants—remember to handle the streaming case (`Event::LLM`). - -Every `invoke_streaming` run ends with a diagnostic whose scope equals `STREAM_END_SCOPE`. Use it to notify clients that the workflow has finished and the event stream is about to close. - -## Tuning Buffer Capacity - -- Default capacity is `1024` events per broadcast channel. -- Increase the buffer with `RuntimeConfig::default().with_event_bus(EventBusConfig::new(capacity, sinks))`. -- Slow consumers trigger a `weavegraph::event_bus` warning (`event stream lagged; dropped events`) and increment `EventStream::dropped()`. -- Benchmark with `cargo bench --bench event_bus_throughput` to validate settings for your workload. - -## Web Framework Integration - -### Pattern for HTTP Streaming (Axum Example) - -```rust -let (invocation, events) = app.invoke_streaming(initial_state).await; - -tokio::spawn(async move { - if let Err(err) = invocation.join().await { - tracing::error!("workflow failed: {err}"); - } -}); - -let sse_stream = events - .into_async_stream() - .map(|event| SseEvent::default().json_data(event).unwrap()); - -Sse::new(sse_stream) -``` - -### Required Dependencies (for Axum) - -Add to your `Cargo.toml`: - -```toml -[dependencies] -axum = "0.7" -futures-util = "0.3" -# flume is already a dependency of weavegraph -``` - -## Architecture Flow - -```text -┌──────────────┐ -│ HTTP Request │ -└──────┬───────┘ - │ - ▼ -┌──────────────────────────────────────┐ -│ HTTP Handler │ -│ 1. Create mpsc channel │ -│ 2. Create EventBus + ChannelSink │ -│ 3. Spawn workflow task │ -│ 4. Return SSE stream immediately │ -└──────────────────────────────────────┘ - │ │ - │ spawn │ return SSE - ▼ ▼ -┌──────────────────┐ ┌────────────────┐ -│ Background Task │ │ Client Stream │ -│ ┌──────────────┐ │ │ ┌────────────┐ │ -│ │ AppRunner │ │ │ │ SSE Stream │ │ -│ │ + EventBus │ │ │ │ ← channel │ │ -│ └──────┬───────┘ │ │ └────────────┘ │ -│ │ │ └────────────────┘ -│ ▼ │ -│ ┌──────────┐ │ -│ │ Workflow │ │ -│ │ Nodes │ │ -│ └────┬─────┘ │ -│ │ │ -│ ctx.emit() │ -│ │ │ -│ ▼ │ -│ ┌──────────┐ │ -│ │ EventBus │───┼──→ ChannelSink ──→ mpsc ──→ Client -│ └──────────┘ │ -└──────────────────┘ -``` - -## Key Patterns - -### 1. One ChannelSink Per Client - -Each HTTP connection should get its own channel: - -```rust -// GOOD: New channel per request -async fn handler() -> Sse { - let (tx, rx) = flume::unbounded(); // ✓ Per-request channel - let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); - // ... +while let Ok(event) = rx.try_recv() { + eprintln!("mirrored event: {}", event.message()); } - -// BAD: Shared channel across requests -static SHARED_CHANNEL: OnceCell> = OnceCell::new(); // ✗ Don't do this -``` - -### 2. Background Task for Long-Running Workflows - -Spawn the workflow execution so the handler can return immediately: - -```rust -// GOOD: Non-blocking handler -tokio::spawn(async move { - runner.run_until_complete(&session_id).await -}); -return Sse::new(stream); // Returns immediately - -// BAD: Blocking handler -runner.run_until_complete(&session_id).await?; // ✗ Blocks until completion -return Sse::new(stream); -``` - -### 3. Event Filtering - -Filter events by scope or type before sending to clients: - -```rust -let stream = UnboundedReceiverStream::new(rx) - .filter(|event| { - matches!(event.event_type(), EventType::Node | EventType::Diagnostic) +# Ok(()) } +~~~ +This helper is useful when you want a plain-text operator feed, a structured log sink, and a receiver for application code without building `AppRunner` by hand. + +## `invoke_streaming` +`invoke_streaming` is the preferred interface for web transports. It creates a fresh `EventBus`, subscribes an `EventStream`, spawns the workflow in the background, and returns an `InvocationHandle` that can be awaited or aborted. +~~~rust,no_run +use futures_util::StreamExt as _; +use weavegraph::event_bus::{Event, STREAM_END_SCOPE}; +use weavegraph::state::VersionedState; +# async fn sample(app: weavegraph::app::App) -> Result<(), Box> { +let seed = VersionedState::new_with_user_message("stream this run"); +let (handle, events) = app.invoke_streaming(seed).await; +let reader = tokio::spawn(async move { + let mut feed = events.into_async_stream(); + loop { + match feed.next().await { + Some(Event::Diagnostic(marker)) if marker.scope() == STREAM_END_SCOPE => break, + Some(event) => println!("event => {}", event.message()), + None => break, + } // next event branch + } // async stream loop +}); // spawned reader +let final_state = handle.join().await?; +reader.await?; +assert!(final_state.snapshot().messages.len() >= 1); +# Ok(()) } +~~~ +Two lifecycle rules matter here: +- Dropping the `InvocationHandle` aborts the background workflow task. +- Dropping only the `EventStream` does not stop the run; call `abort()` or `join()` on the handle when the client disconnects. +The sentinel diagnostic with scope `STREAM_END_SCOPE` is the clean end-of-stream marker. The runtime emits it before closing the channel so HTTP code can flush a final frame and then terminate the connection intentionally. + +## `EventStream` operations +`EventStream` wraps a Tokio broadcast receiver and exposes several consumption styles. +| Method | What it does | Good for | +| --- | --- | --- | +| `recv().await` | wait for the next event or a lag/closed error | direct async loops | +| `try_recv()` | poll without blocking | manual event pumps | +| `into_async_stream()` | produce a boxed `futures` stream and silently skip lagged slots | SSE or WebSocket adapters | +| `into_blocking_iter()` | iterate from synchronous code | CLI bridges or thread-based tools | +| `next_timeout(duration).await` | wait up to a deadline and skip lag notices | periodic polling loops | +| `with_shutdown(watch_rx)` | stop the stream when external shutdown flips to `true` | server-managed cancellation | +Lag handling is built into the hub rather than the consumer. When a receiver falls behind, the hub records dropped messages, logs a warning, and keeps the producer side moving. + +## Server-Sent Events shape +`examples/production_streaming.rs` shows the full Axum integration, but the essential flow is short: start the workflow, map each `Event` into an SSE frame, and stop when `STREAM_END_SCOPE` arrives. +~~~rust,no_run +use axum::response::sse::{Event as AxumSseFrame, Sse}; +use futures_util::StreamExt as _; +use weavegraph::event_bus::{Event, STREAM_END_SCOPE}; +# async fn sse(app: weavegraph::app::App, initial: weavegraph::state::VersionedState) { +let (handle, events) = app.invoke_streaming(initial).await; +let sse_stream = events.into_async_stream().map(|event| { + let _closing_frame = matches!(&event, Event::Diagnostic(marker) if marker.scope() == STREAM_END_SCOPE); + Ok::<_, std::convert::Infallible>(AxumSseFrame::default().json_data(event).expect("serialize event")) +}); // SSE frame mapper +let join_task = tokio::spawn(async move { + match handle.join().await { + Ok(_) => {} + Err(err) => tracing::error!("workflow join failed: {err}"), + } // join outcome +}); // spawned join task +let _ = join_task; +let _response = Sse::new(sse_stream); +# } +~~~ +If you need to cancel when the client vanishes, keep the handle in shared state and call `abort()` from the disconnect path, just like the production example does. + +## Buffer sizing and diagnostics +The broadcast side is configured through `RuntimeConfig` and `EventBusConfig`. +~~~rust,no_run +use weavegraph::runtimes::{DiagnosticsConfig, EventBusConfig, RuntimeConfig, SinkConfig}; +let runtime = RuntimeConfig::default().with_event_bus( + EventBusConfig::new(2048, vec![SinkConfig::StdOut]).with_diagnostics(DiagnosticsConfig { + enabled: true, + buffer_capacity: Some(512), + emit_to_events: false, }) - .map(|event| Ok(SseEvent::default().json_data(event).unwrap())); -``` - -## Testing - -Test your streaming setup with `curl`: - -```bash -# Start your server -cargo run - -# In another terminal, stream events -curl -N http://localhost:3000/stream - -# You should see SSE events: -# event: workflow-event -# data: {"type":"node","message":"Processing...","scope":"worker","timestamp":"..."} -``` - -## Further Reading - -- **`streaming_events.rs`** - Self-contained example (no web framework) -- (Legacy) Older LLM streaming demos were removed during the 0.2.0 refactor; use `examples/streaming_events.rs` as the canonical streaming pattern. -- **EventBus source**: `weavegraph/src/event_bus/` -- **AppRunner source**: `weavegraph/src/runtimes/runner.rs` - -## Common Issues - -### Events Not Appearing in Stream - -**Problem**: Workflow runs but no events in channel. - -**Solution**: Ensure you're injecting your custom `EventBus` into the runner (prefer `AppRunner::builder()`): - -```rust -// ✓ CORRECT: Custom EventBus -let mut runner = AppRunner::builder() - .app(app) - .checkpointer(CheckpointerType::InMemory) - .event_bus(bus) - .autosave(false) - .start_listener(true) - .build() - .await; - -// ✗ WRONG: Default EventBus (events go nowhere) -let final_state = app.invoke(state).await?; -``` - -### Stream Ends Immediately - -**Problem**: SSE connection closes right away. - -**Solution**: Make sure the workflow task is spawned, not awaited: - -```rust -// ✓ CORRECT: Spawned task -tokio::spawn(async move { runner.run_until_complete(&id).await }); -return Sse::new(stream); // Returns immediately, stream stays open - -// ✗ WRONG: Awaited task -runner.run_until_complete(&id).await?; -return Sse::new(stream); // Stream already finished -``` - -### Missing Events at Start - -**Problem**: First few events are dropped. - -**Solution**: Create the channel and EventBus *before* starting the workflow: - -```rust -// ✓ CORRECT: Channel exists before workflow starts -let (tx, rx) = flume::unbounded(); -let bus = EventBus::with_sinks(vec![Box::new(ChannelSink::new(tx))]); -tokio::spawn(async move { /* run with bus */ }); -Sse::new(stream) // Events captured from the start - -// ✗ WRONG: Channel created after workflow starts -tokio::spawn(async move { /* already running */ }); -let (tx, rx) = flume::unbounded(); // Too late! -``` - ---- - -## Sink Diagnostics: Monitoring Failures - -Weavegraph provides **opt-in diagnostics** for monitoring event sink health without disrupting your main event stream. This is useful for production observability and debugging sink-specific issues. - -### Quick Start: No Changes Needed - -Existing code works unchanged—diagnostics are isolated and optional: - -```rust -// ✓ Constructing and using EventBus exactly as before -let bus = EventBus::with_sinks(vec![Box::new(StdOutSink::default())]); -app_runner.run_until_complete(&session).await?; - -// ✓ No changes needed for EventStream consumers -let mut events = bus.subscribe(); -while let Ok(event) = events.recv().await { /* ... */ } - -// ✓ No changes needed for sinks like StdOutSink, MemorySink, or ChannelSink -``` - -### Opt-In: Subscribe to Diagnostics - -To monitor sink failures, subscribe to the diagnostics stream: - -```rust -use weavegraph::event_bus::EventBus; - -let bus = EventBus::with_sinks(vec![ - Box::new(StdOutSink::default()), - Box::new(ChannelSink::new(tx)), -]); - -// Subscribe to diagnostics (doesn't affect main event stream) -let mut diags = bus.diagnostics(); - -tokio::spawn(async move { - while let Ok(diagnostic) = diags.recv().await { - eprintln!( - "[{}] Sink '{}' error #{}: {}", - diagnostic.when.format("%H:%M:%S"), - diagnostic.sink, - diagnostic.occurrence, - diagnostic.error - ); - } -}); - -// Main event stream continues independently -let mut events = bus.subscribe(); -``` - -**DiagnosticsStream API** mirrors `EventStream`: -- `recv()` → blocking receive -- `try_recv()` → non-blocking poll (returns `Empty`, `Closed`, or `Ok(diagnostic)`) -- `into_async_stream()` → convert to `futures::Stream` -- `next_timeout(duration)` → receive with timeout - -### Health Snapshots - -Query aggregated sink health at any time without subscribing: - -```rust -// Get current health for all sinks -let health = bus.sink_health(); - -for entry in health { - println!( - "Sink '{}': {} errors, last: {:?}", - entry.sink, - entry.error_count, - entry.last_error.as_deref().unwrap_or("none") - ); -} -``` - -**Use cases:** -- Health check endpoints in web servers -- Periodic alerting without continuous monitoring -- Post-mortem analysis after workflow completion - -### Configuration Options - -Control diagnostics behavior via `EventBusConfig`: - -```rust -use weavegraph::runtimes::{RuntimeConfig, EventBusConfig, DiagnosticsConfig}; - -// Disable diagnostics entirely (saves memory) -let config = RuntimeConfig::default() - .with_event_bus( - EventBusConfig::with_stdout_only() - .with_diagnostics(DiagnosticsConfig { - enabled: false, - buffer_capacity: None, - emit_to_events: false, - }) - ); - -// Enable diagnostics with custom buffer -let config = RuntimeConfig::default() - .with_event_bus( - EventBusConfig::with_stdout_only() - .with_diagnostics(DiagnosticsConfig { - enabled: true, - buffer_capacity: Some(512), // Default: same as event bus capacity - emit_to_events: false, - }) - ); - -// Emit diagnostics to BOTH the diagnostics stream AND main event stream -// ⚠️ Caution: Only use when sinks cannot create feedback loops -let config = RuntimeConfig::default() - .with_event_bus( - EventBusConfig::with_stdout_only() - .with_diagnostics(DiagnosticsConfig { - enabled: true, - buffer_capacity: None, - emit_to_events: true, // Also emit Event::Diagnostic to main stream - }) - ); -``` - -**When to use `emit_to_events: true`:** -- You have a single monitoring sink that won't fail on diagnostic events -- You want diagnostics visible in existing event consumers (logs, metrics, etc.) -- You understand the risk of cascading failures (e.g., a sink that fails on all events will emit diagnostics that trigger more failures) - -**Default behavior (`emit_to_events: false`):** -- Diagnostics are isolated to the dedicated diagnostics stream -- Main event stream is unaffected by sink failures -- Safer for most production use cases - -### Custom Sink Naming - -Override the default sink name for clearer diagnostics: - -```rust -use std::borrow::Cow; -use weavegraph::event_bus::EventSink; - -struct DatabaseSink { /* ... */ } - -impl EventSink for DatabaseSink { - fn name(&self) -> Cow<'static, str> { - Cow::Borrowed("postgres_events_sink") - } - - fn handle(&mut self, event: &Event) -> std::io::Result<()> { - // ... write to database - Ok(()) - } -} - -// Diagnostics will show "postgres_events_sink" instead of generic type name -let bus = EventBus::with_sink(DatabaseSink { /* ... */ }); -let health = bus.sink_health(); -assert_eq!(health[0].sink, "postgres_events_sink"); -``` - -### Example: Health Monitoring in Axum - -```rust -use axum::{Json, routing::get, Router}; -use serde_json::json; -use std::sync::Arc; - -async fn health_check(bus: Arc) -> Json { - let health = bus.sink_health(); - let any_errors = health.iter().any(|h| h.error_count > 0); - - Json(json!({ - "status": if any_errors { "degraded" } else { "healthy" }, - "sinks": health.iter().map(|h| json!({ - "name": h.sink, - "errors": h.error_count, - "last_error": h.last_error, - "last_error_at": h.last_error_at, - })).collect::>() - })) -} - -let app = Router::new() - .route("/health", get(health_check)) - .with_state(Arc::new(event_bus)); -``` - -### Troubleshooting - -**Q: Diagnostics stream returns `Closed` immediately** - -A: Diagnostics may be disabled in config. Check `EventBusConfig::diagnostics.enabled`. - -**Q: I'm not seeing diagnostics for sink failures I know are happening** - -A: Ensure you're subscribing *before* the workflow starts, or check if the diagnostics buffer is lagging (broadcast receivers drop messages when full). - -**Q: Health snapshot shows zero errors but I see tracing logs** - -A: Diagnostics tracking is disabled. Set `diagnostics.enabled: true` in config. - -**Q: Can I get diagnostics for a specific sink only?** - -A: Filter by `diagnostic.sink` name after receiving. The stream contains all sink diagnostics. - +); +let _ = runtime; +~~~ +What these settings change: +- `buffer_capacity` sets the size of the broadcast ring used by `EventHub`. +- diagnostics capacity can match or diverge from the main event buffer. +- `emit_to_events` controls whether sink failures stay isolated in `DiagnosticsStream` or are also mirrored into the primary event feed. +At runtime you can inspect sink status with `EventBus::sink_health()` and subscribe to failure notifications with `EventBus::diagnostics()`. That is the supported way to watch sink health without polluting the normal client stream. + +## Event payload families +Every helper above ultimately delivers the same `event_bus::Event` enum. +- `Event::Node` carries structured node output emitted through `NodeContext::emit`. +- `Event::Diagnostic` carries framework markers, including `STREAM_END_SCOPE` and `INVOCATION_END_SCOPE`. +- `Event::LLM` carries chunk, final, and error events produced by `emit_llm_chunk`, `emit_llm_final`, or `emit_llm_error`. +Because every variant exposes `scope_label()` and `message()`, many consumers can stay generic and only branch when they need variant-specific metadata. + +## When to drop down to `AppRunner` +The convenience methods cover the common one-shot cases. Use `AppRunner::builder()` when you need any of the following: +- an explicitly supplied `EventBus`, +- a custom or preconnected `Checkpointer`, +- iterative sessions driven by `invoke_next`, +- a runner that survives across multiple logical inputs. +That lower-level API is the same engine used underneath the convenience helpers, so moving down a level does not change event semantics; it only exposes more of the wiring. + +## Where to explore next +- [examples/convenience_streaming.rs](../examples/convenience_streaming.rs) demonstrates the channel and multi-sink helpers. +- [examples/streaming_events.rs](../examples/streaming_events.rs) shows the generic `EventStream` pattern. +- [examples/production_streaming.rs](../examples/production_streaming.rs) shows SSE plus checkpoint-backed execution. +- [ARCHITECTURE.md](ARCHITECTURE.md) explains how the bus, scheduler, reducers, and checkpointers fit together. diff --git a/examples/advanced_patterns.rs b/examples/advanced_patterns.rs index 34ecb26..4105e57 100644 --- a/examples/advanced_patterns.rs +++ b/examples/advanced_patterns.rs @@ -1,31 +1,19 @@ -//! Advanced node patterns and error handling examples. +//! Advanced workflow patterns. //! -//! This example demonstrates sophisticated workflow patterns that go beyond basic -//! node execution. It showcases: +//! This example chains together three higher-level node styles: +//! a retrying API call, a router driven by state, and a transformation stage +//! that records an audit log. The goal is to show how richer workflows can stay +//! observable without giving up small, composable nodes. //! -//! ## Core Patterns -//! - **Complex error handling**: Retry logic, graceful fallbacks, and error recovery -//! - **Conditional node execution**: Dynamic routing based on state conditions -//! - **Rich state transformations**: Complex data processing and enrichment -//! - **Advanced NodePartial usage**: Efficient partial state updates +//! The pipeline used here is: +//! API call -> route decision -> transformation pass //! -//! ## Key Learning Points -//! - How to implement robust external service integration -//! - Patterns for conditional workflow routing -//! - Techniques for data transformation and validation -//! - Best practices for error handling in distributed workflows -//! -//! ## Architecture -//! The example creates a pipeline: API Call → Router → Transformer -//! Each stage demonstrates different advanced patterns while maintaining -//! composability and observability. -//! -//! Run with: `cargo run --example advanced_patterns` +//! Run with: `cargo run --example advanced_patterns --features examples` use async_trait::async_trait; use serde_json::json; use std::collections::HashMap; -use std::sync::Arc; +use std::sync::Arc as SharedEmitter; use weavegraph::channels::Channel; use weavegraph::event_bus::EventBus; use weavegraph::message::{Message, Role}; @@ -39,26 +27,16 @@ use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitEx type ExampleResult = std::result::Result>; -/// A node that simulates external API calls with potential failures and retry logic. -/// -/// This node demonstrates enterprise-grade patterns for external service integration: -/// -/// ## Features -/// - **Configurable failure simulation**: Adjustable failure rates for testing -/// - **Retry logic**: Configurable retry attempts with detailed logging -/// - **Rich error reporting**: Comprehensive error metadata and context -/// - **Graceful degradation**: Structured failure handling +/// Simulates an external dependency that may fail before eventually succeeding. /// -/// ## Use Cases -/// - External API integration (REST, GraphQL, gRPC) -/// - Database operations with connection retry -/// - File system operations with temporary failures -/// - Any operation that might fail temporarily +/// The node emits attempt-by-attempt progress, records a structured result on +/// success, and returns a regular `NodeError` once the retry budget is spent. +/// It is intentionally small so the retry pattern is easy to adapt elsewhere. /// -/// ## Configuration -/// - `service_name`: Human-readable service identifier for logging -/// - `failure_rate`: Probability of failure (0.0 = never fail, 1.0 = always fail) -/// - `max_retries`: Maximum number of retry attempts before giving up +/// Configuration fields: +/// - `service_name`: label used in emitted messages and result payloads +/// - `failure_rate`: `0.0` always succeeds, `1.0` always fails +/// - `max_retries`: number of attempts before returning an error pub struct ApiCallNode { /// Human-readable name for this service (used in logging and errors) pub service_name: String, @@ -107,14 +85,14 @@ impl Node for ApiCallNode { }), ); - let partial = NodePartial::new() - .with_messages(vec![Message::with_role( - Role::System, - &format!("{} API call completed successfully", self.service_name), - )]) - .with_extra(extra); - - return Ok(partial); + let message = Message::with_role( + Role::System, + &format!("{} API call completed successfully", self.service_name), + ); + let result = NodePartial::new() + .with_extra(extra) + .with_messages(vec![message]); + return Ok(result); } else { ctx.emit("retry", format!("Attempt {} failed, retrying...", attempt))?; } @@ -128,28 +106,15 @@ impl Node for ApiCallNode { } } -/// A conditional router node that directs workflow execution based on state conditions. +/// Chooses a route label by comparing one state key against a set of values. /// -/// This node implements sophisticated routing logic that can dynamically alter -/// workflow execution paths based on runtime state. It's essential for building -/// adaptive workflows that respond to data conditions. +/// This is a lightweight example of data-driven routing. The selected route is +/// written back into `extra` so later stages can branch, log, or inspect the +/// decision without re-running the matching logic. /// -/// ## Features -/// - **Multi-condition routing**: Support for complex condition sets -/// - **Fallback routing**: Default route when no conditions match -/// - **Rich routing metadata**: Comprehensive decision logging -/// - **Type-safe conditions**: JSON value-based condition matching -/// -/// ## Use Cases -/// - User role-based workflow routing -/// - Data quality-based processing paths -/// - Feature flag-driven execution -/// - A/B testing workflow variants -/// - Error severity-based handling paths -/// -/// ## Configuration -/// - `route_key`: The state key to evaluate for routing decisions -/// - `conditions`: Map of route names to expected values for that key +/// Configuration fields: +/// - `route_key`: extra-data entry inspected during routing +/// - `conditions`: route name to expected JSON value pub struct ConditionalRouterNode { /// The key in the state's extra data to use for routing decisions pub route_key: String, @@ -198,12 +163,12 @@ impl Node for ConditionalRouterNode { }), ); - Ok(NodePartial::new() - .with_messages(vec![Message::with_role( - Role::System, - &format!("Routed to: {}", selected_route), - )]) - .with_extra(extra)) + let route_message = + Message::with_role(Role::System, &format!("Routed to: {selected_route}")); + let update = NodePartial::new() + .with_extra(extra) + .with_messages(vec![route_message]); + Ok(update) } } @@ -327,12 +292,14 @@ impl Node for DataTransformerNode { extra.insert("transformation_log".to_string(), json!(transformation_log)); - Ok(NodePartial::new() - .with_messages(vec![Message::with_role( - Role::Assistant, - &format!("Applied {} transformations", transformation_log.len()), - )]) - .with_extra(extra)) + let summary = Message::with_role( + Role::Assistant, + &format!("Applied {} transformations", transformation_log.len()), + ); + let update = NodePartial::new() + .with_extra(extra) + .with_messages(vec![summary]); + Ok(update) } } @@ -397,9 +364,8 @@ async fn main() -> ExampleResult<()> { max_retries: 3, }; - let emitter = event_bus.get_emitter(); - - let ctx1 = NodeContext::new("api_call", 1, Arc::clone(&emitter)); + let event_emitter = event_bus.get_emitter(); + let ctx1 = NodeContext::new("api_call", 1, SharedEmitter::clone(&event_emitter)); // Demonstrate both success and failure scenarios match api_node.run(state.snapshot(), ctx1).await { @@ -445,7 +411,7 @@ async fn main() -> ExampleResult<()> { max_retries: 2, }; - let ctx1_1 = NodeContext::new("metrics_api", 1, Arc::clone(&emitter)); + let ctx1_1 = NodeContext::new("metrics_api", 1, SharedEmitter::clone(&event_emitter)); match failing_api_node.run(state.snapshot(), ctx1_1).await { Ok(result) => { @@ -479,7 +445,7 @@ async fn main() -> ExampleResult<()> { }, }; - let ctx2 = NodeContext::new("router", 2, Arc::clone(&emitter)); + let ctx2 = NodeContext::new("router", 2, SharedEmitter::clone(&event_emitter)); let result2 = router_node.run(state.snapshot(), ctx2).await?; if let Some(messages) = result2.messages { @@ -517,7 +483,7 @@ async fn main() -> ExampleResult<()> { ], }; - let ctx3 = NodeContext::new("transformer", 3, Arc::clone(&emitter)); + let ctx3 = NodeContext::new("transformer", 3, SharedEmitter::clone(&event_emitter)); let result3 = transformer_node.run(state.snapshot(), ctx3).await?; if let Some(messages) = result3.messages { diff --git a/examples/basic_nodes.rs b/examples/basic_nodes.rs index f45244b..76cbc55 100644 --- a/examples/basic_nodes.rs +++ b/examples/basic_nodes.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use serde_json::json; -use std::sync::Arc; +use std::sync::Arc as SharedEmitter; use tracing::info; use tracing_error::ErrorLayer; use tracing_subscriber::fmt::format::FmtSpan; @@ -73,9 +73,8 @@ impl Node for MessageCounterNode { format!("{} finished processing", self.node_name), )?; - Ok(NodePartial::new() - .with_messages(vec![Message::with_role(Role::Assistant, &summary)]) - .with_extra(extra)) + let update = NodePartial::new().with_extra(extra); + Ok(update.with_messages(vec![Message::with_role(Role::Assistant, &summary)])) } } @@ -132,7 +131,8 @@ impl Node for ValidationNode { extra.insert("validated_fields".to_string(), json!(self.required_fields)); extra.insert("message_count_ok".to_string(), json!(true)); - Ok(NodePartial::new().with_extra(extra)) + let validation = NodePartial::new().with_extra(extra); + Ok(validation) } } @@ -201,9 +201,8 @@ impl Node for AggregatorNode { ctx.emit("completed", "Data aggregation completed")?; - Ok(NodePartial::new() - .with_messages(vec![Message::with_role(Role::Assistant, &summary)]) - .with_extra(extra)) + let update = NodePartial::new().with_extra(extra); + Ok(update.with_messages(vec![Message::with_role(Role::Assistant, &summary)])) } } @@ -259,9 +258,8 @@ async fn main() -> ExampleResult<()> { node_name: "CounterExample".to_string(), }; - let emitter = event_bus.get_emitter(); - - let ctx1 = NodeContext::new("counter-1", 2, Arc::clone(&emitter)); + let event_emitter = event_bus.get_emitter(); + let ctx1 = NodeContext::new("counter-1", 2, SharedEmitter::clone(&event_emitter)); let result1 = counter_node.run(state.snapshot(), ctx1).await?; @@ -286,7 +284,7 @@ async fn main() -> ExampleResult<()> { min_message_count: 1, }; - let ctx2 = NodeContext::new("validator-1", 3, Arc::clone(&emitter)); + let ctx2 = NodeContext::new("validator-1", 3, SharedEmitter::clone(&event_emitter)); let result2 = validation_node.run(state.snapshot(), ctx2).await?; @@ -300,7 +298,7 @@ async fn main() -> ExampleResult<()> { info!("\n📈 Running AggregatorNode..."); let aggregator_node = AggregatorNode; - let ctx3 = NodeContext::new("aggregator-1", 4, Arc::clone(&emitter)); + let ctx3 = NodeContext::new("aggregator-1", 4, SharedEmitter::clone(&event_emitter)); let result3 = aggregator_node.run(state.snapshot(), ctx3).await?; diff --git a/examples/convenience_streaming.rs b/examples/convenience_streaming.rs index f896b85..1822348 100644 --- a/examples/convenience_streaming.rs +++ b/examples/convenience_streaming.rs @@ -172,8 +172,8 @@ async fn main() -> ExampleResult<()> { // Spawn background collector for channel let channel_collector = tokio::spawn(async move { let mut events = Vec::new(); - while let Ok(event) = rx.recv_async().await { - events.push(event); + while let Ok(next_event) = rx.recv_async().await { + events.extend([next_event]); } events }); diff --git a/examples/graph_execution.rs b/examples/graph_execution.rs index c5a0a32..9106af1 100644 --- a/examples/graph_execution.rs +++ b/examples/graph_execution.rs @@ -1,19 +1,19 @@ -//! Graph Execution: Basic Graph Building and Execution +//! Graph execution walkthrough. //! -//! This demonstration showcases the fundamental graph building and execution patterns -//! in the Weavegraph framework. It covers basic workflow construction, state management, -//! barrier operations, and error handling scenarios. +//! This example assembles a small workflow, runs it, mutates the resulting +//! state, and applies manual barrier updates. It is meant as a compact tour of +//! graph compilation, state snapshots, and lower-level runtime hooks. //! -//! What You'll Learn: -//! 1. Modern Message Construction: Typed roles with `Message::with_role()` -//! 2. State Management: Working with versioned state and snapshots -//! 3. Graph Building: Creating workflows with nodes and edges -//! 4. Barrier Operations: Manual state updates and version management -//! 5. Error Handling: Validation and expected failure scenarios +//! Highlights: +//! 1. Constructing messages with typed roles. +//! 2. Preparing a `VersionedState` with extras. +//! 3. Compiling a workflow with fan-out edges. +//! 4. Applying barrier updates by hand. +//! 5. Inspecting saturation and no-op behavior. //! -//! Running This Example: +//! Run with: //! ```bash -//! cargo run --example graph_execution +//! cargo run --example graph_execution --features examples //! ``` use async_trait::async_trait; @@ -168,32 +168,19 @@ async fn run_example() -> ExampleResult<()> { // ✅ STEP 2: Modern Graph Building info!("\n🔗 Step 2: Building workflow graph with modern GraphBuilder"); + let initializer_id = NodeKind::from("initializer"); + let processor_a_id = NodeKind::from("processor-a"); + let processor_b_id = NodeKind::from("processor-b"); + let app = GraphBuilder::new() - .add_node( - NodeKind::Custom("Initializer".into()), - SimpleNode::new("Initializer"), - ) - .add_node( - NodeKind::Custom("ProcessorA".into()), - SimpleNode::new("ProcessorA"), - ) - .add_node( - NodeKind::Custom("ProcessorB".into()), - SimpleNode::new("ProcessorB"), - ) - // Create a processing pipeline: Start -> A -> B -> End - .add_edge(NodeKind::Start, NodeKind::Custom("Initializer".into())) - .add_edge( - NodeKind::Custom("Initializer".into()), - NodeKind::Custom("ProcessorA".into()), - ) - .add_edge( - NodeKind::Custom("ProcessorA".into()), - NodeKind::Custom("ProcessorB".into()), - ) - .add_edge(NodeKind::Custom("ProcessorB".into()), NodeKind::End) - // Add a secondary path: Start -> B (for demonstration of fan-out) - .add_edge(NodeKind::Start, NodeKind::Custom("ProcessorB".into())) + .add_node(initializer_id.clone(), SimpleNode::new("Initializer")) + .add_node(processor_a_id.clone(), SimpleNode::new("ProcessorA")) + .add_node(processor_b_id.clone(), SimpleNode::new("ProcessorB")) + .add_edge(NodeKind::Start, initializer_id.clone()) + .add_edge(initializer_id.clone(), processor_a_id.clone()) + .add_edge(processor_a_id, processor_b_id.clone()) + .add_edge(processor_b_id.clone(), NodeKind::End) + .add_edge(NodeKind::Start, processor_b_id) .compile()?; info!(" ✓ Graph compiled successfully"); @@ -290,25 +277,25 @@ async fn run_example() -> ExampleResult<()> { )]) .with_extra(extra_b); - let run_ids = vec![ - NodeKind::Custom("VirtualA".into()), - NodeKind::Custom("VirtualB".into()), + let barrier_sources = vec![ + NodeKind::from("virtual-left"), + NodeKind::from("virtual-right"), ]; - let barrier_outcome = app - .apply_barrier(&mut barrier_state, &run_ids, vec![partial_a, partial_b]) + let barrier_result = app + .apply_barrier( + &mut barrier_state, + &barrier_sources, + vec![partial_a, partial_b], + ) .await .map_err(|e| std::io::Error::other(format!("Barrier operation failed: {e}")))?; + let changed_channels = &barrier_result.updated_channels; + let barrier_error_count = barrier_result.errors.len(); info!(" ✓ Barrier applied successfully"); - info!( - " ✓ Updated channels: {:?}", - barrier_outcome.updated_channels - ); - info!( - " ✓ Errors recorded at barrier: {}", - barrier_outcome.errors.len() - ); + info!(" ✓ Barrier touched channels: {changed_channels:?}"); + info!(" ✓ Barrier recorded {barrier_error_count} error(s)"); let barrier_snapshot = barrier_state.snapshot(); info!( @@ -326,27 +313,26 @@ async fn run_example() -> ExampleResult<()> { let noop_partials = vec![NodePartial::new().with_messages(vec![])]; // Empty - should not update - let noop_outcome = app + let noop_result = app .apply_barrier(&mut barrier_state, &[], noop_partials) .await .map_err(|e| std::io::Error::other(format!("No-op barrier failed: {e}")))?; let post_noop_version = barrier_state.messages.version(); + let noop_channels = &noop_result.updated_channels; + let noop_error_total = noop_result.errors.len(); info!(" ✓ No-op barrier completed"); info!( " ✓ Version unchanged: {} -> {} (expected same)", pre_noop_version, post_noop_version ); - info!(" ✓ Updated channels: {:?}", noop_outcome.updated_channels); - info!( - " ✓ Errors recorded at barrier: {}", - noop_outcome.errors.len() - ); + info!(" ✓ No-op barrier updated channels: {noop_channels:?}"); + info!(" ✓ No-op barrier recorded {noop_error_total} error(s)"); // ✅ STEP 6: Error Handling Demonstrations info!("\n❌ Step 6: Demonstrating error handling and validation"); - // (Removed obsolete test: entry point validation no longer enforced. Start/End are virtual.) + // Start and End are virtual nodes, so the older entry-point validation demo no longer applies. info!(" 🧪 Skipping deprecated entry-point error test (Start/End now virtual)."); // Test 3: Version saturation behavior diff --git a/examples/scheduler_fanout.rs b/examples/scheduler_fanout.rs index 49e957d..f0a7e12 100644 --- a/examples/scheduler_fanout.rs +++ b/examples/scheduler_fanout.rs @@ -157,56 +157,37 @@ async fn run_scheduler_fanout() -> ExampleResult<()> { // ✅ STEP 2: Building a Complex Graph for Scheduler Example info!("\n🔗 Step 2: Building complex graph with dependencies and fan-out"); + let initializer_id = NodeKind::from("fanout-init"); + let analyzer_id = NodeKind::from("fanout-analyzer"); + let processor_a_id = NodeKind::from("fanout-processor-a"); + let processor_b_id = NodeKind::from("fanout-processor-b"); + let synthesizer_id = NodeKind::from("fanout-synth"); + let app = GraphBuilder::new() .add_node( - NodeKind::Custom("Initializer".into()), + initializer_id.clone(), SchedulerDemoNode::new("Initializer", 50), ) + .add_node(analyzer_id.clone(), SchedulerDemoNode::new("Analyzer", 200)) .add_node( - NodeKind::Custom("Analyzer".into()), - SchedulerDemoNode::new("Analyzer", 200), - ) - .add_node( - NodeKind::Custom("ProcessorA".into()), + processor_a_id.clone(), SchedulerDemoNode::new("ProcessorA", 150), ) .add_node( - NodeKind::Custom("ProcessorB".into()), + processor_b_id.clone(), SchedulerDemoNode::new("ProcessorB", 100), ) .add_node( - NodeKind::Custom("Synthesizer".into()), + synthesizer_id.clone(), SchedulerDemoNode::new("Synthesizer", 300), ) - // (Removed concrete End node registration – End is virtual) - // Create complex dependency graph: - // Start fans out to Analyzer and ProcessorA - .add_edge(NodeKind::Start, NodeKind::Custom("Initializer".into())) - .add_edge( - NodeKind::Custom("Initializer".into()), - NodeKind::Custom("Analyzer".into()), - ) - .add_edge( - NodeKind::Custom("Initializer".into()), - NodeKind::Custom("ProcessorA".into()), - ) - // Analyzer feeds into ProcessorB - .add_edge( - NodeKind::Custom("Analyzer".into()), - NodeKind::Custom("ProcessorB".into()), - ) - // Both ProcessorA and ProcessorB feed into Synthesizer - .add_edge( - NodeKind::Custom("ProcessorA".into()), - NodeKind::Custom("Synthesizer".into()), - ) - .add_edge( - NodeKind::Custom("ProcessorB".into()), - NodeKind::Custom("Synthesizer".into()), - ) - // Synthesizer feeds into End - .add_edge(NodeKind::Custom("Synthesizer".into()), NodeKind::End) - // .set_entry(NodeKind::Start) // removed: Start is virtual, no explicit entry required + .add_edge(NodeKind::Start, initializer_id.clone()) + .add_edge(initializer_id.clone(), analyzer_id.clone()) + .add_edge(initializer_id, processor_a_id.clone()) + .add_edge(analyzer_id, processor_b_id.clone()) + .add_edge(processor_a_id, synthesizer_id.clone()) + .add_edge(processor_b_id, synthesizer_id.clone()) + .add_edge(synthesizer_id, NodeKind::End) .compile()?; info!(" ✓ Complex graph compiled successfully"); diff --git a/examples/streaming_events.rs b/examples/streaming_events.rs index ef0d3e7..fe252da 100644 --- a/examples/streaming_events.rs +++ b/examples/streaming_events.rs @@ -1,75 +1,51 @@ -//! # Streaming Events Example +//! # Streaming workflow events //! -//! This example demonstrates how to stream events from a Weavegraph workflow -//! using `App::invoke_streaming`. This pattern is the foundation for building -//! real-time web dashboards, SSE endpoints, or WebSocket connections without -//! wiring `AppRunner` by hand. +//! This example focuses on the per-invocation event stream returned by +//! `App::invoke_streaming`. It is a good fit when a caller wants live progress +//! updates plus the final workflow state from the same request. //! -//! ## What This Example Shows +//! ## Highlights //! -//! 1. **Invoking the workflow** - `App::invoke_streaming(initial_state)` -//! 2. **Consuming events** - Convert `EventStream` into an async iterator -//! 3. **Forwarding to clients** - Serialize events to JSON/SSE/WebSocket frames -//! 4. **Awaiting completion** - Join the workflow handle for the final state +//! - build a small graph once and reuse it for multiple requests +//! - turn `EventStream` into an async stream with `into_async_stream` +//! - serialize each event into JSON for an SSE or WebSocket layer +//! - stop reading when the `STREAM_END_SCOPE` sentinel arrives //! -//! ## Architecture +//! ## Event flow sketch //! //! ```text -//! ┌─────────────────┐ -//! │ Workflow Node │ ──ctx.emit()──┐ -//! └─────────────────┘ │ -//! ▼ -//! ┌────────────────────────────────────────┐ -//! │ EventHub (broadcasts to all streams) │ -//! └─────────┬──────────────────────────────┘ -//! │ -//! ├─────────────┐ -//! ▼ ▼ -//! StdOutSink EventStream ──→ Your Code / SSE / WebSocket +//! node execution -> ctx.emit(...) -> event hub -> invoke_streaming receiver //! ``` //! -//! ## Web Integration -//! -//! For Axum/HTTP integration, convert the `EventStream` returned by -//! `invoke_streaming` into SSE frames. +//! ## SSE shape //! //! ```ignore //! use axum::response::sse::{Event as SseEvent, Sse}; -//! use futures_util::StreamExt; -//! -//! async fn stream_handler(State(app): State>) -> Sse<_> { -//! let (workflow, events) = app.invoke_streaming(initial_state).await; -//! +//! use futures_util::StreamExt as _; +//! async fn sse_handler(State(app): State>) -> Sse<_> { +//! let (workflow_handle, stream_rx) = app.invoke_streaming(initial_state).await; //! tokio::spawn(async move { -//! if let Err(err) = workflow.await.and_then(|res| res) { -//! tracing::error!("workflow failed: {err}"); -//! } +//! let _ = workflow_handle.join().await; //! }); -//! -//! let sse_stream = events.into_async_stream().map(|event| { -//! Ok(SseEvent::default().json_data(event).unwrap()) +//! let sse_events = stream_rx.into_async_stream().map(|item| { +//! Ok(SseEvent::default().json_data(item).unwrap()) //! }); -//! Sse::new(sse_stream) +//! Sse::new(sse_events) //! } //! ``` //! -//! **Key Points:** -//! - `App::invoke_streaming` handles the `AppRunner` boilerplate for you -//! - `EventStream::into_async_stream` is ideal for SSE/WebSocket integrations -//! - Drop-in convenience methods (`invoke_with_channel`, `invoke_with_sinks`) remain for simple scripts -//! -//! ## Run This Example +//! ## Run it //! //! ```bash -//! cargo run --example streaming_events +//! cargo run --example streaming_events --features examples //! ``` use async_trait::async_trait; -use futures_util::StreamExt; +use futures_util::StreamExt as _; use serde_json::json; use weavegraph::{ - channels::Channel, + channels::Channel as _, event_bus::{Event, STREAM_END_SCOPE}, graphs::GraphBuilder, message::{Message, Role}, @@ -84,6 +60,23 @@ use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitEx type ExampleResult = std::result::Result>; +fn event_kind_label(event: &Event) -> &'static str { + match event { + Event::Node(_) => "node", + Event::Diagnostic(_) => "diagnostic", + Event::LLM(_) => "llm", + } +} + +fn event_payload(event: &Event) -> serde_json::Value { + json!({ + "kind": event_kind_label(event), + "scope": event.scope_label().unwrap_or("workflow"), + "message": event.message(), + "observed_at": chrono::Utc::now().to_rfc3339(), + }) +} + fn init_tracing() { tracing_subscriber::registry() .with( @@ -145,66 +138,47 @@ impl Node for ProcessingNode { async fn main() -> ExampleResult<()> { init_tracing(); - info!("=== Streaming Events Example ===\n"); - - // 1. Build the workflow graph (compile once, reuse many times) - info!("Building workflow graph..."); - let app = GraphBuilder::new() - .add_node(NodeKind::Custom("Processor".into()), ProcessingNode) - .add_edge(NodeKind::Start, NodeKind::Custom("Processor".into())) - .add_edge(NodeKind::Custom("Processor".into()), NodeKind::End) - .compile()?; - - let initial_state = VersionedState::new_with_user_message("Process my data"); - let (invocation, event_stream) = app.invoke_streaming(initial_state).await; - - // 2. Consume streamed events as they arrive - info!("📡 Streaming events (these could be sent to a web client):\n"); - - let events_task = tokio::spawn(async move { - let mut count = 0usize; - let mut events = event_stream.into_async_stream(); - while let Some(event) = events.next().await { - count += 1; - let json_payload = json!({ - "type": match &event { - Event::Node(_) => "node", - Event::Diagnostic(_) => "diagnostic", - Event::LLM(_) => "llm", - }, - "scope": event.scope_label(), - "message": event.message(), - "timestamp": chrono::Utc::now().to_rfc3339(), - }); - - info!( - "📨 Stream event: {}", - serde_json::to_string_pretty(&json_payload)? - ); - - if event.scope_label() == Some(STREAM_END_SCOPE) { - info!("✅ Received STREAM_END_SCOPE sentinel; closing stream"); - break; + info!("=== Streaming workflow events ==="); + + let worker_id = NodeKind::from("processor-stream"); + let stream_builder = GraphBuilder::new() + .add_node(worker_id.clone(), ProcessingNode) + .add_edge(NodeKind::Start, worker_id.clone()) + .add_edge(worker_id, NodeKind::End); + let app = stream_builder.compile()?; + + let seed_state = VersionedState::new_with_user_message("Process dashboard data"); + let (run_handle, live_events) = app.invoke_streaming(seed_state).await; + + info!("📡 Forwarding event payloads as structured JSON"); + let collect_events = async move { + let mut stream = live_events.into_async_stream(); + let mut seen = 0usize; + + while let Some(next_event) = stream.next().await { + seen = seen.saturating_add(1); + let payload = event_payload(&next_event); + info!("📨 {}", serde_json::to_string_pretty(&payload)?); + + if matches!(next_event.scope_label(), Some(scope) if scope == STREAM_END_SCOPE) { + info!("✅ saw the terminal stream marker"); + return Ok::(seen); } } - Ok::(count) - }); - let final_state = invocation.join().await?; - let _event_count = events_task + Ok(seen) + }; + let collector = tokio::spawn(collect_events); + + let completed_state = run_handle.join().await?; + let delivered = collector .await - .map_err(|e| std::io::Error::other(e.to_string()))??; - - info!( - "🧾 Final state contains {} message(s)", - final_state.messages.snapshot().len() - ); - - info!("\n=== Example Complete ==="); - info!("\n💡 Next Steps:"); - info!(" - Use this pattern with Axum for SSE endpoints"); - info!(" - Use `invoke_with_channel` when you need a flume receiver"); - info!(" - Filter events by scope before streaming"); + .map_err(|join_err| std::io::Error::other(join_err.to_string()))??; + let final_messages = completed_state.messages.snapshot(); + + info!("🧾 Final state kept {} message(s)", final_messages.len()); + info!("📊 Stream delivered {delivered} event(s)"); + info!("💡 Pair this with Axum SSE, or swap to invoke_with_channel for flume consumers."); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index c992fc8..620dac5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,10 @@ #![cfg_attr(docsrs, feature(doc_cfg))] -//! ```text -//! GraphBuilder -> App::compile -> AppRunner -//! | | -//! | +-> Scheduler -> Nodes -> NodePartial -//! | | -//! | +-> Reducers -> VersionedState -//! | +-> EventBus (diagnostics / LLM) -//! | -//! +-> RuntimeConfig (persistence, sinks, execution knobs) -//! ``` +//! Weavegraph at a glance: +//! [`graphs::GraphBuilder`] describes the workflow topology and compiles it into an [`app::App`]. +//! [`runtimes::AppRunner`] executes the graph with schedulers, reducers, and node implementations. +//! [`state::VersionedState`] carries messages, typed extras, and errors across node boundaries. +//! [`event_bus`] exposes diagnostics and streaming hooks for observing a running workflow. //! //! Weavegraph is a graph-driven workflow framework for concurrent, stateful execution. //! You define nodes and edges with [`graphs::GraphBuilder`], compile to an [`app::App`], @@ -158,5 +153,4 @@ pub mod state; pub mod telemetry; pub mod types; pub mod utils; - -pub use control::{FrontierCommand, NodeRoute}; +pub use self::control::{FrontierCommand, NodeRoute}; diff --git a/tests/common/testing.rs b/tests/common/testing.rs index aca48c6..653d2cf 100644 --- a/tests/common/testing.rs +++ b/tests/common/testing.rs @@ -23,8 +23,8 @@ mod tests { #[tokio::test] async fn test_testnode_construction() { let node = TestNode { name: "example" }; - let bus = weavegraph::event_bus::EventBus::default(); - let ctx = NodeContext::new("test_node", 1, bus.get_emitter()); + let event_bus = weavegraph::event_bus::EventBus::default(); + let ctx = NodeContext::new("test_node", 1, event_bus.get_emitter()); let snapshot = VersionedState::builder().build().snapshot(); let result = node.run(snapshot, ctx).await; assert!(result.is_ok()); diff --git a/tests/event_bus.rs b/tests/event_bus.rs index 4450378..df90145 100644 --- a/tests/event_bus.rs +++ b/tests/event_bus.rs @@ -1,611 +1,650 @@ -use chrono::Utc; -use futures_util::{StreamExt, pin_mut}; -use proptest::prelude::*; -use rustc_hash::FxHashMap; +use async_trait::async_trait; +use chrono::Utc as ChronoUtc; +use futures_util::StreamExt; +use proptest::{collection::hash_map, prelude::*}; +use rustc_hash::FxHashMap as FastMap; use serde_json::{Number, Value, json}; -use std::fmt; -use std::sync::Arc; -use std::sync::Mutex; -use std::time::Duration; -use weavegraph::channels::Channel; +use std::{ + fmt as stdfmt, fs, + io::{self, Cursor, Write}, + path::PathBuf, + sync::{Arc, Mutex}, + time::{Duration, SystemTime, UNIX_EPOCH}, +}; +use weavegraph::channels::Channel as _; use weavegraph::event_bus::{ - ChannelSink, Event, EventBus, EventEmitter, EventSink, INVOCATION_END_SCOPE, JsonLinesSink, - LLMStreamingEvent, MemorySink, NodeEvent, STREAM_END_SCOPE, + ChannelSink, EmitterError, Event, EventBus, EventEmitter, EventHub, EventSink, + INVOCATION_END_SCOPE, JsonLinesSink, LLMStreamingEvent, MemorySink, NodeEvent, + STREAM_END_SCOPE, }; -use weavegraph::node::NodeContext; +use weavegraph::graphs::GraphBuilder; +use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; +use weavegraph::state::{StateSnapshot, VersionedState}; +use weavegraph::types::NodeKind; +use weavegraph::utils::clock::MockClock; -#[tokio::test] -async fn stop_listener_flushes_pending_events() { +fn test_bus() -> (EventBus, MemorySink) { let sink = MemorySink::new(); - let sink_snapshot = sink.clone(); - let bus = EventBus::with_sink(sink); + (EventBus::with_sink(sink.clone()), sink) +} - bus.listen_for_events(); +fn scope_message_pairs(events: &[Event]) -> Vec<(Option, String)> { + events + .iter() + .map(|event| { + ( + event.scope_label().map(str::to_owned), + event.message().to_owned(), + ) + }) + .collect() +} - let emitter = bus.get_emitter(); - emitter - .emit(Event::node_message_with_meta( - "test-node", - 42, - "scope", - "payload", - )) - .unwrap(); +fn as_node_event(event: &Event) -> &NodeEvent { + match event { + Event::Node(node) => node, + other => panic!("expected node event, got {other:?}"), + } +} - tokio::time::sleep(std::time::Duration::from_millis(10)).await; +fn as_llm_event(event: &Event) -> &LLMStreamingEvent { + match event { + Event::LLM(llm) => llm, + other => panic!("expected llm event, got {other:?}"), + } +} - bus.stop_listener().await; +#[derive(Clone)] +struct MirrorWriter { + buffer: Arc>>>, +} +impl Write for MirrorWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.buffer + .lock() + .expect("MirrorWriter mutex poisoned") + .write(buf) + } + fn flush(&mut self) -> io::Result<()> { + self.buffer + .lock() + .expect("MirrorWriter mutex poisoned") + .flush() + } +} +fn shared_buffer_text(buffer: &Arc>>>) -> String { + let locked = buffer.lock().expect("shared buffer mutex poisoned"); + String::from_utf8(locked.get_ref().clone()).expect("buffer should contain valid utf-8") +} +fn parse_jsonl(buffer: &Arc>>>) -> Vec { + shared_buffer_text(buffer) + .lines() + .map(|line| serde_json::from_str(line).expect("line should be valid json")) + .collect() +} +fn repo_jsonl_path(label: &str) -> PathBuf { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time should be after unix epoch") + .as_nanos(); + let directory = std::env::current_dir() + .expect("current directory should be available") + .join("target") + .join("test-artifacts"); + fs::create_dir_all(&directory).expect("test artifact directory should be creatable"); + directory.join(format!("{label}-{unique}.jsonl")) +} +async fn wait_for_workers() { + tokio::time::sleep(Duration::from_millis(40)).await; +} +#[derive(Clone, Default)] +struct CapturingEmitter { + recorded: Arc>>, +} +impl CapturingEmitter { + fn snapshot(&self) -> Vec { + self.recorded + .lock() + .expect("CapturingEmitter mutex poisoned") + .clone() + } +} +impl stdfmt::Debug for CapturingEmitter { + fn fmt(&self, f: &mut stdfmt::Formatter<'_>) -> stdfmt::Result { + f.debug_struct("CapturingEmitter") + .field("event_count", &self.snapshot().len()) + .finish() + } +} +impl EventEmitter for CapturingEmitter { + fn emit(&self, event: Event) -> Result<(), EmitterError> { + self.recorded + .lock() + .expect("CapturingEmitter mutex poisoned") + .push(event); + Ok(()) + } +} +#[derive(Debug)] +struct NoopTerminalNode; +#[async_trait] +impl Node for NoopTerminalNode { + async fn run(&self, _: StateSnapshot, _: NodeContext) -> Result { + Ok(NodePartial::default()) + } +} - let entries = sink_snapshot.snapshot(); - assert_eq!(entries.len(), 1); - assert_eq!(entries[0].message(), "payload"); +fn ascii_text() -> impl Strategy { + proptest::string::string_regex("[A-Za-z0-9 _-]{0,24}").expect("regex should compile") } -#[tokio::test] -async fn stopping_without_events_is_noop() { - let bus = EventBus::with_sink(MemorySink::new()); - bus.listen_for_events(); - bus.stop_listener().await; +fn leaf_json_value() -> impl Strategy { + prop_oneof![ + Just(Value::Null), + ascii_text().prop_map(Value::String), + any::().prop_map(Value::Bool), + any::().prop_map(|value| Value::Number(Number::from(value))), + ] +} +fn any_event() -> impl Strategy { + let diagnostics = + (ascii_text(), ascii_text()).prop_map(|(scope, message)| Event::diagnostic(scope, message)); + let plain_nodes = ( + prop::option::of(ascii_text()), + prop::option::of(any::()), + ascii_text(), + ascii_text(), + ) + .prop_map(|(node_id, step, scope, message)| { + Event::Node(NodeEvent::new(node_id, step, scope, message)) + }); + let nodes_with_metadata = ( + ascii_text(), + any::(), + ascii_text(), + ascii_text(), + hash_map(ascii_text(), leaf_json_value(), 0..4), + ) + .prop_map(|(node_id, step, scope, message, metadata)| { + Event::node_message_with_metadata( + node_id, + step, + scope, + message, + metadata.into_iter().collect(), + ) + }); + let llm_events = ( + prop::option::of(ascii_text()), + prop::option::of(ascii_text()), + prop::option::of(ascii_text()), + ascii_text(), + hash_map(ascii_text(), leaf_json_value(), 0..4), + any::(), + ) + .prop_map( + |(session_id, node_id, stream_id, chunk, metadata, is_final)| { + let mut builder = LLMStreamingEvent::builder(chunk) + .is_final(is_final) + .metadata(metadata.into_iter().collect()); + if let Some(id) = session_id { + builder = builder.session_id(id); + } + if let Some(id) = node_id { + builder = builder.node_id(id); + } + if let Some(id) = stream_id { + builder = builder.stream_id(id); + } + Event::LLM(builder.build()) + }, + ); + + prop_oneof![diagnostics, plain_nodes, nodes_with_metadata, llm_events] } #[tokio::test] -async fn memory_sink_captures_events_with_scope_and_messages() { - let sink = MemorySink::new(); - let sink_snapshot = sink.clone(); - let bus = EventBus::with_sink(sink); - +async fn stop_listener_flushes_pending_events() { + let (bus, sink) = test_bus(); bus.listen_for_events(); - - let emitter = bus.get_emitter(); - - emitter - .emit(Event::node_message("Scope1", "one")) - .expect("emit one"); - emitter - .emit(Event::node_message("Scope1", "two")) - .expect("emit two"); - - emitter - .emit(Event::diagnostic("Scope2", "three")) - .expect("emit three"); - emitter - .emit(Event::diagnostic("Scope2", "four")) - .expect("emit four"); - - tokio::time::sleep(std::time::Duration::from_millis(20)).await; + let publisher = bus.get_emitter(); + publisher + .emit(Event::node_message_with_meta( + "node-a", 42, "scope-a", "payload", + )) + .expect("emit should succeed"); + wait_for_workers().await; bus.stop_listener().await; - let entries = sink_snapshot.snapshot(); - assert_eq!(entries.len(), 4); - - assert_eq!(entries[0].scope_label(), Some("Scope1")); - assert_eq!(entries[0].message(), "one"); - - assert_eq!(entries[1].scope_label(), Some("Scope1")); - assert_eq!(entries[1].message(), "two"); - - assert_eq!(entries[2].scope_label(), Some("Scope2")); - assert_eq!(entries[2].message(), "three"); - - assert_eq!(entries[3].scope_label(), Some("Scope2")); - assert_eq!(entries[3].message(), "four"); + let events = sink.snapshot(); + assert_eq!(events.len(), 1); + assert_eq!( + (events[0].message(), as_node_event(&events[0]).node_id()), + ("payload", Some("node-a")) + ); } - #[tokio::test] -async fn multiple_listen_calls_are_idempotent() { - let sink = MemorySink::new(); - let sink_snapshot = sink.clone(); +async fn stopping_idle_listener_is_safe() { + let (_, sink) = test_bus(); let bus = EventBus::with_sink(sink); - - // Call listen multiple times; only one listener should be active. - bus.listen_for_events(); bus.listen_for_events(); + bus.stop_listener().await; + bus.stop_listener().await; +} +#[tokio::test] +async fn memory_sink_keeps_scope_and_message_sequence() { + let (bus, sink) = test_bus(); bus.listen_for_events(); - - let emitter = bus.get_emitter(); - emitter.emit(Event::node_message("S", "a")).unwrap(); - emitter.emit(Event::node_message("S", "b")).unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(20)).await; + let publisher = bus.get_emitter(); + let sequence = [ + Event::node_message("Scope1", "one"), + Event::node_message("Scope1", "two"), + Event::diagnostic("Scope2", "three"), + Event::diagnostic("Scope2", "four"), + ]; + for item in sequence { + publisher.emit(item).expect("emit should succeed"); + } + wait_for_workers().await; bus.stop_listener().await; - let entries = sink_snapshot.snapshot(); - assert_eq!(entries.len(), 2); - assert!(entries.iter().any(|e| e.message() == "a")); - assert!(entries.iter().any(|e| e.message() == "b")); + assert_eq!( + scope_message_pairs(&sink.snapshot()), + vec![ + (Some("Scope1".to_owned()), "one".to_owned()), + (Some("Scope1".to_owned()), "two".to_owned()), + (Some("Scope2".to_owned()), "three".to_owned()), + (Some("Scope2".to_owned()), "four".to_owned()), + ] + ); } - #[tokio::test] -async fn memory_sink_preserves_order_under_concurrency() { - use tokio::task; +async fn repeated_listen_calls_keep_single_worker_per_sink() { + let (bus, sink) = test_bus(); + for _ in 0..3 { + bus.listen_for_events(); + } + let publisher = bus.get_emitter(); + publisher + .emit(Event::node_message("single", "first")) + .expect("first emit should succeed"); + publisher + .emit(Event::node_message("single", "second")) + .expect("second emit should succeed"); + wait_for_workers().await; + bus.stop_listener().await; - let sink = MemorySink::new(); - let sink_snapshot = sink.clone(); - let bus = EventBus::with_sink(sink); + let events = sink.snapshot(); + assert_eq!( + [events[0].message(), events[1].message()], + ["first", "second"] + ); +} +#[tokio::test] +async fn concurrent_publishers_preserve_send_order() { + let (bus, sink) = test_bus(); bus.listen_for_events(); - - let emitter = bus.get_emitter(); - let mut handles = Vec::new(); - let total = 20u32; - for i in 0..total { - let emitter = Arc::clone(&emitter); - handles.push(task::spawn(async move { - // Stagger sends to establish a deterministic order. - tokio::time::sleep(std::time::Duration::from_millis((i * 2) as u64)).await; - emitter - .emit(Event::node_message("ORDER", format!("m{i}"))) - .expect("emit"); + let source = bus.get_emitter(); + let count = 12u32; + let mut tasks = Vec::new(); + for index in 0..count { + let publisher = Arc::clone(&source); + tasks.push(tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis((index * 2) as u64)).await; + publisher + .emit(Event::node_message("ORDER", format!("msg-{index}"))) + .expect("emit should succeed"); })); } - - for h in handles { - let _ = h.await; + for task in tasks { + task.await.expect("task should join"); } - - tokio::time::sleep(std::time::Duration::from_millis((total * 3) as u64)).await; bus.stop_listener().await; - let entries = sink_snapshot.snapshot(); - assert_eq!(entries.len() as u32, total); - for (idx, entry) in entries.iter().enumerate() { - let expected = format!("m{idx}"); - assert_eq!( - entry.message(), - &expected, - "entry {idx} should have message {expected}, got: {}", - entry.message() - ); - } + let ordered_messages: Vec = sink + .snapshot() + .into_iter() + .map(|event| event.message().to_owned()) + .collect(); + let expected: Vec = (0..count).map(|index| format!("msg-{index}")).collect(); + assert_eq!(ordered_messages, expected); } - #[tokio::test] -async fn channel_sink_forwards_events() { +async fn channel_sink_delivers_events_to_receiver() { let (tx, rx) = flume::unbounded(); let bus = EventBus::with_sink(ChannelSink::new(tx)); bus.listen_for_events(); + let publisher = bus.get_emitter(); + publisher + .emit(Event::diagnostic("channel", "hello world")) + .expect("emit should succeed"); - bus.get_emitter() - .emit(Event::diagnostic("test", "hello world")) - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - - let received = rx.recv_async().await.unwrap(); - assert_eq!(received.message(), "hello world"); - assert_eq!(received.scope_label(), Some("test")); + let received = tokio::time::timeout(Duration::from_secs(1), rx.recv_async()) + .await + .expect("receiver should wake in time") + .expect("channel should carry an event"); + assert_eq!( + (received.scope_label(), received.message()), + (Some("channel"), "hello world") + ); } #[tokio::test] -async fn multi_sink_broadcast() { +async fn bus_broadcasts_each_event_to_every_sink() { let memory = MemorySink::new(); let (tx, rx) = flume::unbounded(); - let bus = EventBus::with_sinks(vec![ Box::new(memory.clone()), Box::new(ChannelSink::new(tx)), ]); bus.listen_for_events(); + let publisher = bus.get_emitter(); + publisher + .emit(Event::diagnostic("broadcast", "fan out")) + .expect("emit should succeed"); + wait_for_workers().await; + bus.stop_listener().await; - bus.get_emitter() - .emit(Event::diagnostic("test", "broadcast message")) - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - - let memory_events = memory.snapshot(); - assert_eq!(memory_events.len(), 1); - assert_eq!(memory_events[0].message(), "broadcast message"); - - let channel_event = rx.recv_async().await.unwrap(); - assert_eq!(channel_event.message(), "broadcast message"); + assert_eq!(memory.snapshot()[0].message(), "fan out"); + let mirrored = tokio::time::timeout(Duration::from_secs(1), rx.recv_async()) + .await + .expect("receiver should wake in time") + .expect("channel should carry an event"); + assert_eq!(mirrored.message(), "fan out"); } - #[tokio::test] -async fn add_sink_dynamically() { - let bus = EventBus::default(); // Starts with StdOutSink +async fn sink_added_after_start_receives_later_events() { + let bus = EventBus::with_sinks(Vec::new()); bus.listen_for_events(); - let (tx, rx) = flume::unbounded(); bus.add_sink(ChannelSink::new(tx)); + let publisher = bus.get_emitter(); + publisher + .emit(Event::diagnostic("dynamic", "arrived later")) + .expect("emit should succeed"); + let received = tokio::time::timeout(Duration::from_secs(1), rx.recv_async()) + .await + .expect("receiver should wake in time") + .expect("channel should carry an event"); + assert_eq!(received.message(), "arrived later"); +} +#[tokio::test(flavor = "current_thread")] +async fn stop_listener_waits_for_every_sink() { + let first = MemorySink::new(); + let second = MemorySink::new(); + let bus = EventBus::with_sinks(vec![Box::new(first.clone()), Box::new(second.clone())]); + bus.listen_for_events(); + let publisher = bus.get_emitter(); - bus.get_emitter() - .emit(Event::diagnostic("test", "dynamic sink")) - .unwrap(); - - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + for index in 0..10 { + publisher + .emit(Event::diagnostic("drain", format!("msg-{index}"))) + .expect("emit should succeed"); + } + wait_for_workers().await; + bus.stop_listener().await; - let received = rx.recv_async().await.unwrap(); - assert_eq!(received.message(), "dynamic sink"); + assert_eq!((first.snapshot().len(), second.snapshot().len()), (10, 10)); } - #[tokio::test] -async fn channel_sink_handles_dropped_receiver() { - use std::io::ErrorKind; - - let (tx, rx) = flume::unbounded(); - let mut sink = ChannelSink::new(tx); - - drop(rx); +async fn stop_listener_can_interrupt_active_publish_loop() { + let bus = Arc::new(EventBus::with_sink(MemorySink::new())); + bus.listen_for_events(); + let publisher = bus.get_emitter(); + let emission = tokio::spawn(async move { + for index in 0..1_000u32 { + let _ = publisher.emit(Event::diagnostic("stress", format!("{index}"))); + tokio::task::yield_now().await; + } + }); + tokio::time::sleep(Duration::from_millis(20)).await; - let event = Event::diagnostic("test", "msg"); - let result = sink.handle(&event); + let stop_result = tokio::time::timeout(Duration::from_secs(1), bus.stop_listener()).await; + assert!(stop_result.is_ok(), "stop_listener should finish promptly"); - assert!(result.is_err()); - assert_eq!(result.unwrap_err().kind(), ErrorKind::BrokenPipe); + emission.abort(); + let _ = emission.await; } - #[tokio::test] -async fn async_stream_adapter_yields_events() { - let bus = EventBus::with_sink(MemorySink::new()); - let emitter = bus.get_emitter(); +async fn listener_can_restart_after_stop() { + let (bus, sink) = test_bus(); + bus.listen_for_events(); + let first_cycle = bus.get_emitter(); + first_cycle + .emit(Event::diagnostic("cycle-1", "first")) + .expect("first emit should succeed"); + wait_for_workers().await; + bus.stop_listener().await; + + bus.listen_for_events(); + let second_cycle = bus.get_emitter(); + second_cycle + .emit(Event::diagnostic("cycle-2", "second")) + .expect("second emit should succeed"); + wait_for_workers().await; + bus.stop_listener().await; + + assert_eq!( + scope_message_pairs(&sink.snapshot()), + [ + (Some("cycle-1".to_owned()), "first".to_owned()), + (Some("cycle-2".to_owned()), "second".to_owned()), + ] + ); +} +#[tokio::test(flavor = "current_thread")] +async fn async_stream_adapter_produces_published_event() { + let (bus, _) = test_bus(); + let publisher = bus.get_emitter(); let stream = bus.subscribe().into_async_stream(); - pin_mut!(stream); - emitter + tokio::pin!(stream); + publisher .emit(Event::diagnostic("async", "stream")) - .expect("emit"); - - let event = stream.next().await.expect("event"); - assert_eq!(event.message(), "stream"); - assert_eq!(event.scope_label(), Some("async")); + .expect("emit should succeed"); + let observed = stream.next().await.expect("stream should yield an event"); + assert_eq!(observed.message(), "stream"); + assert_eq!( + scope_message_pairs(std::slice::from_ref(&observed)), + [(Some("async".to_owned()), "stream".to_owned())] + ); } - -#[tokio::test] -async fn next_timeout_reports_timeouts_and_events() { - let bus = EventBus::with_sink(MemorySink::new()); - let emitter = bus.get_emitter(); - let mut stream = bus.subscribe(); - +#[tokio::test(flavor = "current_thread")] +async fn next_timeout_distinguishes_idle_streams_from_published_events() { + let (bus, _) = test_bus(); + let publisher = bus.get_emitter(); + let mut subscription = bus.subscribe(); assert!( - stream - .next_timeout(Duration::from_millis(10)) + subscription + .next_timeout(Duration::from_millis(20)) .await .is_none() ); - - emitter + publisher .emit(Event::diagnostic("timeout", "delivered")) - .expect("emit"); - - let event = stream - .next_timeout(Duration::from_secs(1)) + .expect("emit should succeed"); + let observed = subscription + .next_timeout(Duration::from_millis(500)) .await - .expect("event after emit"); - assert_eq!(event.message(), "delivered"); - assert_eq!(event.scope_label(), Some("timeout")); + .expect("stream should receive the published event"); + assert_eq!(observed.message(), "delivered"); + assert_eq!( + scope_message_pairs(std::slice::from_ref(&observed)), + [(Some("timeout".to_owned()), "delivered".to_owned())] + ); } - -#[tokio::test] -async fn blocking_iterator_receives_events() { - let bus = EventBus::with_sink(MemorySink::new()); - let emitter = bus.get_emitter(); - let iter = bus.subscribe().into_blocking_iter(); - - let handle = tokio::task::spawn_blocking(move || { - let mut iter = iter; - iter.next() +#[tokio::test(flavor = "current_thread")] +async fn blocking_iterator_receives_next_event() { + let (bus, _) = test_bus(); + let publisher = bus.get_emitter(); + let blocking_iter = bus.subscribe().into_blocking_iter(); + let worker = tokio::task::spawn_blocking(move || { + let mut blocking_iter = blocking_iter; + blocking_iter.next() }); - tokio::time::sleep(Duration::from_millis(10)).await; - emitter + publisher .emit(Event::diagnostic("blocking", "iter")) - .expect("emit"); - - let event = handle.await.expect("join").expect("event"); - assert_eq!(event.message(), "iter"); - assert_eq!(event.scope_label(), Some("blocking")); + .expect("emit should succeed"); + let observed = worker + .await + .expect("blocking task should join") + .expect("iterator should yield one event"); + assert_eq!(observed.message(), "iter"); + assert_eq!( + scope_message_pairs(std::slice::from_ref(&observed)), + [(Some("blocking".to_owned()), "iter".to_owned())] + ); } - -#[tokio::test] -async fn event_stream_closes_when_bus_dropped() { - use std::time::Duration; - - let mut stream = { - let bus = EventBus::with_sink(MemorySink::new()); +#[tokio::test(flavor = "current_thread")] +async fn subscription_closes_after_bus_drop() { + let mut subscription = { + let (bus, _) = test_bus(); bus.listen_for_events(); bus.subscribe() }; - assert!( - stream + subscription .next_timeout(Duration::from_millis(50)) .await - .is_none(), - "expected broadcast stream to close after EventBus drop" + .is_none() ); } - -#[tokio::test] -async fn stop_listener_drains_multiple_sinks() { - use std::time::Duration; - - let sink1 = MemorySink::new(); - let sink2 = MemorySink::new(); - let snapshot1 = sink1.clone(); - let snapshot2 = sink2.clone(); - - let bus = EventBus::with_sinks(vec![Box::new(sink1), Box::new(sink2)]); - bus.listen_for_events(); - - let emitter = bus.get_emitter(); - for i in 0..10 { - emitter - .emit(Event::diagnostic("test", format!("msg {i}"))) - .unwrap(); - } - - tokio::time::sleep(Duration::from_millis(50)).await; - bus.stop_listener().await; - - assert_eq!(snapshot1.snapshot().len(), 10); - assert_eq!(snapshot2.snapshot().len(), 10); -} - -#[tokio::test] -async fn stop_listener_during_emission() { - use tokio::task; - - let bus = Arc::new(EventBus::with_sink(MemorySink::new())); - bus.listen_for_events(); - - let emitter = bus.get_emitter(); - let emit_task = task::spawn(async move { - for i in 0..1000u32 { - let _ = emitter.emit(Event::diagnostic("stress", format!("{i}"))); - task::yield_now().await; - } - }); - - tokio::time::sleep(std::time::Duration::from_millis(20)).await; - bus.stop_listener().await; - emit_task.abort(); -} - -#[tokio::test] -async fn restart_after_stop() { - use std::time::Duration; - - let sink = MemorySink::new(); - let snapshot = sink.clone(); - let bus = EventBus::with_sink(sink); - - bus.listen_for_events(); - bus.get_emitter() - .emit(Event::diagnostic("cycle1", "msg1")) - .unwrap(); - tokio::time::sleep(Duration::from_millis(10)).await; - bus.stop_listener().await; - - bus.listen_for_events(); - bus.get_emitter() - .emit(Event::diagnostic("cycle2", "msg2")) - .unwrap(); - tokio::time::sleep(Duration::from_millis(10)).await; - bus.stop_listener().await; - - let events = snapshot.snapshot(); - assert_eq!(events.len(), 2); - assert_eq!(events[0].message(), "msg1"); - assert_eq!(events[1].message(), "msg2"); -} - -#[tokio::test] -async fn invoke_streaming_emits_terminal_event() { - use async_trait::async_trait; - use futures_util::StreamExt; - use weavegraph::graphs::GraphBuilder; - use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; - use weavegraph::state::{StateSnapshot, VersionedState}; - use weavegraph::types::NodeKind; - - struct TerminalNode; - - #[async_trait] - impl Node for TerminalNode { - async fn run(&self, _: StateSnapshot, _: NodeContext) -> Result { - Ok(NodePartial::default()) - } +#[tokio::test(flavor = "current_thread")] +async fn invoke_streaming_appends_terminal_scope_event() { + let terminal_kind = NodeKind::Custom("terminal".to_owned()); + let builder = GraphBuilder::new() + .add_node(terminal_kind.clone(), NoopTerminalNode) + .add_edge(NodeKind::Start, terminal_kind.clone()) + .add_edge(terminal_kind, NodeKind::End); + let app = builder.compile().expect("graph should compile"); + let request = VersionedState::new_with_user_message("finish"); + let (handle, stream) = app.invoke_streaming(request).await; + let collector = + tokio::spawn(async move { stream.into_async_stream().collect::>().await }); + let final_state = handle.join().await.expect("workflow should finish"); + let emitted_messages = final_state.messages.snapshot().len(); + assert_eq!(emitted_messages, 1); + let events = collector.await.expect("collector should join"); + let terminal = events.last().expect("stream should emit a terminal event"); + assert_eq!(terminal.scope_label(), Some(STREAM_END_SCOPE)); +} +#[tokio::test(flavor = "current_thread")] +async fn hub_metrics_count_lagged_events() { + use tokio::sync::broadcast::error::RecvError as BroadcastRecvError; + let lag_hub = EventHub::new(1); + let publisher = lag_hub.emitter(); + let mut subscription = lag_hub.subscribe(); + for message in ["first", "second"] { + publisher + .emit(Event::diagnostic("metrics", message)) + .expect("emit should succeed"); } - - let app = GraphBuilder::new() - .add_node(NodeKind::Custom("terminal".into()), TerminalNode) - .add_edge(NodeKind::Start, NodeKind::Custom("terminal".into())) - .add_edge(NodeKind::Custom("terminal".into()), NodeKind::End) - .compile() - .expect("graph"); - - let initial = VersionedState::new_with_user_message("finish"); - let (handle, event_stream) = app.invoke_streaming(initial).await; - - let collector = tokio::spawn(async move { - let mut collected = Vec::new(); - let mut stream = event_stream.into_async_stream(); - while let Some(event) = stream.next().await { - collected.push(event); - } - collected - }); - - let final_state = handle.join().await.expect("workflow"); - assert_eq!(final_state.messages.snapshot().len(), 1); - - let events = collector.await.expect("collector join"); - let end_event = events.last().expect("at least one terminal event"); - assert_eq!(end_event.scope_label(), Some(STREAM_END_SCOPE)); -} - -#[tokio::test] -async fn event_hub_metrics_track_drops() { - use tokio::sync::broadcast::error::RecvError; - use weavegraph::event_bus::EventHub; - - let hub = EventHub::new(1); - let emitter = hub.emitter(); - let mut stream = hub.subscribe(); - - emitter - .emit(Event::diagnostic("metrics", "first")) - .expect("emit first event"); - emitter - .emit(Event::diagnostic("metrics", "second")) - .expect("emit second event"); - - let missed = match stream.recv().await { - Err(RecvError::Lagged(missed)) => missed, - Ok(event) => { - panic!("expected lagged error, received event: {:?}", event); - } - Err(err) => panic!("unexpected recv error: {err:?}"), + let dropped_by_lag = match subscription.recv().await { + Err(BroadcastRecvError::Lagged(count)) => count, + Ok(event) => panic!("expected lagged error, received {event:?}"), + Err(other) => panic!("unexpected recv error: {other:?}"), }; - - assert_eq!(missed, 1); - - let metrics = hub.metrics(); - assert_eq!(metrics.capacity, 1); - assert_eq!(metrics.dropped, 1); -} - -#[test] -fn event_bus_metrics_expose_capacity() { - let bus = EventBus::default(); - let metrics = bus.metrics(); - assert_eq!(metrics.capacity, 1024); - assert_eq!(metrics.dropped, 0); -} - -#[derive(Default)] -struct RecordingEmitter { - events: Arc>>, -} - -impl RecordingEmitter { - fn record(&self, event: Event) { - self.events - .lock() - .expect("RecordingEmitter mutex poisoned") - .push(event); - } - - fn snapshot(&self) -> Vec { - self.events - .lock() - .expect("RecordingEmitter mutex poisoned") - .clone() - } -} - -impl fmt::Debug for RecordingEmitter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RecordingEmitter") - .field( - "event_count", - &self - .events - .lock() - .expect("RecordingEmitter mutex poisoned") - .len(), - ) - .finish() - } -} - -impl EventEmitter for RecordingEmitter { - fn emit(&self, event: Event) -> Result<(), weavegraph::event_bus::EmitterError> { - self.record(event); - Ok(()) - } -} - -#[test] -fn node_context_emits_all_event_variants() { - let emitter = Arc::new(RecordingEmitter::default()); - let event_emitter: Arc = emitter.clone(); - let ctx = NodeContext::new("node-a", 7, event_emitter); - - ctx.emit("progress", "started").unwrap(); - ctx.emit_diagnostic("diagnostic", "all good").unwrap(); - - let mut metadata = FxHashMap::default(); - metadata.insert("token_count".into(), json!(42)); - ctx.emit_llm_chunk( - Some("session-1".into()), - Some("stream-1".into()), - "chunk text", - Some(metadata), - ) - .unwrap(); - - ctx.emit_llm_final( - Some("session-1".into()), - Some("stream-1".into()), - "final chunk", - None, - ) - .unwrap(); - - ctx.emit_llm_error( - Some("session-1".into()), - Some("stream-1".into()), - "error occurred", - ) - .unwrap(); - - let events = emitter.snapshot(); - assert_eq!(events.len(), 5); - - match &events[0] { - Event::Node(node) => { - assert_eq!(node.node_id(), Some("node-a")); - assert_eq!(node.step(), Some(7)); - assert_eq!(node.scope(), "progress"); - assert_eq!(node.message(), "started"); - } - other => panic!("expected node event, got {other:?}"), - } - - match &events[1] { - Event::Diagnostic(diag) => { - assert_eq!(diag.scope(), "diagnostic"); - assert_eq!(diag.message(), "all good"); - } - other => panic!("expected diagnostic event, got {other:?}"), - } - - match &events[2] { - Event::LLM(llm) => { - assert_eq!(llm.session_id(), Some("session-1")); - assert_eq!(llm.node_id(), Some("node-a")); - assert_eq!(llm.stream_id(), Some("stream-1")); - assert!(!llm.is_final()); - assert_eq!(llm.chunk(), "chunk text"); - assert_eq!(llm.metadata().get("token_count"), Some(&json!(42))); - } - other => panic!("expected LLM chunk event, got {other:?}"), - } - - match &events[3] { - Event::LLM(llm) => { - assert!(llm.is_final()); - assert_eq!(llm.chunk(), "final chunk"); - assert!(llm.metadata().is_empty()); - } - other => panic!("expected final LLM event, got {other:?}"), - } - - match &events[4] { - Event::LLM(llm) => { - assert!(llm.is_final()); - assert_eq!(llm.chunk(), "error occurred"); - assert_eq!(llm.metadata().get("severity"), Some(&json!("error"))); - } - other => panic!("expected LLM error event, got {other:?}"), - } + let health = lag_hub.metrics(); + assert_eq!((dropped_by_lag, health.capacity, health.dropped), (1, 1, 1)); +} +#[cfg_attr(test, test)] +fn default_bus_keeps_expected_capacity_metrics() { + let default_report = EventBus::default().metrics(); + assert_eq!((default_report.capacity, default_report.dropped), (1024, 0)); +} +#[cfg_attr(test, test)] +fn node_context_emit_adds_identity_and_runtime_metadata() { + let recorder = CapturingEmitter::default(); + let emitter: Arc = Arc::new(recorder.clone()); + let mut context = NodeContext::new("node-a", 7, emitter); + context.invocation_id = Some("invoke-1".to_owned()); + context.clock = Some(Arc::new(MockClock::new(123))); + + context + .emit("progress", "started") + .expect("node event should emit"); + + let events = recorder.snapshot(); + assert_eq!(events.len(), 1); + let node = as_node_event(&events[0]); + assert_eq!(node.node_id(), Some("node-a")); + assert_eq!(node.step(), Some(7)); + assert_eq!(node.scope(), "progress"); + assert_eq!(node.message(), "started"); + assert_eq!( + node.metadata().get("invocation_id"), + Some(&json!("invoke-1")) + ); + assert_eq!(node.metadata().get("now_unix_ms"), Some(&json!(123_000))); +} +#[cfg_attr(any(test), test)] +fn node_context_emits_diagnostics_and_llm_variants() { + let recorder = CapturingEmitter::default(); + let emitter: Arc = Arc::new(recorder.clone()); + let context = NodeContext::new("node-a", 7, emitter); + let mut llm_metadata = FastMap::default(); + llm_metadata.insert("token_count".to_owned(), json!(42)); + + context + .emit_diagnostic("diagnostic", "all good") + .expect("diagnostic should emit"); + context + .emit_llm_chunk( + Some("session-1".to_owned()), + Some("stream-1".to_owned()), + "chunk text", + Some(llm_metadata), + ) + .expect("chunk event should emit"); + context + .emit_llm_final( + Some("session-1".to_owned()), + Some("stream-1".to_owned()), + "final chunk", + None, + ) + .expect("final event should emit"); + context + .emit_llm_error( + Some("session-1".to_owned()), + Some("stream-1".to_owned()), + "error occurred", + ) + .expect("error event should emit"); + + let captured = recorder.snapshot(); + assert_eq!(captured.len(), 4); + let Event::Diagnostic(diagnostic) = &captured[0] else { + panic!("expected diagnostic event, got {:?}", captured[0]); + }; + assert_eq!( + (diagnostic.scope(), diagnostic.message()), + ("diagnostic", "all good") + ); + let chunk = as_llm_event(&captured[1]); + assert_eq!(chunk.session_id(), Some("session-1")); + assert_eq!(chunk.node_id(), Some("node-a")); + assert_eq!(chunk.stream_id(), Some("stream-1")); + assert!(!chunk.is_final()); + assert_eq!(chunk.chunk(), "chunk text"); + assert_eq!(chunk.metadata().get("token_count"), Some(&json!(42))); + let final_chunk = as_llm_event(&captured[2]); + assert!(final_chunk.is_final()); + assert_eq!(final_chunk.chunk(), "final chunk"); + assert!(final_chunk.metadata().is_empty()); + let error_event = as_llm_event(&captured[3]); + assert!(error_event.is_final()); + assert_eq!(error_event.chunk(), "error occurred"); + assert_eq!( + error_event.metadata().get("severity"), + Some(&json!("error")) + ); } - -#[test] -fn node_event_metadata_defaults_when_deserializing_legacy_payloads() { - let legacy = r#"{"node_id":"legacy-node","step":3,"scope":"legacy","message":"old"}"#; - let event: NodeEvent = serde_json::from_str(legacy).expect("legacy node event should decode"); +#[cfg_attr(any(test), test)] +fn legacy_node_payloads_decode_with_empty_metadata() { + let payload = r#"{"node_id":"legacy-node","step":3,"scope":"legacy","message":"old"}"#; + let event: NodeEvent = serde_json::from_str(payload).expect("legacy payload should decode"); assert_eq!(event.node_id(), Some("legacy-node")); assert_eq!(event.step(), Some(3)); @@ -615,129 +654,46 @@ fn node_event_metadata_defaults_when_deserializing_legacy_payloads() { } #[test] -fn node_event_runtime_metadata_is_preserved_and_structured_fields_win_collisions() { - let mut metadata = FxHashMap::default(); - metadata.insert("custom".to_string(), json!({ "nested": true })); - metadata.insert("node_id".to_string(), json!("spoofed-node")); - metadata.insert("step".to_string(), json!(0)); +fn structured_metadata_beats_conflicting_runtime_keys() { + let mut metadata = FastMap::default(); + metadata.insert("custom".to_owned(), json!({"nested": true})); + metadata.insert("node_id".to_owned(), json!("spoofed-node")); + metadata.insert("step".to_owned(), json!(0)); - let event = Event::node_message_with_metadata("real-node", 42, "scope", "message", metadata); - let value = event.to_json_value(); + let json = Event::node_message_with_metadata("real-node", 42, "scope", "message", metadata) + .to_json_value(); - assert_eq!(value["metadata"]["custom"], json!({ "nested": true })); - assert_eq!(value["metadata"]["node_id"], "real-node"); - assert_eq!(value["metadata"]["step"], 42); + assert_eq!(json["metadata"]["custom"], json!({"nested": true})); + assert_eq!(json["metadata"]["node_id"], "real-node"); + assert_eq!(json["metadata"]["step"], 42); } #[test] -fn stream_scope_constants_are_distinct_and_stable() { +fn stream_boundary_scope_constants_remain_distinct() { assert_eq!(STREAM_END_SCOPE, "__weavegraph_stream_end__"); assert_eq!(INVOCATION_END_SCOPE, "__weavegraph_invocation_end__"); assert_ne!(STREAM_END_SCOPE, INVOCATION_END_SCOPE); } -fn text_strategy() -> impl Strategy { - proptest::string::string_regex("[A-Za-z0-9 _-]{0,32}").unwrap() -} - -fn json_value_strategy() -> impl Strategy { - prop_oneof![ - Just(Value::Null), - text_strategy().prop_map(Value::String), - any::().prop_map(Value::Bool), - prop::num::f64::NORMAL.prop_map(|f| { - let bounded = f.clamp(-1_000_000.0, 1_000_000.0).trunc(); - Number::from_f64(bounded).map_or(Value::Number(Number::from(0)), Value::Number) - }), - ] -} - -fn event_strategy() -> impl Strategy { - let diagnostic = (text_strategy(), text_strategy()) - .prop_map(|(scope, message)| Event::diagnostic(scope, message)); - - let plain_node = ( - prop::option::of(text_strategy()), - prop::option::of(any::()), - text_strategy(), - text_strategy(), - ) - .prop_map(|(node_id, step, scope, message)| { - Event::Node(NodeEvent::new(node_id, step, scope, message)) - }); - - let node_with_metadata = ( - text_strategy(), - any::(), - text_strategy(), - text_strategy(), - prop::collection::hash_map(text_strategy(), json_value_strategy(), 0..4), - ) - .prop_map(|(node_id, step, scope, message, metadata)| { - let meta: FxHashMap = metadata.into_iter().collect(); - Event::node_message_with_metadata(node_id, step, scope, message, meta) - }); - - let llm = ( - prop::option::of(text_strategy()), - prop::option::of(text_strategy()), - prop::option::of(text_strategy()), - text_strategy(), - prop::collection::hash_map(text_strategy(), json_value_strategy(), 0..4), - any::(), - ) - .prop_map( - |(session_id, node_id, stream_id, chunk, metadata, is_final)| { - let meta: FxHashMap = metadata.into_iter().collect(); - let mut b = LLMStreamingEvent::builder(chunk) - .is_final(is_final) - .metadata(meta); - if let Some(id) = session_id { - b = b.session_id(id); - } - if let Some(id) = node_id { - b = b.node_id(id); - } - if let Some(id) = stream_id { - b = b.stream_id(id); - } - Event::LLM(b.build()) - }, - ); - - prop_oneof![diagnostic, plain_node, node_with_metadata, llm] -} - -proptest! { - #[test] - fn event_serialization_roundtrip(event in event_strategy()) { - let json = serde_json::to_string(&event).expect("serialize"); - let decoded: Event = serde_json::from_str(&json).expect("deserialize"); - prop_assert_eq!(decoded, event); - } -} - -// JSON serialization - -struct SharedWriter(Arc>>>); - -impl std::io::Write for SharedWriter { - fn write(&mut self, buf: &[u8]) -> std::io::Result { - self.0 - .lock() - .expect("SharedWriter mutex poisoned") - .write(buf) - } - - fn flush(&mut self) -> std::io::Result<()> { - self.0.lock().expect("SharedWriter mutex poisoned").flush() - } +#[test] +fn event_json_roundtrip_is_lossless() { + let mut runner = proptest::test_runner::TestRunner::new(ProptestConfig::with_cases(64)); + runner + .run(&any_event(), |sampled_event| { + let encoded_event = + serde_json::to_string(&sampled_event).expect("event should serialize"); + let restored: Event = + serde_json::from_str(&encoded_event).expect("event should deserialize"); + prop_assert_eq!(sampled_event, restored); + Ok(()) + }) + .expect("proptest cases should pass"); } #[test] -fn node_event_serializes_type_scope_message_and_metadata() { - let event = Event::node_message_with_meta("router", 5, "routing", "Processing request"); - let json = event.to_json_value(); +fn node_event_json_contains_type_message_and_metadata() { + let json = + Event::node_message_with_meta("router", 5, "routing", "Processing request").to_json_value(); assert_eq!(json["type"], "node"); assert_eq!(json["scope"], "routing"); @@ -748,9 +704,8 @@ fn node_event_serializes_type_scope_message_and_metadata() { } #[test] -fn node_event_without_node_id_or_step_serializes_with_null_fields() { - let event = Event::node_message("test_scope", "test message"); - let json = event.to_json_value(); +fn node_event_without_optional_fields_keeps_metadata_object_empty() { + let json = Event::node_message("test_scope", "test message").to_json_value(); assert_eq!(json["type"], "node"); assert_eq!(json["scope"], "test_scope"); @@ -761,36 +716,33 @@ fn node_event_without_node_id_or_step_serializes_with_null_fields() { } #[test] -fn diagnostic_event_serializes_type_scope_message_and_empty_metadata() { - let event = Event::diagnostic("error_scope", "Something went wrong"); - let json = event.to_json_value(); +fn diagnostic_json_uses_empty_metadata_map() { + let json = Event::diagnostic("error_scope", "Something went wrong").to_json_value(); assert_eq!(json["type"], "diagnostic"); assert_eq!(json["scope"], "error_scope"); assert_eq!(json["message"], "Something went wrong"); assert!(json["timestamp"].is_string()); - assert!(json["metadata"].is_object()); - let metadata = json["metadata"].as_object().unwrap(); - assert!(metadata.is_empty()); + assert_eq!(json["metadata"], json!({})); } #[test] -fn llm_event_serializes_all_fields_to_json_value() { - let mut metadata = FxHashMap::default(); - metadata.insert("content_type".to_string(), json!("reasoning")); - metadata.insert("token_count".to_string(), json!(42)); - - let timestamp = Utc::now(); - let llm_event = LLMStreamingEvent::builder("Thinking step by step...") - .session_id("session-123") - .node_id("node-abc") - .stream_id("stream-xyz") - .metadata(metadata) - .timestamp(timestamp) - .build(); - let event = Event::LLM(llm_event); - let json = event.to_json_value(); +fn llm_json_serializes_ids_flags_and_timestamp() { + let mut metadata = FastMap::default(); + metadata.insert("content_type".to_owned(), json!("reasoning")); + metadata.insert("token_count".to_owned(), json!(42)); + let timestamp = ChronoUtc::now(); + let event = Event::LLM( + LLMStreamingEvent::builder("Thinking step by step...") + .session_id("session-123") + .node_id("node-abc") + .stream_id("stream-xyz") + .metadata(metadata) + .timestamp(timestamp) + .build(), + ); + let json = event.to_json_value(); assert_eq!(json["type"], "llm"); assert_eq!(json["message"], "Thinking step by step..."); assert_eq!(json["metadata"]["session_id"], "session-123"); @@ -803,184 +755,152 @@ fn llm_event_serializes_all_fields_to_json_value() { } #[test] -fn llm_event_final_chunk_sets_is_final_and_nulls_optional_ids() { - let llm_event = LLMStreamingEvent::builder("Final chunk") - .stream_id("stream-999") - .is_final(true) - .build(); - let event = Event::LLM(llm_event); - let json = event.to_json_value(); - - assert_eq!(json["type"], "llm"); - assert_eq!(json["metadata"]["is_final"], true); - assert_eq!(json["metadata"]["stream_id"], "stream-999"); - assert!(json["metadata"]["session_id"].is_null()); - assert!(json["metadata"]["node_id"].is_null()); -} - -#[test] -fn json_string_output_is_compact() { - let event = Event::diagnostic("test", "message"); - let json_str = event.to_json_string().unwrap(); - - assert!(json_str.contains("\"type\":\"diagnostic\"")); - assert!(json_str.contains("\"scope\":\"test\"")); - assert!(json_str.contains("\"message\":\"message\"")); - assert!(!json_str.contains(" ")); // No indentation -} - -#[test] -fn json_pretty_output_has_indentation() { - let event = Event::node_message("test", "hello"); - let json_str = event.to_json_pretty().unwrap(); - - assert!(json_str.contains(" \"type\": \"node\"")); - assert!(json_str.contains(" \"scope\": \"test\"")); - assert!(json_str.contains(" \"message\": \"hello\"")); -} +fn compact_and_pretty_json_helpers_preserve_llm_payload() { + let timestamp = ChronoUtc::now(); + let event = Event::LLM( + LLMStreamingEvent::builder("hello") + .stream_id("stream-1") + .timestamp(timestamp) + .build(), + ); -#[test] -fn json_string_round_trips_to_parsed_value() { - let original = Event::node_message_with_meta("node1", 10, "scope1", "msg1"); - let json_str = original.to_json_string().unwrap(); - let parsed: Value = serde_json::from_str(&json_str).unwrap(); + let compact = event + .to_json_string() + .expect("compact json should serialize"); + let pretty = event + .to_json_pretty() + .expect("pretty json should serialize"); + let compact_value: Value = serde_json::from_str(&compact).expect("compact json should parse"); + let pretty_value: Value = serde_json::from_str(&pretty).expect("pretty json should parse"); - assert_eq!(parsed["type"], "node"); - assert_eq!(parsed["metadata"]["node_id"], "node1"); - assert_eq!(parsed["metadata"]["step"], 10); + assert_eq!(compact_value, pretty_value); + assert!(compact.contains("\"type\":\"llm\"")); + assert!(pretty.contains(" \"type\": \"llm\"")); } #[tokio::test] -async fn jsonlines_sink_writes_one_line_per_event() { - use std::io::Cursor; - +async fn jsonlines_sink_writes_one_compact_json_object_per_line() { let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); - let buffer_clone = buffer.clone(); - - let mut sink = JsonLinesSink::new(Box::new(SharedWriter(buffer))); - - let event1 = Event::diagnostic("test1", "first message"); - let event2 = Event::node_message("test2", "second message"); - - sink.handle(&event1).unwrap(); - sink.handle(&event2).unwrap(); + let writer = MirrorWriter { + buffer: Arc::clone(&buffer), + }; + let mut sink = JsonLinesSink::new(Box::new(writer)); - let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); - let output = String::from_utf8(locked.get_ref().clone()).unwrap(); - let lines: Vec<&str> = output.lines().collect(); + sink.handle(&Event::diagnostic("test1", "first message")) + .expect("first write should succeed"); + sink.handle(&Event::node_message("test2", "second message")) + .expect("second write should succeed"); + let lines = parse_jsonl(&buffer); assert_eq!(lines.len(), 2); - - let json1: Value = serde_json::from_str(lines[0]).unwrap(); - assert_eq!(json1["type"], "diagnostic"); - assert_eq!(json1["scope"], "test1"); - assert_eq!(json1["message"], "first message"); - - let json2: Value = serde_json::from_str(lines[1]).unwrap(); - assert_eq!(json2["type"], "node"); - assert_eq!(json2["scope"], "test2"); - assert_eq!(json2["message"], "second message"); + assert_eq!(lines[0]["type"], "diagnostic"); + assert_eq!(lines[0]["scope"], "test1"); + assert_eq!(lines[0]["message"], "first message"); + assert_eq!(lines[1]["type"], "node"); + assert_eq!(lines[1]["scope"], "test2"); + assert_eq!(lines[1]["message"], "second message"); } #[tokio::test] -async fn jsonlines_sink_pretty_print_indents_output() { - use std::io::Cursor; - +async fn pretty_jsonlines_sink_includes_indentation() { let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); - let buffer_clone = buffer.clone(); - - let mut sink = JsonLinesSink::with_pretty_print(Box::new(SharedWriter(buffer))); - - let event = Event::diagnostic("pretty_test", "formatted output"); - sink.handle(&event).unwrap(); + let writer = MirrorWriter { + buffer: Arc::clone(&buffer), + }; + let mut sink = JsonLinesSink::with_pretty_print(Box::new(writer)); - let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); - let output = String::from_utf8(locked.get_ref().clone()).unwrap(); + sink.handle(&Event::diagnostic("pretty_test", "formatted output")) + .expect("write should succeed"); - // Pretty printed JSON should have indentation + let output = shared_buffer_text(&buffer); assert!(output.contains(" \"type\": \"diagnostic\"")); assert!(output.contains(" \"scope\": \"pretty_test\"")); } #[tokio::test] -async fn jsonlines_sink_writes_events_to_file() { - use std::fs; - - let temp_file = tempfile::NamedTempFile::new().unwrap(); - let path = temp_file.path().to_path_buf(); +async fn jsonlines_sink_to_file_persists_every_event() { + let path = repo_jsonl_path("event-bus-jsonl"); { - let mut sink = JsonLinesSink::to_file(&path).unwrap(); - - let event1 = Event::node_message_with_meta("file_node", 1, "file_scope", "first"); - let event2 = Event::diagnostic("file_scope", "second"); - let event3 = Event::node_message("file_scope", "third"); - - sink.handle(&event1).unwrap(); - sink.handle(&event2).unwrap(); - sink.handle(&event3).unwrap(); - } // sink dropped, file flushed + let mut sink = JsonLinesSink::to_file(&path).expect("file sink should open"); + sink.handle(&Event::node_message_with_meta( + "file_node", + 1, + "file_scope", + "first", + )) + .expect("first write should succeed"); + sink.handle(&Event::diagnostic("file_scope", "second")) + .expect("second write should succeed"); + sink.handle(&Event::node_message("file_scope", "third")) + .expect("third write should succeed"); + } - let contents = fs::read_to_string(&path).unwrap(); - let lines: Vec<&str> = contents.lines().collect(); + let contents = fs::read_to_string(&path).expect("jsonl file should be readable"); + let lines: Vec = contents + .lines() + .map(|line| serde_json::from_str(line).expect("line should be valid json")) + .collect(); + let _ = fs::remove_file(&path); assert_eq!(lines.len(), 3); - - let json1: Value = serde_json::from_str(lines[0]).unwrap(); - assert_eq!(json1["metadata"]["node_id"], "file_node"); - assert_eq!(json1["metadata"]["step"], 1); - - let json2: Value = serde_json::from_str(lines[1]).unwrap(); - assert_eq!(json2["type"], "diagnostic"); - - let json3: Value = serde_json::from_str(lines[2]).unwrap(); - assert_eq!(json3["message"], "third"); + assert_eq!(lines[0]["metadata"]["node_id"], "file_node"); + assert_eq!(lines[0]["metadata"]["step"], 1); + assert_eq!(lines[1]["type"], "diagnostic"); + assert_eq!(lines[2]["message"], "third"); } #[tokio::test] -async fn jsonlines_sink_flushes_after_each_event() { - use std::io::Cursor; - +async fn jsonlines_sink_flushes_after_each_handle_call() { let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); - let buffer_clone = buffer.clone(); - - let mut sink = JsonLinesSink::new(Box::new(SharedWriter(buffer))); - - let event = Event::diagnostic("flush_test", "should be flushed immediately"); - sink.handle(&event).unwrap(); + let writer = MirrorWriter { + buffer: Arc::clone(&buffer), + }; + let mut sink = JsonLinesSink::new(Box::new(writer)); - // handle() flushes after each event, so the buffer is immediately readable - let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); - let output = String::from_utf8(locked.get_ref().clone()).unwrap(); + sink.handle(&Event::diagnostic( + "flush_test", + "should be visible immediately", + )) + .expect("write should succeed"); - assert!(output.contains("\"message\":\"should be flushed immediately\"")); + let output = shared_buffer_text(&buffer); + assert!(output.contains("\"message\":\"should be visible immediately\"")); } #[tokio::test] -async fn jsonlines_sink_with_event_bus_captures_all_events() { - let buffer = Arc::new(Mutex::new(std::io::Cursor::new(Vec::new()))); - let buffer_clone = buffer.clone(); - - let sink = JsonLinesSink::new(Box::new(SharedWriter(buffer))); +async fn jsonlines_sink_integrates_with_event_bus() { + let buffer = Arc::new(Mutex::new(Cursor::new(Vec::new()))); + let sink = JsonLinesSink::new(Box::new(MirrorWriter { + buffer: Arc::clone(&buffer), + })); let bus = EventBus::with_sink(sink); bus.listen_for_events(); - let emitter = bus.get_emitter(); + emitter .emit(Event::node_message("integration", "message1")) - .unwrap(); + .expect("first emit should succeed"); emitter .emit(Event::diagnostic("integration", "message2")) - .unwrap(); + .expect("second emit should succeed"); - tokio::time::sleep(Duration::from_millis(50)).await; + wait_for_workers().await; bus.stop_listener().await; - let locked = buffer_clone.lock().expect("test buffer mutex poisoned"); - let output = String::from_utf8(locked.get_ref().clone()).unwrap(); - let lines: Vec<&str> = output.lines().collect(); - + let lines = parse_jsonl(&buffer); assert_eq!(lines.len(), 2); - let json1: Value = serde_json::from_str(lines[0]).unwrap(); - assert_eq!(json1["message"], "message1"); + assert_eq!(lines[0]["message"], "message1"); + assert_eq!(lines[1]["message"], "message2"); +} + +#[tokio::test] +async fn channel_sink_reports_broken_pipe_when_receiver_is_gone() { + let (tx, rx) = flume::unbounded(); + let mut sink = ChannelSink::new(tx); + drop(rx); + + let result = sink.handle(&Event::diagnostic("test", "msg")); + let error = result.expect_err("dropped receiver should produce an error"); + assert_eq!(error.kind(), io::ErrorKind::BrokenPipe); } diff --git a/tests/nodes.rs b/tests/nodes.rs index 99d2b53..3f801bd 100644 --- a/tests/nodes.rs +++ b/tests/nodes.rs @@ -10,12 +10,13 @@ use weavegraph::utils::collections::new_extra_map; fn make_ctx(step: u64) -> (NodeContext, EventBus) { let event_bus = EventBus::default(); - event_bus.listen_for_events(); - let ctx = NodeContext::new("test-node", step, event_bus.get_emitter()); - (ctx, event_bus) + let emitter = event_bus.get_emitter(); + let _events = event_bus.subscribe(); + let context = NodeContext::new("test-node", step, emitter); + (context, event_bus) } -#[tokio::test] +#[tokio::test(flavor = "current_thread")] async fn node_context_exposes_id_and_step() { let (ctx, _event_bus) = make_ctx(5); assert_eq!(ctx.node_id, "test-node"); @@ -62,11 +63,11 @@ fn node_partial_with_errors_leaves_other_fields_none() { assert_eq!(partial.errors, Some(errors)); } -#[tokio::test] +#[tokio::test(flavor = "current_thread")] async fn emit_fails_when_event_bus_dropped() { let (ctx, event_bus) = make_ctx(1); drop(event_bus); - tokio::task::yield_now().await; + tokio::time::sleep(std::time::Duration::from_millis(1)).await; let result = ctx.emit("scope", "message"); assert!(matches!(result, Err(NodeContextError::EventBusUnavailable))); } @@ -162,7 +163,7 @@ async fn node_returns_partial_with_expected_message_role() { async fn node_run_propagates_event_bus_disconnect_as_error() { let (ctx, event_bus) = make_ctx(0); drop(event_bus); - tokio::task::yield_now().await; + tokio::time::sleep(std::time::Duration::from_millis(1)).await; let node = DummyNode; let snapshot = VersionedState::new_with_user_message("dummy").snapshot(); let result = node.run(snapshot, ctx).await; diff --git a/tests/runtimes_persistence_postgres.rs b/tests/runtimes_persistence_postgres.rs index d4ab6f2..c1d11e2 100644 --- a/tests/runtimes_persistence_postgres.rs +++ b/tests/runtimes_persistence_postgres.rs @@ -1,456 +1,477 @@ -//! PostgreSQL checkpointer integration tests. -//! -//! These tests require a running PostgreSQL instance. Set the environment variable -//! `WEAVEGRAPH_POSTGRES_TEST_URL` to point to your test database, e.g.: -//! -//! ```bash -//! export WEAVEGRAPH_POSTGRES_TEST_URL="postgresql://weavegraph:weavegraph@localhost/weavegraph_test" -//! docker-compose up -d postgres -//! cargo test --features postgres-migrations runtimes_persistence_postgres -//! ``` -//! -//! Each test uses unique session IDs to ensure test independence. - -#![cfg(feature = "postgres")] - -use chrono::Utc; -use rustc_hash::FxHashMap; +#![cfg(all(feature = "postgres"))] + +use rustc_hash::FxHashMap as FastMap; +use serde_json::{Value, json}; use std::sync::Arc; use tokio::sync::Barrier; -use weavegraph::channels::Channel; -use weavegraph::channels::errors::{ErrorEvent, WeaveError}; -use weavegraph::message::Role; -use weavegraph::runtimes::checkpointer_postgres::StepQuery as PgStepQuery; -use weavegraph::runtimes::{Checkpoint, Checkpointer, PostgresCheckpointer}; -use weavegraph::types::NodeKind; - -mod common; -use common::*; - -/// Get the test database URL from environment or use default docker-compose URL. -fn get_test_db_url() -> String { - std::env::var("WEAVEGRAPH_POSTGRES_TEST_URL").unwrap_or_else(|_| { - "postgresql://weavegraph:weavegraph@localhost:5432/weavegraph_test".into() - }) +use weavegraph::{ + channels::{ + Channel, + errors::{ErrorEvent, ErrorScope, WeaveError}, + }, + message::Role, + runtimes::{Checkpoint, Checkpointer, PostgresCheckpointer, checkpointer_postgres::StepQuery}, + state::VersionedState, + types::NodeKind as Kind, +}; + +#[path = "common/mod.rs"] +mod support; +use support::state_with_user; + +fn postgres_test_url() -> String { + match std::env::var("WEAVEGRAPH_POSTGRES_TEST_URL") { + Ok(url) => url, + Err(_) => "postgresql://weavegraph:weavegraph@localhost:5432/weavegraph_test".to_owned(), + } } -/// Connect to Postgres or panic with helpful message. -async fn connect_or_fail() -> PostgresCheckpointer { - let db_url = get_test_db_url(); - PostgresCheckpointer::connect(&db_url) - .await - .unwrap_or_else(|e| { - panic!( - "Failed to connect to Postgres at {db_url}: {e}\n\ - Start Postgres with: docker-compose up -d postgres" - ) - }) +async fn postgres_checkpointer() -> PostgresCheckpointer { + let url = postgres_test_url(); + let connection = PostgresCheckpointer::connect(&url).await; + match connection { + Ok(store) => store, + Err(error) => panic!("failed to connect to postgres at {url}: {error}"), + } } -/// Helper to generate unique session IDs for test isolation. -fn unique_session_id(prefix: &str) -> String { - format!("{}_{}", prefix, uuid::Uuid::new_v4()) +fn session_id(name: &str) -> String { + format!("wg-pg-{name}-{}", uuid::Uuid::new_v4()) } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn checkpoint_state_and_metadata_survive_postgres_roundtrip() { - let cp = connect_or_fail().await; - - let session_id = unique_session_id("roundtrip"); - let mut state = state_with_user("hello"); - state - .extra - .get_mut() - .insert("k".into(), serde_json::json!(42)); - - let mut versions_seen: FxHashMap> = FxHashMap::default(); - versions_seen.insert( - "Start".into(), - FxHashMap::from_iter([("messages".into(), 1_u64), ("extra".into(), 1_u64)]), - ); +fn state_with_entries( + prompt: &str, + entries: impl IntoIterator, +) -> VersionedState { + entries + .into_iter() + .fold(state_with_user(prompt), |mut draft, (field, value)| { + draft.extra.get_mut().insert(field.to_owned(), value); + draft + }) +} - let cp_struct = Checkpoint { - session_id: session_id.clone(), - step: 1, - state: state.clone(), - frontier: vec![NodeKind::End], - versions_seen: versions_seen.clone(), - concurrency_limit: 4, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], +fn checkpoint_record(session: &str, step_number: u64, state: VersionedState) -> Checkpoint { + Checkpoint { + session_id: session.to_owned(), + step: step_number, + state, + frontier: Vec::from([Kind::End]), + versions_seen: FastMap::default(), + concurrency_limit: 2, + created_at: chrono::Utc::now(), + ran_nodes: vec![], skipped_nodes: vec![], - updated_channels: vec!["messages".to_string()], - }; + updated_channels: vec![], + } +} - cp.save(cp_struct.clone()).await.expect("save"); +async fn latest_checkpoint(store: &PostgresCheckpointer, session: &str) -> Checkpoint { + let Ok(found) = store.load_latest(session).await else { + panic!("load_latest failed for {session}"); + }; + let Some(current) = found else { + panic!("expected persisted checkpoint for {session}"); + }; + current +} - let loaded = cp - .load_latest(&session_id) - .await - .expect("load_latest") - .expect("Some checkpoint"); - assert_eq!(loaded.step, 1); - assert_eq!(loaded.frontier, vec![NodeKind::End]); - assert_eq!( - loaded - .versions_seen - .get("Start") - .and_then(|m| m.get("messages")) - .copied(), - Some(1) +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] +async fn checkpoint_state_and_metadata_survive_postgres_roundtrip() { + let store = postgres_checkpointer().await; + let session = session_id("checkpoint-state-roundtrip"); + + let mut persisted = checkpoint_record( + &session, + 7, + state_with_entries( + "hello from postgres roundtrip", + [("counter", json!(7)), ("status", json!("persisted"))], + ), ); - assert_eq!(loaded.state.messages.snapshot()[0].role, Role::User); - assert_extra_has(&loaded.state, "k"); + persisted.frontier = vec![Kind::Custom("resume-node".to_owned()), Kind::End]; + persisted.concurrency_limit = 4; + persisted.ran_nodes = vec![Kind::Start]; + persisted.skipped_nodes = vec![Kind::Custom("Skipped".to_owned())]; + persisted.updated_channels = vec!["messages".to_owned(), "extra".to_owned()]; + + let mut seen_channels = FastMap::default(); + seen_channels.insert("messages".to_owned(), 3); + seen_channels.insert("extra".to_owned(), 2); + persisted + .versions_seen + .insert("ResumeNode".to_owned(), seen_channels); + + store.save(persisted).await.expect("save checkpoint"); + + let loaded = latest_checkpoint(&store, &session).await; + let persisted_message_version = loaded + .versions_seen + .get("ResumeNode") + .and_then(|table| table.get("messages")) + .copied(); + + assert_eq!(loaded.step, 7); + assert_eq!(loaded.concurrency_limit, 4); assert_eq!( - loaded.state.extra.snapshot().get("k"), - Some(&serde_json::json!(42)) + loaded.frontier, + vec![Kind::Custom("resume-node".to_owned()), Kind::End] ); + assert_eq!(persisted_message_version, Some(3)); + + let messages = loaded.state.messages.snapshot(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, Role::User); + assert_eq!(messages[0].content, "hello from postgres roundtrip"); + + let extra = loaded.state.extra.snapshot(); + assert_eq!(extra.get("counter"), Some(&json!(7))); + assert_eq!(extra.get("status"), Some(&json!("persisted"))); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] async fn saved_sessions_appear_in_list_and_missing_session_returns_none() { - let cp = connect_or_fail().await; - - // Use unique session IDs for this test - let prefix = format!("list_test_{}", uuid::Uuid::new_v4()); - let session_ids: Vec = (0..3).map(|i| format!("{prefix}_s{i}")).collect(); - - for s_id in &session_ids { - let state = state_with_user("x"); - let cp_struct = Checkpoint { - session_id: s_id.clone(), - step: 1, - state: state.clone(), - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![], - skipped_nodes: vec![NodeKind::End], - updated_channels: vec![], - }; - cp.save(cp_struct).await.unwrap(); + let store = postgres_checkpointer().await; + let created_sessions = [ + session_id("listed-a"), + session_id("listed-b"), + session_id("listed-c"), + ]; + + for name in &created_sessions { + let snapshot = checkpoint_record(name, 1, state_with_user("session listing")); + store.save(snapshot).await.expect("save checkpoint"); } - let all_sessions = cp.list_sessions().await.unwrap(); - // Check that our test sessions are in the list - for s_id in &session_ids { + let sessions = store.list_sessions().await.expect("list sessions"); + for wanted in created_sessions { assert!( - all_sessions.contains(s_id), - "Session {s_id} should be in list" + sessions.iter().any(|listed| listed == &wanted), + "missing {wanted}" ); } - // Test loading nonexistent session - let nonexistent = format!("nonexistent_{}", uuid::Uuid::new_v4()); - let res = cp.load_latest(&nonexistent).await.unwrap(); - assert!(res.is_none()); + let absent = session_id("missing-session"); + let missing = store + .load_latest(&absent) + .await + .expect("load missing session"); + assert!(missing.is_none()); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] async fn step_query_paginates_results_newest_first() { - let cp = connect_or_fail().await; - - let session_id = unique_session_id("paginate"); - - for step in 1..=5 { - let state = state_with_user(&format!("step {step}")); - let checkpoint = Checkpoint { - session_id: session_id.clone(), - step, - state, - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: if step % 2 == 0 { - vec![NodeKind::Start] - } else { - vec![] - }, - skipped_nodes: vec![NodeKind::End], - updated_channels: vec!["messages".to_string()], + let store = postgres_checkpointer().await; + let session = session_id("step-query-pagination"); + + for recorded_step in 1_u64..=5 { + let mut snapshot = checkpoint_record( + &session, + recorded_step, + state_with_entries( + &format!("state recorded at step {recorded_step}"), + [ + ("step", json!(recorded_step)), + ("bucket", json!(recorded_step % 2)), + ], + ), + ); + snapshot.ran_nodes = if recorded_step % 2 == 0 { + vec![Kind::Start] + } else { + vec![Kind::Custom("worker".to_owned())] }; - cp.save(checkpoint).await.expect("save checkpoint"); + snapshot.skipped_nodes = vec![Kind::End]; + snapshot.updated_channels = vec!["messages".to_owned()]; + store.save(snapshot).await.expect("save checkpoint"); } - let result = cp - .query_steps( - &session_id, - PgStepQuery { - limit: Some(2), - offset: Some(0), - ..Default::default() - }, - ) + let newest_query = StepQuery { + offset: Some(0), + limit: Some(2), + ..StepQuery::default() + }; + let newest_page = store + .query_steps(&session, newest_query) + .await + .expect("query newest page"); + + assert_eq!(newest_page.page_info.total_count, 5); + assert_eq!(newest_page.page_info.page_size, 2); + assert_eq!(newest_page.page_info.offset, 0); + assert!(newest_page.page_info.has_next_page); + assert_eq!( + newest_page + .checkpoints + .iter() + .map(|item| item.step) + .collect::>(), + vec![5, 4] + ); + + let next_query = StepQuery { + offset: Some(2), + limit: Some(2), + ..StepQuery::default() + }; + let next_page = store + .query_steps(&session, next_query) .await - .expect("query steps"); - assert_eq!(result.page_info.total_count, 5); - assert_eq!(result.page_info.page_size, 2); - assert_eq!(result.page_info.offset, 0); - assert!(result.page_info.has_next_page); - assert_eq!(result.checkpoints.len(), 2); - assert_eq!(result.checkpoints[0].step, 5); - assert_eq!(result.checkpoints[1].step, 4); + .expect("query next page"); + + assert_eq!( + next_page + .checkpoints + .iter() + .map(|item| item.step) + .collect::>(), + vec![3, 2] + ); + assert_eq!(next_page.page_info.total_count, 5); + assert!(next_page.page_info.has_next_page); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] async fn error_events_survive_postgres_roundtrip() { - let cp = connect_or_fail().await; - - let session_id = unique_session_id("err"); - let mut state = state_with_user("err"); + let store = postgres_checkpointer().await; + let session = session_id("error-events-roundtrip"); + + let mut state = state_with_user("error persistence"); + state.errors.get_mut().push( + ErrorEvent::runner( + session.clone(), + 3, + WeaveError::msg("boom").with_details(json!({"code": "E_BANG"})), + ) + .with_tags(vec!["postgres".to_owned(), "roundtrip".to_owned()]) + .with_context(json!({"retryable": false})), + ); - let err = ErrorEvent::app(WeaveError::msg("boom")) - .with_tag("t") - .with_context(serde_json::json!({"a":1})); + let mut persisted = checkpoint_record(&session, 3, state); + persisted.updated_channels = vec!["errors".to_owned()]; + store.save(persisted).await.expect("save checkpoint"); - state.errors.get_mut().push(err); - let checkpoint = Checkpoint { - session_id: session_id.clone(), - step: 1, - state, - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], - skipped_nodes: vec![], - updated_channels: vec!["errors".into()], - }; - cp.save(checkpoint).await.unwrap(); - let loaded = cp.load_latest(&session_id).await.unwrap().unwrap(); + let loaded = latest_checkpoint(&store, &session).await; let errors = loaded.state.errors.snapshot(); - assert_eq!(errors.len(), 1); - assert_eq!(errors[0].error.message, "boom"); + let [event] = &errors[..] else { + panic!("expected one error event"); + }; + assert_eq!(event.error.message, "boom"); + assert_eq!(event.error.details, json!({"code": "E_BANG"})); + assert_eq!(event.tags, vec!["postgres", "roundtrip"]); + assert_eq!(event.context, json!({"retryable": false})); + assert!(matches!(event.scope, ErrorScope::Runner { step: 3, .. })); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] async fn duplicate_save_is_idempotent() { - let cp = connect_or_fail().await; - - let session_id = unique_session_id("idempotent"); - let state = state_with_user("test"); - - let checkpoint = Checkpoint { - session_id: session_id.clone(), - step: 1, - state: state.clone(), - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], - skipped_nodes: vec![], - updated_channels: vec![], - }; + let store = postgres_checkpointer().await; + let session = session_id("duplicate-save"); + + let snapshot = checkpoint_record( + &session, + 1, + state_with_entries("same checkpoint twice", [("dedupe", json!(true))]), + ); + + store.save(snapshot.clone()).await.expect("first save"); + store.save(snapshot).await.expect("second save"); - // Save the same checkpoint twice - should not fail (upsert behavior) - cp.save(checkpoint.clone()).await.expect("first save"); - cp.save(checkpoint).await.expect("second save (idempotent)"); + let history = store + .query_steps( + &session, + StepQuery { + limit: Some(10), + ..StepQuery::default() + }, + ) + .await + .expect("query deduplicated history"); + let latest = latest_checkpoint(&store, &session).await; - let loaded = cp.load_latest(&session_id).await.unwrap().unwrap(); - assert_eq!(loaded.step, 1); + assert_eq!(history.page_info.total_count, 1); + assert_eq!(history.checkpoints.len(), 1); + assert_eq!(history.checkpoints[0].step, latest.step); + assert_eq!( + latest.state.extra.snapshot().get("dedupe"), + Some(&json!(true)) + ); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] async fn save_with_stale_expected_step_is_rejected() { - let cp = connect_or_fail().await; - - let session_id = unique_session_id("concurrency"); - let state = state_with_user("test"); - - // Save step 1 - let checkpoint1 = Checkpoint { - session_id: session_id.clone(), - step: 1, - state: state.clone(), - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], - skipped_nodes: vec![], - updated_channels: vec![], - }; - cp.save(checkpoint1).await.expect("save step 1"); - - // Try to save step 2 with correct expected_last_step - let checkpoint2 = Checkpoint { - session_id: session_id.clone(), - step: 2, - state: state.clone(), - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![], - skipped_nodes: vec![], - updated_channels: vec![], - }; - cp.save_with_concurrency_check(checkpoint2.clone(), Some(1)) + let store = postgres_checkpointer().await; + let session = session_id("stale-expected-step"); + + let first = checkpoint_record(&session, 1, state_with_user("first write")); + store.save(first).await.expect("save first checkpoint"); + + let second = checkpoint_record(&session, 2, state_with_user("second write")); + store + .save_with_concurrency_check(second, Some(1)) .await - .expect("save step 2 with correct check"); + .expect("save second checkpoint with matching expected step"); - // Try to save step 3 with wrong expected_last_step (should fail) - let checkpoint3 = Checkpoint { - session_id: session_id.clone(), - step: 3, - state, - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![], - skipped_nodes: vec![], - updated_channels: vec![], - }; - let result = cp.save_with_concurrency_check(checkpoint3, Some(1)).await; - assert!(result.is_err(), "should fail with wrong expected step"); + let stale = checkpoint_record(&session, 3, state_with_user("stale write")); + let failure = store + .save_with_concurrency_check(stale, Some(1)) + .await + .expect_err("stale expected step should fail"); + + assert!(failure.to_string().contains("concurrency conflict")); + + let latest = latest_checkpoint(&store, &session).await; + let history = store + .query_steps( + &session, + StepQuery { + limit: Some(10), + ..StepQuery::default() + }, + ) + .await + .expect("query history after rejected write"); + + assert_eq!(latest.step, 2); + assert_eq!(latest.state.messages.snapshot()[0].content, "second write"); + assert_eq!(history.page_info.total_count, 2); + assert_eq!( + history + .checkpoints + .iter() + .map(|item| item.step) + .collect::>(), + vec![2, 1] + ); } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +#[tokio::test(worker_threads = 2, flavor = "multi_thread")] async fn out_of_order_write_does_not_overwrite_higher_step() { - let cp = connect_or_fail().await; - - let session_id = unique_session_id("out_of_order"); - - // Save a higher step first. - let mut state_step_5 = state_with_user("step 5"); - state_step_5 - .extra - .get_mut() - .insert("marker".into(), serde_json::json!(5)); - - let checkpoint5 = Checkpoint { - session_id: session_id.clone(), - step: 5, - state: state_step_5, - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], - skipped_nodes: vec![], - updated_channels: vec![], - }; - - cp.save(checkpoint5).await.expect("save step 5"); - - // Then save a lower step later (out-of-order). - let mut state_step_2 = state_with_user("step 2"); - state_step_2 - .extra - .get_mut() - .insert("marker".into(), serde_json::json!(2)); - - let checkpoint2 = Checkpoint { - session_id: session_id.clone(), - step: 2, - state: state_step_2, - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], - skipped_nodes: vec![], - updated_channels: vec![], - }; + let store = postgres_checkpointer().await; + let session = session_id("out-of-order-write"); + + let highest = checkpoint_record( + &session, + 9, + state_with_entries( + "higher step wins", + [("winner", json!("high-step")), ("step", json!(9))], + ), + ); + store.save(highest).await.expect("save higher step"); + + let late_lower = checkpoint_record( + &session, + 3, + state_with_entries( + "lower step arrives later", + [("winner", json!("late-low-step")), ("step", json!(3))], + ), + ); + store.save(late_lower).await.expect("save lower step"); - cp.save(checkpoint2) + let latest = latest_checkpoint(&store, &session).await; + let history = store + .query_steps( + &session, + StepQuery { + limit: Some(10), + ..StepQuery::default() + }, + ) .await - .expect("save step 2 (out-of-order)"); + .expect("query out-of-order history"); - // Latest must remain at step 5 and retain the step-5 snapshot. - let loaded = cp.load_latest(&session_id).await.unwrap().unwrap(); - assert_eq!(loaded.step, 5); + assert_eq!(latest.step, 9); + assert_eq!( + latest.state.extra.snapshot().get("winner"), + Some(&json!("high-step")) + ); + assert_eq!( + latest.state.messages.snapshot()[0].content, + "higher step wins" + ); + assert_eq!(history.page_info.total_count, 2); assert_eq!( - loaded.state.extra.snapshot().get("marker"), - Some(&serde_json::json!(5)) + history + .checkpoints + .iter() + .map(|item| item.step) + .collect::>(), + vec![9, 3] ); } -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[tokio::test(worker_threads = 4, flavor = "multi_thread")] async fn concurrent_writers_only_one_wins_concurrency_check() { - let cp = Arc::new(connect_or_fail().await); - - let session_id = unique_session_id("concurrent_writers"); - let state = state_with_user("base"); - - // Seed step 1 so expected_last_step = 1 is a valid check. - let checkpoint1 = Checkpoint { - session_id: session_id.clone(), - step: 1, - state: state.clone(), - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![NodeKind::Start], - skipped_nodes: vec![], - updated_channels: vec![], - }; - cp.save(checkpoint1).await.expect("save step 1"); + let store = Arc::new(postgres_checkpointer().await); + let session = session_id("concurrent-writers"); - let barrier = Arc::new(Barrier::new(3)); + let seed = checkpoint_record(&session, 1, state_with_user("seed step")); + store.save(seed).await.expect("save seed checkpoint"); - let make_checkpoint2 = |marker: i64| { - let mut s = state_with_user("step 2"); - s.extra - .get_mut() - .insert("marker".into(), serde_json::json!(marker)); - Checkpoint { - session_id: session_id.clone(), - step: 2, - state: s, - frontier: vec![NodeKind::End], - versions_seen: FxHashMap::default(), - concurrency_limit: 1, - created_at: Utc::now(), - ran_nodes: vec![], - skipped_nodes: vec![], - updated_channels: vec![], - } - }; + let barrier = Arc::new(Barrier::new(3)); - let cp_a = Arc::clone(&cp); - let barrier_a = Arc::clone(&barrier); - let checkpoint_a = make_checkpoint2(111); - let handle_a = tokio::spawn(async move { - barrier_a.wait().await; - cp_a.save_with_concurrency_check(checkpoint_a, Some(1)) + let first_store = Arc::clone(&store); + let first_barrier = Arc::clone(&barrier); + let first_session = session.clone(); + let first_writer = tokio::spawn(async move { + let candidate = checkpoint_record( + &first_session, + 2, + state_with_entries( + "writer one", + [("writer", json!("writer-one")), ("step", json!(2))], + ), + ); + first_barrier.wait().await; + first_store + .save_with_concurrency_check(candidate, Some(1)) .await }); - let cp_b = Arc::clone(&cp); - let barrier_b = Arc::clone(&barrier); - let checkpoint_b = make_checkpoint2(222); - let handle_b = tokio::spawn(async move { - barrier_b.wait().await; - cp_b.save_with_concurrency_check(checkpoint_b, Some(1)) + let second_store = Arc::clone(&store); + let second_barrier = Arc::clone(&barrier); + let second_session = session.clone(); + let second_writer = tokio::spawn(async move { + let candidate = checkpoint_record( + &second_session, + 2, + state_with_entries( + "writer two", + [("writer", json!("writer-two")), ("step", json!(2))], + ), + ); + second_barrier.wait().await; + second_store + .save_with_concurrency_check(candidate, Some(1)) .await }); - // Release both tasks at the same time. barrier.wait().await; - let res_a = handle_a.await.expect("task a join"); - let res_b = handle_b.await.expect("task b join"); - - let ok_count = [res_a.as_ref(), res_b.as_ref()] + let first_outcome = first_writer.await.expect("join writer one"); + let second_outcome = second_writer.await.expect("join writer two"); + let successes = [first_outcome.is_ok(), second_outcome.is_ok()] .into_iter() - .filter(|r| r.is_ok()) + .filter(|won| *won) .count(); - assert_eq!(ok_count, 1, "exactly one writer should succeed"); + let latest = latest_checkpoint(&store, &session).await; + let history = store + .query_steps( + &session, + StepQuery { + limit: Some(10), + ..StepQuery::default() + }, + ) + .await + .expect("query concurrent writer history"); + let elected = latest.state.extra.snapshot().get("writer").cloned(); - // Latest must be step 2, with one of the markers. - let loaded = cp.load_latest(&session_id).await.unwrap().unwrap(); - assert_eq!(loaded.step, 2); - let snapshot = loaded.state.extra.snapshot(); - let marker = snapshot.get("marker"); - assert!( - marker == Some(&serde_json::json!(111)) || marker == Some(&serde_json::json!(222)), - "latest marker should match one of the winning writers" - ); + assert_eq!(successes, 1); + assert_eq!(latest.step, 2); + assert!(elected == Some(json!("writer-one")) || elected == Some(json!("writer-two"))); + assert_eq!(history.page_info.total_count, 2); } diff --git a/tests/runtimes_runner.rs b/tests/runtimes_runner.rs index 2cf8b62..c3644b9 100644 --- a/tests/runtimes_runner.rs +++ b/tests/runtimes_runner.rs @@ -6,22 +6,22 @@ use std::time::Duration; use async_trait::async_trait; use serde_json::json; -use weavegraph::channels::Channel; -use weavegraph::event_bus::{ - EventBus, EventStream, INVOCATION_END_SCOPE, MemorySink, STREAM_END_SCOPE, +use weavegraph::{ + FrontierCommand, NodeRoute, + channels::Channel, + event_bus::{EventBus, EventStream, INVOCATION_END_SCOPE, MemorySink, STREAM_END_SCOPE}, + graphs::{EdgePredicate, GraphBuilder}, + message::{Message, Role}, + node::{Node, NodeContext, NodeError, NodePartial}, + runtimes::{ + AppRunner, Checkpoint, Checkpointer, CheckpointerType, PausedReason, RuntimeConfig, + SessionInit, SessionState, StepOptions, StepResult, + }, + schedulers::{Scheduler, SchedulerState}, + state::{StateSnapshot, VersionedState}, + types::NodeKind, + utils::clock::MockClock, }; -use weavegraph::graphs::{EdgePredicate, GraphBuilder}; -use weavegraph::message::{Message, Role}; -use weavegraph::node::{Node, NodeContext, NodeError, NodePartial}; -use weavegraph::runtimes::{ - AppRunner, Checkpoint, Checkpointer, CheckpointerType, PausedReason, RuntimeConfig, - SessionInit, SessionState, StepOptions, StepResult, -}; -use weavegraph::schedulers::{Scheduler, SchedulerState}; -use weavegraph::state::{StateSnapshot, VersionedState}; -use weavegraph::types::NodeKind; -use weavegraph::utils::clock::MockClock; -use weavegraph::{FrontierCommand, NodeRoute}; mod common; use common::*; @@ -262,27 +262,22 @@ async fn conditional_edge_routes_to_labeled_target_based_on_state() { assert!(!rep2.next_frontier.contains(&NodeKind::Custom("Y".into()))); } -#[tokio::test] -async fn event_stream_returns_none_on_second_call() { - let app = make_test_app(); +#[::tokio::test] +async fn event_stream_stays_claimed_after_initial_subscription_is_dropped() { let mut runner = AppRunner::builder() - .app(app) + .app(make_test_app()) .checkpointer(CheckpointerType::InMemory) .build() .await; - - let stream = runner - .event_stream() - .expect("first event_stream call should succeed"); - drop(stream); - - let result = runner.event_stream(); - assert!( - result.is_none(), - "expected None on second event_stream call" + drop( + runner + .event_stream() + .expect("initial event stream subscription should exist"), ); + if runner.event_stream().is_some() { + panic!("event stream remains unavailable after the first subscription"); + } } - #[tokio::test] async fn create_session_returns_fresh_and_is_retrievable() { let app = make_test_app(); @@ -881,96 +876,82 @@ async fn run_metadata_reports_graph_hash_and_backend_identifiers() { assert_eq!(metadata.clock_mode, "configured"); } -#[derive(Debug, Clone)] -struct ReplaceController; - -#[async_trait] -impl Node for ReplaceController { +#[derive(Clone, Debug)] +struct FrontierRedirectNode; +#[async_trait::async_trait] +impl Node for FrontierRedirectNode { async fn run( &self, _snapshot: StateSnapshot, _ctx: NodeContext, ) -> Result { - Ok(NodePartial::new().with_frontier_replace(vec![NodeKind::Custom("worker".into())])) + let replacement = NodeKind::Custom("replacement".to_owned()); + Ok(NodePartial::new().with_frontier_replace([replacement])) } } - -#[derive(Debug, Clone)] -struct WorkerNode; - -#[async_trait] -impl Node for WorkerNode { +#[derive(Clone, Debug)] +struct ReplacementMessageNode; +#[async_trait::async_trait] +impl Node for ReplacementMessageNode { async fn run( &self, _snapshot: StateSnapshot, _ctx: NodeContext, ) -> Result { - Ok(NodePartial::new() - .with_messages(vec![Message::with_role(Role::Assistant, "worker-run")])) + let message = Message::with_role(Role::Assistant, "replacement-hit"); + Ok(NodePartial::new().with_messages(vec![message])) } } - -#[tokio::test] -async fn frontier_replace_command_redirects_execution_to_replacement_node() { - let app = GraphBuilder::new() - .add_node(NodeKind::Custom("controller".into()), ReplaceController) - .add_node(NodeKind::Custom("worker".into()), WorkerNode) - .add_edge(NodeKind::Start, NodeKind::Custom("controller".into())) - .add_edge( - NodeKind::Custom("controller".into()), - NodeKind::Custom("worker".into()), - ) - .add_edge(NodeKind::Custom("worker".into()), NodeKind::End) - .compile() - .unwrap(); - +#[::tokio::test] +async fn frontier_replace_command_uses_replacement_frontier_on_next_step() { + let controller = NodeKind::Custom("controller".to_owned()); + let replacement = NodeKind::Custom("replacement".to_owned()); + let session_id = "frontier-session"; + let app = { + let builder = GraphBuilder::new() + .add_node(controller.clone(), FrontierRedirectNode) + .add_node(replacement.clone(), ReplacementMessageNode) + .add_edge(NodeKind::Start, controller.clone()) + .add_edge(controller.clone(), replacement.clone()) + .add_edge(controller.clone(), NodeKind::End) + .add_edge(replacement.clone(), NodeKind::End); + builder.compile().expect("test graph should compile") + }; let mut runner = AppRunner::builder() .app(app) .checkpointer(CheckpointerType::InMemory) .build() .await; - runner - .create_session("frontier-session".into(), state_with_user("control")) + .create_session(session_id.to_owned(), state_with_user("control")) .await - .expect("create session"); - - let StepResult::Completed(first_report) = runner - .run_step("frontier-session", StepOptions::default()) + .expect("session setup should succeed"); + let first_outcome = runner + .run_step(session_id, StepOptions::default()) .await - .expect("first step") - else { + .expect("first step should complete"); + let StepResult::Completed(first_step) = first_outcome else { panic!("expected completed step"); }; - assert_eq!( - first_report.ran_nodes, - vec![NodeKind::Custom("controller".into())] - ); - assert_eq!(first_report.barrier_outcome.frontier_commands.len(), 1); - let FrontierCommand::Replace(routes) = &first_report.barrier_outcome.frontier_commands[0].1 + assert_eq!(first_step.ran_nodes, vec![controller.clone()]); + let [(origin, FrontierCommand::Replace(routes))] = + first_step.barrier_outcome.frontier_commands.as_slice() else { - panic!( - "expected replace command, got {:?}", - first_report.barrier_outcome.frontier_commands[0].1 - ); + panic!("expected exactly one replace frontier command"); }; - let kinds: Vec = routes.iter().map(NodeRoute::to_node_kind).collect(); - assert_eq!(kinds, vec![NodeKind::Custom("worker".into())]); - - let StepResult::Completed(second_report) = runner - .run_step("frontier-session", StepOptions::default()) + assert_eq!(origin, &controller); + let next_nodes: Vec<_> = routes.iter().map(NodeRoute::to_node_kind).collect(); + assert_eq!(next_nodes, vec![replacement.clone()]); + assert_eq!(first_step.next_frontier, vec![replacement.clone()]); + let second_outcome = runner + .run_step(session_id, StepOptions::default()) .await - .expect("second step") - else { + .expect("second step should complete"); + let StepResult::Completed(second_step) = second_outcome else { panic!("expected completed step"); }; - assert!( - second_report - .ran_nodes - .contains(&NodeKind::Custom("worker".into())) - ); + assert_eq!(second_step.ran_nodes, vec![replacement]); } - #[tokio::test] async fn interrupt_before_pauses_step_before_named_node() { let app = make_test_app(); diff --git a/tests/smoke.rs b/tests/smoke.rs index 53d57c6..ffe7f0b 100644 --- a/tests/smoke.rs +++ b/tests/smoke.rs @@ -18,10 +18,12 @@ use std::process::Command; /// Helper to run an example and verify it succeeds with output fn run_example(example_name: &str) { - let result = Command::new("cargo") + let Ok(result) = Command::new("cargo") .args(["run", "--example", example_name]) .output() - .unwrap_or_else(|_| panic!("Failed to run example: {}", example_name)); + else { + panic!("failed to run example: {example_name}"); + }; assert!( result.status.success(), diff --git a/tests/state_channels.rs b/tests/state_channels.rs index bc07ebd..ed72448 100644 --- a/tests/state_channels.rs +++ b/tests/state_channels.rs @@ -54,24 +54,23 @@ fn new_with_user_message_creates_state_with_one_user_message() { } #[test] -fn new_with_messages_creates_state_with_all_supplied_messages() { - let messages = vec![ +fn new_with_messages_keeps_existing_history_and_starts_channels_at_version_one() { + let history = vec![ Message::with_role(Role::User, "hello"), Message::with_role(Role::Assistant, "hi there"), ]; - let state = VersionedState::new_with_messages(messages.clone()); - let snapshot = state.snapshot(); - - assert_eq!(snapshot.messages.len(), 2); - assert_eq!(snapshot.messages[0], messages[0]); - assert_eq!(snapshot.messages[1], messages[1]); - assert_eq!(snapshot.messages_version, 1); - assert!(snapshot.extra.is_empty()); - assert_eq!(snapshot.extra_version, 1); - assert!(snapshot.errors.is_empty()); - assert_eq!(snapshot.errors_version, 1); + let snapshot = VersionedState::new_with_messages(history.clone()).snapshot(); + assert_eq!(snapshot.messages, history); + assert_eq!( + ( + snapshot.messages_version, + snapshot.extra_version, + snapshot.errors_version, + ), + (1, 1, 1) + ); + assert_eq!((snapshot.extra.len(), snapshot.errors.len()), (0, 0)); } - #[test] fn snapshot_is_independent_of_subsequent_mutations() { let mut s = VersionedState::new_with_user_message("x"); @@ -84,22 +83,22 @@ fn snapshot_is_independent_of_subsequent_mutations() { } #[test] -fn snapshot_from_multi_message_init_is_independent_of_subsequent_mutations() { - let mut state = VersionedState::new_with_messages(vec![ +fn snapshot_from_seeded_history_stays_unchanged_after_later_updates() { + let mut state = VersionedState::new_with_messages(Vec::from([ Message::with_role(Role::User, "original"), Message::with_role(Role::Assistant, "response"), - ]); - let snapshot = state.snapshot(); - + ])); + let frozen_snapshot = state.snapshot(); state.add_message("user", "third"); - state.add_extra("k", Value::String("v".into())); - - assert_eq!(snapshot.messages.len(), 2); - assert_eq!(snapshot.messages[0].content, "original"); - assert_eq!(snapshot.messages[1].content, "response"); - assert!(!snapshot.extra.contains_key("k")); + state.add_extra("k", Value::String("v".to_owned())); + let contents = frozen_snapshot + .messages + .iter() + .map(|message| message.content.as_str()) + .collect::>(); + assert_eq!(contents.as_slice(), ["original", "response"]); + assert!(!frozen_snapshot.extra.contains_key("k")); } - #[test] fn extra_slot_accepts_number_string_and_array_json_values() { let s = VersionedState::builder() From cbba271f473c4d512523b830f375a8881f51f60a Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 20:38:42 -0400 Subject: [PATCH 11/15] final lint --- tests/channels.rs | 2 +- tests/runtimes_persistence_postgres.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/channels.rs b/tests/channels.rs index fbb03e0..26a3164 100644 --- a/tests/channels.rs +++ b/tests/channels.rs @@ -507,7 +507,7 @@ fn pretty_print_formats_app_event_with_context() { .with_context(json!({})); event.when = when; - let out = pretty_print(&vec![event]); + let out = pretty_print(&[event]); assert!(out.contains("display")); } diff --git a/tests/runtimes_persistence_postgres.rs b/tests/runtimes_persistence_postgres.rs index c1d11e2..93b68a1 100644 --- a/tests/runtimes_persistence_postgres.rs +++ b/tests/runtimes_persistence_postgres.rs @@ -1,4 +1,4 @@ -#![cfg(all(feature = "postgres"))] +#![cfg(feature = "postgres")] use rustc_hash::FxHashMap as FastMap; use serde_json::{Value, json}; From 03d3e8856ace9830d8f7e060ee0af2de65313a60 Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sat, 23 May 2026 20:55:10 -0400 Subject: [PATCH 12/15] Prepare 0.7.0 release notes and docs Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- CHANGELOG.md | 39 ++++++++++++- CONTRIBUTING.md | 12 ++-- Cargo.toml | 38 ++++++++++++- README.md | 4 +- SECURITY.md | 8 +-- docs/INDEX.md | 2 +- docs/MIGRATION.md | 93 +++++++++++++++++++++++++++++-- docs/OPERATIONS.md | 4 +- examples/README.md | 6 +- examples/basic_nodes.rs | 2 +- examples/convenience_streaming.rs | 2 +- examples/errors_pretty.rs | 4 ++ examples/event_backpressure.rs | 2 + examples/json_serialization.rs | 2 +- examples/production_streaming.rs | 4 +- examples/scheduler_fanout.rs | 2 + 16 files changed, 197 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ce3b5ac..bd2fa4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,38 @@ All notable changes to Weavegraph will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.7.0] - 2026-05-23 + +### Added + +- `LLMStreamingEventBuilder`, returned by `LLMStreamingEvent::builder(chunk)`, for constructing LLM streaming events without the removed eight-argument constructor. The builder supports optional `session_id`, `node_id`, `stream_id`, `is_final`, `scope`, `metadata`, and `timestamp` fields before `.build()`. + +### Changed + +- Internal code-quality passes across `src/app.rs`, `src/channels/`, `src/control.rs`, `src/event_bus/`, `src/graphs/`, `src/llm/`, `src/message.rs`, `src/node.rs`, `src/reducers/`, `src/runtimes/`, `src/schedulers/`, `src/state.rs`, `src/telemetry/`, `src/types.rs`, and `src/utils/`: documentation condensed, dead code removed, control flow simplified, and private helpers renamed for clarity. No behavior changes are intended for these maintenance-only revisions. +- `conditional_edges()` now returns `&[ConditionalEdge]` instead of `&Vec`. Most callers coerce transparently; explicit `&Vec<_>` annotations should be changed to slices. +- `WeaveError` now derives `Default`; behavior is unchanged. +- `EventBus::with_sink` and `Channel::set_version` were cosmetically cleaned up without semantic changes. +- `MetricsObserver` now relies on the default no-op `RuntimeObserver::on_invocation_start` implementation instead of an explicit empty override. +- Migration SQL files were rederived with equivalent schemas but different file contents. Existing databases that already ran these migrations may need their `sqlx` migration checksums updated or may need to be regenerated. + +### Removed + +- **Breaking:** `LLMStreamingEvent::new` was removed. Use `LLMStreamingEvent::builder(...)` or the existing `chunk_event`, `final_event`, and `error_event` factory methods. +- **Breaking:** `NodeContext::emit_node()` was removed. Use `NodeContext::emit(scope, message)`, which has the same signature and behavior. +- **Breaking:** `PersistenceError::MissingField` was removed. +- **Breaking:** `SQLiteCheckpointerError` was removed. Use `CheckpointerError` directly. +- **Breaking:** `checkpointer_postgres_helpers` and `checkpointer_sqlite_helpers` modules were removed; their contents are now private implementation details of the corresponding checkpointers. +- **Breaking:** Unused public utility surface was trimmed: `IdError`, `ParsedId`, `IdGenerator::{generate_uuid, generate_random_id, parse_id, current_counter, reset_counter}`, `id_utils`, `JsonValueExt::deep_clone`, `merge_inspector`, `message_id_helpers`, and `type_guards`. +- `tmp/license-relicensing-audit.md` was removed from git tracking. + +### Fixed + +- Removed misapplied `#[must_use]` attributes from mutating `VersionedState::{add_message, add_extra}` methods. +- Removed the private `VersionedStateBuilder::new()` constructor in favor of `VersionedStateBuilder::default()`. +- Tightened private runtime helper visibility and made SQLite `row_to_checkpoint` synchronous because it had no await points. +- Fixed the `IdGenerator` doctest to avoid using `gen`, a reserved keyword in Rust 2024. + ## [0.6.0] - 2026-05-11 ### Added @@ -191,7 +223,12 @@ Initial stable release. Core features: --- -[unreleased]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.2.0...HEAD +[unreleased]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.7.0...HEAD +[0.7.0]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.6.0...weavegraph-v0.7.0 +[0.6.0]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.5.0...weavegraph-v0.6.0 +[0.5.0]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.4.0...weavegraph-v0.5.0 +[0.4.0]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.3.0...weavegraph-v0.4.0 +[0.3.0]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.2.0...weavegraph-v0.3.0 [0.2.0]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.1.3...weavegraph-v0.2.0 [0.1.3]: https://github.com/Idleness76/weavegraph/compare/weavegraph-0.1.2...weavegraph-v0.1.3 [0.1.2]: https://github.com/Idleness76/weavegraph/compare/weavegraph-v0.1.1...weavegraph-0.1.2 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6085acb..876d918 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,12 +1,12 @@ # Contributing to Weavegraph -Thank you for your interest in contributing to Weavegraph! This project welcomes contributions from developers of all skill levels. Weavegraph is in active development (v0.2.x released, targeting v0.3.0 API stabilization) with ongoing improvements based on real-world usage and community feedback. +Thank you for your interest in contributing to Weavegraph! This project welcomes contributions from developers of all skill levels. Weavegraph is in active development (v0.7.x, moving toward 1.0 API stabilization) with ongoing improvements based on real-world usage and community feedback. ## Getting Started ### Prerequisites -- Rust 1.89 or later +- Rust 1.90 or later - Basic familiarity with async Rust and the `tokio` runtime - Understanding of graph-based workflows is helpful but not required @@ -26,9 +26,9 @@ Thank you for your interest in contributing to Weavegraph! This project welcomes 3. **Run examples to understand the framework**: ```bash - cargo run --example basic_nodes - cargo run --example advanced_patterns - cargo run --example streaming_events + cargo run --example basic_nodes --features examples + cargo run --example advanced_patterns --features examples + cargo run --example streaming_events --features examples ``` 4. **Set up local services** (optional): @@ -45,7 +45,7 @@ Before submitting a PR, run local CI checks to catch issues early: # Quick checks (fmt, clippy, test, doc) ./scripts/ci-quick.sh -# Full CI suite (includes MSRV 1.89, deny, machete) +# Full CI suite (includes MSRV 1.90, deny, machete) ./scripts/ci-local.sh ``` diff --git a/Cargo.toml b/Cargo.toml index df92454..af57e9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "weavegraph" -version = "0.6.0" +version = "0.7.0" edition = "2024" description = "Graph-driven, concurrent agent workflow framework with versioned state, deterministic barrier merges, and rich diagnostics." license = "Apache-2.0" @@ -124,6 +124,42 @@ examples = ["reqwest", "scraper"] metrics = ["dep:metrics"] petgraph-compat = ["petgraph"] +[[example]] +name = "basic_nodes" +required-features = ["examples"] + +[[example]] +name = "graph_execution" +required-features = ["examples"] + +[[example]] +name = "scheduler_fanout" +required-features = ["examples"] + +[[example]] +name = "advanced_patterns" +required-features = ["examples"] + +[[example]] +name = "convenience_streaming" +required-features = ["examples"] + +[[example]] +name = "streaming_events" +required-features = ["examples"] + +[[example]] +name = "event_backpressure" +required-features = ["examples"] + +[[example]] +name = "json_serialization" +required-features = ["examples"] + +[[example]] +name = "errors_pretty" +required-features = ["examples"] + [[example]] name = "production_streaming" required-features = ["postgres", "examples"] diff --git a/README.md b/README.md index c33a9a3..15b9943 100644 --- a/README.md +++ b/README.md @@ -35,10 +35,10 @@ Add to your `Cargo.toml`: ```toml [dependencies] -weavegraph = "0.5" +weavegraph = "0.7" ``` -> **Note:** Examples and instructions in this README are current as of 0.5.x. For upgrade notes across pre-1.0 releases, see [MIGRATION.md](docs/MIGRATION.md). +> **Note:** Examples and instructions in this README are current as of 0.7.x. For upgrade notes across pre-1.0 releases, see [MIGRATION.md](docs/MIGRATION.md). ## Dependency Compatibility diff --git a/SECURITY.md b/SECURITY.md index fde1a18..687bb59 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -6,9 +6,9 @@ We actively support the following versions of Weavegraph with security updates: | Version | Supported | | ------- | ------------------ | -| 0.2.x | :white_check_mark: | -| 0.1.x | :x: | -| < 0.1.0 | :x: | +| 0.7.x | :white_check_mark: | +| 0.6.x | :white_check_mark: | +| < 0.6.0 | :x: | ## Reporting a Vulnerability @@ -116,4 +116,4 @@ If you have questions about this security policy, please open a discussion on Gi --- -**Last updated**: 2026-03-06 +**Last updated**: 2026-05-23 diff --git a/docs/INDEX.md b/docs/INDEX.md index f2b99ce..5bcc35a 100644 --- a/docs/INDEX.md +++ b/docs/INDEX.md @@ -20,6 +20,6 @@ Complete documentation for building workflows with Weavegraph. - [MIGRATION.md](MIGRATION.md) - Upgrade notes and migration guides by release - [Schema Definitions](schemas/) - JSON schemas for event and error payloads -- [CHANGELOG.md](../CHANGELOG.md) - Release history (placeholder until added in §0.P.3) +- [CHANGELOG.md](../CHANGELOG.md) - Release history and release notes - [Examples](../examples/) - Runnable code for all major patterns - [STREAMING.md](STREAMING.md) - Event streaming quickstart guide diff --git a/docs/MIGRATION.md b/docs/MIGRATION.md index 031f2a0..87a8f80 100644 --- a/docs/MIGRATION.md +++ b/docs/MIGRATION.md @@ -5,6 +5,86 @@ migration guidance for upgrading your code. --- +## v0.7.0 + +### Overview + +v0.7.0 is a release-preparation cleanup that removes unused public API surface, +adds an ergonomic LLM streaming event builder, and tightens docs/examples around +the `examples` feature flag. Most changes are internal maintenance; the items +below are the user-visible migrations. + +### Breaking: `LLMStreamingEvent::new` removed + +The eight-argument `LLMStreamingEvent::new(...)` constructor has been removed. +Use the builder when constructing custom events: + +```rust +use weavegraph::event_bus::{LLMStreamingEvent, LLMStreamingEventScope}; + +let event = LLMStreamingEvent::builder(chunk) + .session_id(session_id) + .node_id(node_id) + .stream_id(stream_id) + .is_final(true) + .scope(LLMStreamingEventScope::Chunk) + .metadata(metadata) + .build(); +``` + +For standard cases, the existing `chunk_event`, `final_event`, and `error_event` +factory methods remain available and are still preferred. + +### Breaking: `NodeContext::emit_node()` removed + +`NodeContext::emit_node(scope, message)` was a redundant alias for +`NodeContext::emit(scope, message)`. + +**Migration**: replace calls directly: + +```rust +// Before +ctx.emit_node("progress", "started")?; + +// After +ctx.emit("progress", "started")?; +``` + +### Breaking: smaller public surfaces + +The following unused or internal-only public items were removed: + +- `PersistenceError::MissingField` +- `SQLiteCheckpointerError`; use `CheckpointerError` directly +- `checkpointer_postgres_helpers` and `checkpointer_sqlite_helpers` +- `IdError`, `ParsedId`, and unused `IdGenerator` helpers +- `id_utils`, `merge_inspector`, `message_id_helpers`, and `type_guards` +- `JsonValueExt::deep_clone()`; use `.clone()` instead + +### Breaking: `conditional_edges()` returns a slice + +`conditional_edges()` now returns `&[ConditionalEdge]` instead of +`&Vec`. Most callers need no change. If you explicitly +annotated the returned value as `&Vec<_>`, change the annotation to a slice: + +```rust +// Before +let edges: &Vec = config.conditional_edges(); + +// After +let edges: &[ConditionalEdge] = config.conditional_edges(); +``` + +### Deployment note: migration SQL checksums + +SQLite and PostgreSQL `0001_init.sql` files were rederived with equivalent +schemas but changed file contents. Existing databases that already ran these +migrations can trigger an `sqlx` checksum mismatch on the next migration run. +Either update the stored checksum in `_sqlx_migrations` after verifying the +schema, or regenerate the database. + +--- + ## v0.6.0 ### Overview @@ -117,7 +197,7 @@ cannot crash or abort a workflow invocation. Enable the `metrics` feature and attach `MetricsObserver` to export standard Prometheus-compatible metrics via any `metrics`-crate recorder (e.g. `metrics-exporter-prometheus`): ```toml -weavegraph = { version = "0.6", features = ["metrics"] } +weavegraph = { version = "0.7", features = ["metrics"] } metrics-exporter-prometheus = "0.17" ``` @@ -402,11 +482,12 @@ weavegraph = { version = "0.4", features = ["rig"] } identify internal diagnostic events when filtering the event stream. - `#![warn(missing_docs)]` is now enforced — all public API items are documented. - `examples/production_streaming.rs` — golden-path reference for Axum + SSE + - Postgres checkpointing (requires `--features postgres,examples`). + Postgres checkpointing (use `--features postgres-migrations,examples` for a + fresh database). --- -## v0.3.0 (Upcoming) +## v0.3.0 ### Breaking Changes @@ -697,7 +778,7 @@ let event = ErrorEvent::app(WeaveError::msg("startup failed")); --- -## v0.2.0 (Upcoming) +## v0.2.0 ### Breaking Changes @@ -926,6 +1007,10 @@ If you encounter issues during migration: | Weavegraph | Rust MSRV | rig-core | tokio | |------------|-----------|----------|-------| +| 0.7.x | 1.90.0 | 0.30.x | 1.x | +| 0.6.x | 1.90.0 | 0.30.x | 1.x | +| 0.5.x | 1.90.0 | 0.30.x | 1.x | +| 0.4.x | 1.90.0 | 0.30.x | 1.x | | 0.3.x | 1.90.0 | 0.30.x | 1.x | | 0.2.x | 1.89.0 | 0.28+ | 1.x | | 0.1.x | 1.89.0 | 0.28+ | 1.x | diff --git a/docs/OPERATIONS.md b/docs/OPERATIONS.md index 0679201..1e209d5 100644 --- a/docs/OPERATIONS.md +++ b/docs/OPERATIONS.md @@ -77,10 +77,10 @@ Rich tracing integration with configurable log levels: ```bash # Debug level for weavegraph modules -RUST_LOG=debug cargo run --example basic_nodes +RUST_LOG=debug cargo run --example basic_nodes --features examples # Error level globally, debug for weavegraph -RUST_LOG=error,weavegraph=debug cargo run --example advanced_patterns +RUST_LOG=error,weavegraph=debug cargo run --example advanced_patterns --features examples ``` ## Persistence {#persistence} diff --git a/examples/README.md b/examples/README.md index 15b80ca..a007d6b 100644 --- a/examples/README.md +++ b/examples/README.md @@ -5,9 +5,12 @@ Runnable examples for core Weavegraph patterns. ## Run an example ```bash -cargo run --example +cargo run --example --features examples ``` +Use `--features postgres-migrations,examples` for `production_streaming` when +running against a fresh database. + ## Core workflow examples - `basic_nodes` - Minimal node and graph setup. @@ -21,6 +24,7 @@ cargo run --example - `streaming_events` - Runtime/event bus streaming pattern for services. - `event_backpressure` - Handling lag and drop behavior under load. - `json_serialization` - Emitting machine-readable event payloads. +- `production_streaming` - PostgreSQL-backed Axum SSE example for production deployments. ## Error handling example diff --git a/examples/basic_nodes.rs b/examples/basic_nodes.rs index 76cbc55..5f7a061 100644 --- a/examples/basic_nodes.rs +++ b/examples/basic_nodes.rs @@ -6,7 +6,7 @@ //! - Return partial state updates //! - Use convenience constructors //! -//! Run with: `cargo run --example basic_nodes` +//! Run with: `cargo run --example basic_nodes --features examples` use async_trait::async_trait; use serde_json::json; diff --git a/examples/convenience_streaming.rs b/examples/convenience_streaming.rs index 1822348..94c6871 100644 --- a/examples/convenience_streaming.rs +++ b/examples/convenience_streaming.rs @@ -27,7 +27,7 @@ //! ## Run This Example //! //! ```bash -//! cargo run --example convenience_streaming +//! cargo run --example convenience_streaming --features examples //! ``` use async_trait::async_trait; diff --git a/examples/errors_pretty.rs b/examples/errors_pretty.rs index 9759c81..1be9b71 100644 --- a/examples/errors_pretty.rs +++ b/examples/errors_pretty.rs @@ -1,3 +1,7 @@ +//! Structured error formatting with explicit color-mode control. +//! +//! Run with: `cargo run --example errors_pretty --features examples` + use chrono::{TimeZone, Utc}; use serde_json::json; use weavegraph::channels::errors::{ErrorEvent, WeaveError, pretty_print, pretty_print_with_mode}; diff --git a/examples/event_backpressure.rs b/examples/event_backpressure.rs index 006933d..c03efbb 100644 --- a/examples/event_backpressure.rs +++ b/examples/event_backpressure.rs @@ -4,6 +4,8 @@ //! - How to detect lagged event streams //! - Metrics for monitoring dropped events //! - Proper handling of RecvError::Lagged +//! +//! Run with: `cargo run --example event_backpressure --features examples` use std::time::Duration; use tokio::time::sleep; diff --git a/examples/json_serialization.rs b/examples/json_serialization.rs index 3fa7201..4ca42d8 100644 --- a/examples/json_serialization.rs +++ b/examples/json_serialization.rs @@ -8,7 +8,7 @@ //! //! ## Run it //! ```bash -//! cargo run --example json_serialization +//! cargo run --example json_serialization --features examples //! ``` use rustc_hash::FxHashMap; diff --git a/examples/production_streaming.rs b/examples/production_streaming.rs index 5c26bab..bf77278 100644 --- a/examples/production_streaming.rs +++ b/examples/production_streaming.rs @@ -43,14 +43,14 @@ //! ## Feature Requirements //! //! ```bash -//! cargo run --example production_streaming --features postgres,examples +//! cargo run --example production_streaming --features postgres-migrations,examples //! ``` //! //! Set `DATABASE_URL` before running: //! //! ```bash //! export DATABASE_URL="postgres://postgres:postgres@localhost/weavegraph" -//! cargo run --example production_streaming --features postgres,examples +//! cargo run --example production_streaming --features postgres-migrations,examples //! ``` //! //! ## Testing diff --git a/examples/scheduler_fanout.rs b/examples/scheduler_fanout.rs index f0a7e12..9f4ab32 100644 --- a/examples/scheduler_fanout.rs +++ b/examples/scheduler_fanout.rs @@ -10,6 +10,8 @@ //! //! The scheduler ensures nodes only execute when their dependencies are ready, //! providing efficient concurrent execution while respecting the dependency graph. +//! +//! Run with: `cargo run --example scheduler_fanout --features examples` use async_trait::async_trait; use rustc_hash::FxHashMap; From fa88c31cbdb40eb09a1918ec9108df8523d6800f Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sun, 24 May 2026 11:11:56 -0400 Subject: [PATCH 13/15] better documentation for users for SQL migration patch --- CHANGELOG.md | 2 +- docs/MIGRATION.md | 176 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 171 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd2fa4d..0160b06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `WeaveError` now derives `Default`; behavior is unchanged. - `EventBus::with_sink` and `Channel::set_version` were cosmetically cleaned up without semantic changes. - `MetricsObserver` now relies on the default no-op `RuntimeObserver::on_invocation_start` implementation instead of an explicit empty override. -- Migration SQL files were rederived with equivalent schemas but different file contents. Existing databases that already ran these migrations may need their `sqlx` migration checksums updated or may need to be regenerated. +- **Breaking (existing databases — hard failure):** Migration SQL files were rederived with equivalent schemas but changed file contents. Any existing database that already ran the `0001_init` migration will produce a hard `sqlx` checksum error on the **very next `connect()` call**, and the application will refuse to start. This is **not** a schema change — no tables, columns, or data are altered — but the connection is unconditionally rejected until the stored checksum record is corrected. SQLite users are affected by default (`sqlite-migrations` is a default feature). PostgreSQL users are affected only if the `postgres-migrations` feature was explicitly enabled. **See the [v0.7.0 migration guide](docs/MIGRATION.md#breaking-migration-sql-checksum-mismatch-existing-databases) for exact fix SQL and step-by-step Docker upgrade instructions.** ### Removed diff --git a/docs/MIGRATION.md b/docs/MIGRATION.md index 87a8f80..59dfa9a 100644 --- a/docs/MIGRATION.md +++ b/docs/MIGRATION.md @@ -75,13 +75,177 @@ let edges: &Vec = config.conditional_edges(); let edges: &[ConditionalEdge] = config.conditional_edges(); ``` -### Deployment note: migration SQL checksums +### Breaking: Migration SQL checksum mismatch (existing databases) -SQLite and PostgreSQL `0001_init.sql` files were rederived with equivalent -schemas but changed file contents. Existing databases that already ran these -migrations can trigger an `sqlx` checksum mismatch on the next migration run. -Either update the stored checksum in `_sqlx_migrations` after verifying the -schema, or regenerate the database. +#### Severity + +**Hard failure. The application will not start against an existing database +until the fix is applied.** + +`sqlx` stores a SHA-384 checksum of every migration file in the +`_sqlx_migrations` table at the time the migration first runs. On every +subsequent `connect()` call it rechecks the file on disk against the stored +checksum. The `0001_init.sql` files for both SQLite and PostgreSQL were +rederived in 0.7.0 with equivalent schemas but changed file contents. The +checksums no longer match, and `sqlx` responds with a hard error: + +``` +error: migration 1/migrate `0001_init` was previously applied but has been modified +``` + +This is raised inside `SQLiteCheckpointer::connect()` / +`PostgresCheckpointer::connect()` before any graph logic runs. The application +will fail to initialize and will not continue. + +> **This is not a schema change.** No tables, columns, indexes, or data are +> altered. The database contents are safe. Only the checksum bookkeeping record +> needs updating. + +#### Who is affected + +| Backend | Feature flag | Default? | Affected? | +|------------|-----------------------|----------|-----------------------------------| +| SQLite | `sqlite-migrations` | **Yes** | All users unless explicitly opted out | +| PostgreSQL | `postgres-migrations` | No | Only users who explicitly enabled this feature | + +If you use SQLite and did not set `default-features = false` in your +`Cargo.toml` dependency declaration, you are affected. + +#### When does it happen + +The error fires on the **first `connect()` call after upgrading** to 0.7.0. +If the database file already has data from a previous run, the mismatch is +detected immediately and the call returns an error. There is no grace period +and no warning — it is a hard error. + +#### Can the fix be applied while the application is running? + +Only partially. If the application is already running and has an open +connection pool, that pool is unaffected until it reconnects. However: + +- Any new process that calls `connect()` (e.g. a restarted container, a new + worker process, or a second instance in a scaled deployment) will fail + immediately. +- The safest approach is to apply the fix SQL **before** deploying 0.7.0 to + any environment. The fix is a single-row `UPDATE` with no locking impact on + application traffic. + +#### Fix: SQLite + +Connect to the database file with any SQLite client (`sqlite3`, DB Browser for +SQLite, etc.) and run: + +```sql +UPDATE _sqlx_migrations +SET checksum = x'3b3263ea3c19ba500ad4f6535b9589ea011a6215a09dc40447e3b4756aebc6bf75b067213afc8692f719706611a8f81b' +WHERE version = 1; +``` + +Verify the row was updated before proceeding: + +```sql +SELECT version, checksum FROM _sqlx_migrations WHERE version = 1; +-- Should return one row; checksum will display as a hex BLOB. +``` + +#### Fix: PostgreSQL + +Connect as a user with `UPDATE` permission on `_sqlx_migrations` and run: + +```sql +UPDATE _sqlx_migrations +SET checksum = '\x5db62b4ff42843f429d889f5445a4117b0ff3b4cd185fae1d1b7685a8d3b37cd3d7a8da4265e95c9b0ab5b6efc9ac343'::bytea +WHERE version = 1; +``` + +Verify: + +```sql +SELECT version, encode(checksum, 'hex') AS checksum FROM _sqlx_migrations WHERE version = 1; +``` + +The `encode(...)` output should be: +``` +5db62b4ff42843f429d889f5445a4117b0ff3b4cd185fae1d1b7685a8d3b37cd3d7a8da4265e95c9b0ab5b6efc9ac343 +``` + +#### Example: Docker deployment with automatic image pulls (PostgreSQL) + +Many production PostgreSQL deployments use Docker Compose or a container +orchestrator (Kubernetes, Nomad, ECS, etc.) configured to pull and restart the +application container automatically when a new image is published. In this +setup, upgrading to 0.7.0 without patching the checksum first will cause the +application container to crash-loop immediately after starting — before it can +serve any traffic — because `connect()` is called at startup. + +**Apply the fix before deploying the new image.** Step-by-step: + +1. **Do not pull the new image yet.** Keep the existing 0.6.x container running + while you patch the database. + +2. **Connect to the PostgreSQL container** (adjust the container name, host, + port, user, and database name to match your deployment): + + ```bash + docker exec -it \ + psql -U -d + ``` + + Or, if PostgreSQL is not in Docker but is a managed service (RDS, Cloud + SQL, etc.), use `psql` from any host that can reach it: + + ```bash + psql "postgresql://:@:/" + ``` + +3. **Run the checksum update:** + + ```sql + UPDATE _sqlx_migrations + SET checksum = '\x5db62b4ff42843f429d889f5445a4117b0ff3b4cd185fae1d1b7685a8d3b37cd3d7a8da4265e95c9b0ab5b6efc9ac343'::bytea + WHERE version = 1; + -- Expected: UPDATE 1 + ``` + +4. **Verify the update succeeded:** + + ```sql + SELECT version, encode(checksum, 'hex') AS checksum FROM _sqlx_migrations WHERE version = 1; + ``` + + Confirm the hex value matches exactly: + ``` + 5db62b4ff42843f429d889f5445a4117b0ff3b4cd185fae1d1b7685a8d3b37cd3d7a8da4265e95c9b0ab5b6efc9ac343 + ``` + +5. **Exit `psql`** (`\q`) and disconnect from the database container. + +6. **Now pull and deploy the 0.7.0 image.** Because the checksum record + already reflects the new file, `connect()` will succeed and the container + will start normally. + + ```bash + docker pull /weavegraph: + docker compose up -d # or however you restart your stack + ``` + +7. **Confirm the container is healthy** before considering the upgrade complete: + + ```bash + docker ps # check STATUS column + docker logs --tail 50 + ``` + +If you missed this step and the container is already crash-looping, the +database is unharmed. Apply the `UPDATE` in step 3, then restart the container. + +#### Option: regenerate the database (development / test only) + +If the database contains only test or ephemeral data, the simplest fix is to +delete the database file (SQLite) or drop and recreate the schema (PostgreSQL) +so that `sqlx` runs the migration fresh against the new file. + +**Do not do this with production data.** --- From 37b9b8eb1f2efbea57c295a38f3122c57baa5231 Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sun, 24 May 2026 11:46:36 -0400 Subject: [PATCH 14/15] fix: resolve rustdoc warnings blocking nightly doc build - event_bus/bus.rs: remove redundant explicit link targets on StdOutSink and EventSink (rustdoc::redundant-explicit-links) - state.rs: remove duplicate code block stanza that left an orphaned closing fence, triggering rustdoc::invalid-rust-codeblocks Fixes all three -D warnings failures in the nightly doc job. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/event_bus/bus.rs | 4 ++-- src/state.rs | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/event_bus/bus.rs b/src/event_bus/bus.rs index 0195969..78738c7 100644 --- a/src/event_bus/bus.rs +++ b/src/event_bus/bus.rs @@ -40,10 +40,10 @@ const DEFAULT_BUFFER_CAPACITY: usize = 1024; /// /// # Available Sinks /// -/// - [`StdOutSink`](crate::event_bus::StdOutSink) — write events to stdout (default) +/// - [`StdOutSink`] — write events to stdout (default) /// - [`ChannelSink`](crate::event_bus::ChannelSink) — stream events to async channels /// - [`MemorySink`](crate::event_bus::MemorySink) — capture events for testing -/// - Custom sinks implementing [`EventSink`](crate::event_bus::EventSink) +/// - Custom sinks implementing [`EventSink`] pub struct EventBus { sinks: Arc>>, hub: Arc, diff --git a/src/state.rs b/src/state.rs index fe56a12..2d1f8b2 100644 --- a/src/state.rs +++ b/src/state.rs @@ -232,10 +232,6 @@ pub enum StateSlotError { /// assert_eq!(snapshot.messages.len(), 1); /// assert_eq!(snapshot.extra.get("session_id"), Some(&json!("sess_123"))); /// ``` -/// let snapshot = state.snapshot(); -/// assert_eq!(snapshot.messages.len(), 1); -/// assert_eq!(snapshot.extra.get("session_id"), Some(&json!("sess_123"))); -/// ``` #[derive(Clone, Debug, PartialEq, Eq)] pub struct VersionedState { /// Conversation messages. From a439ef932a7942d0c732809e21678148aa1d35c6 Mon Sep 17 00:00:00 2001 From: Idleness76 Date: Sun, 24 May 2026 12:20:54 -0400 Subject: [PATCH 15/15] sort the license --- LICENSE | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/LICENSE b/LICENSE index d645695..4947287 100644 --- a/LICENSE +++ b/LICENSE @@ -174,29 +174,4 @@ incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. + END OF TERMS AND CONDITIONS \ No newline at end of file