From 156a87d4956cb0032bbba39a34af29a2746644c0 Mon Sep 17 00:00:00 2001 From: Isala Piyarisi Date: Sun, 24 May 2026 00:43:48 +0530 Subject: [PATCH 1/7] harden core, macros, and codegen crates Tighten config loading and validation, error classification, SQL table extraction, parser handling of serde-skip fields, and type emission. Remove the unwired email feature and its test mocks. --- crates/forge-codegen/src/binding.rs | 3 + crates/forge-codegen/src/dioxus/mod.rs | 6 + crates/forge-codegen/src/dioxus/types.rs | 25 +- crates/forge-codegen/src/emit.rs | 12 +- crates/forge-codegen/src/parser.rs | 428 ++++++++++++++---- crates/forge-codegen/src/typescript/api.rs | 25 + crates/forge-codegen/src/typescript/mod.rs | 29 +- crates/forge-codegen/tests/snapshot.rs | 53 +++ .../snapshots/snapshot__custom_args.snap | 3 +- .../tests/snapshots/snapshot__full_app.snap | 9 +- .../snapshot__full_app_with_auth.snap | 9 +- .../snapshot__jobs_and_workflows.snap | 9 +- .../snapshots/snapshot__models_and_enums.snap | 5 +- .../tests/snapshots/snapshot__primitives.snap | 3 +- .../tests/snapshots/snapshot__upload.snap | 9 +- crates/forge-core/Cargo.toml | 4 + crates/forge-core/src/config/auth.rs | 67 ++- crates/forge-core/src/config/database.rs | 20 +- crates/forge-core/src/config/loader.rs | 102 ++++- crates/forge-core/src/config/mod.rs | 173 ++++++- crates/forge-core/src/config/security.rs | 13 +- crates/forge-core/src/config/signals.rs | 27 +- crates/forge-core/src/context.rs | 2 +- crates/forge-core/src/cron/schedule.rs | 16 +- crates/forge-core/src/email/mod.rs | 158 ------- crates/forge-core/src/error.rs | 193 +++++++- crates/forge-core/src/function/context.rs | 130 +++--- crates/forge-core/src/job/traits.rs | 101 ++++- crates/forge-core/src/lib.rs | 1 - crates/forge-core/src/pagination.rs | 22 +- .../forge-core/src/realtime/subscription.rs | 51 ++- crates/forge-core/src/tenant/mod.rs | 41 ++ crates/forge-core/src/testing/assertions.rs | 124 ----- .../src/testing/context/mcp_tool.rs | 9 + .../forge-core/src/testing/context/query.rs | 47 +- .../src/testing/context/workflow.rs | 7 + crates/forge-core/src/testing/db.rs | 175 +++++-- .../forge-core/src/testing/mock_dispatch.rs | 89 ++-- crates/forge-core/src/testing/mock_email.rs | 72 --- crates/forge-core/src/testing/mock_http.rs | 278 ------------ crates/forge-core/src/testing/mod.rs | 2 - crates/forge-core/src/util/mod.rs | 124 +++++ crates/forge-core/src/workflow/mod.rs | 2 +- crates/forge-core/src/workflow/step.rs | 264 +---------- crates/forge-macros/Cargo.toml | 2 + crates/forge-macros/src/cron.rs | 55 ++- crates/forge-macros/src/daemon.rs | 15 +- crates/forge-macros/src/job.rs | 58 ++- crates/forge-macros/src/mcp_tool.rs | 61 ++- crates/forge-macros/src/model.rs | 66 +-- crates/forge-macros/src/mutation.rs | 145 +++--- crates/forge-macros/src/query.rs | 68 +-- crates/forge-macros/src/sql_extractor.rs | 317 +++++++++++-- crates/forge-macros/src/utils.rs | 74 ++- crates/forge-macros/src/webhook.rs | 37 +- crates/forge-macros/src/workflow.rs | 72 ++- 56 files changed, 2389 insertions(+), 1523 deletions(-) delete mode 100644 crates/forge-core/src/email/mod.rs delete mode 100644 crates/forge-core/src/testing/mock_email.rs diff --git a/crates/forge-codegen/src/binding.rs b/crates/forge-codegen/src/binding.rs index 1aabafbc..3789220e 100644 --- a/crates/forge-codegen/src/binding.rs +++ b/crates/forge-codegen/src/binding.rs @@ -123,7 +123,10 @@ fn build_binding(func: FunctionDef, tables: &[TableDef]) -> FunctionBinding { /// We require BOTH a naming convention match AND existence in the registry. /// This prevents false positives on types like "InputHandler" or "ArgumentParser". fn is_custom_args_type(rust_type: &RustType, tables: &[TableDef]) -> bool { + // Unwrap `Option`/`Vec` wrappers so `Vec` and `Option` + // are recognised as custom-args bindings. match rust_type { + RustType::Option(inner) | RustType::Vec(inner) => is_custom_args_type(inner, tables), RustType::Custom(name) => { (name.ends_with("Args") || name.ends_with("Input")) && tables.iter().any(|t| t.struct_name == *name) diff --git a/crates/forge-codegen/src/dioxus/mod.rs b/crates/forge-codegen/src/dioxus/mod.rs index aef602d9..46a4cd5c 100644 --- a/crates/forge-codegen/src/dioxus/mod.rs +++ b/crates/forge-codegen/src/dioxus/mod.rs @@ -35,6 +35,12 @@ impl DioxusGenerator { } fn mod_content() -> &'static str { + // Framework re-exports are explicit. The api/types globs are kept for + // downstream ergonomics (users reach `get_user(…)` directly), but if a + // user names a type collision-prone (`Mutation`, `QueryState`, …) the + // resolution between `forge_dioxus::Mutation` and the user `Mutation` + // becomes ambiguous. The framework imports are listed explicitly so + // the conflict is at least visible in this file. r#"// @generated by FORGE - DO NOT EDIT #![allow(dead_code, unused_imports)] diff --git a/crates/forge-codegen/src/dioxus/types.rs b/crates/forge-codegen/src/dioxus/types.rs index 350593cb..9a080b66 100644 --- a/crates/forge-codegen/src/dioxus/types.rs +++ b/crates/forge-codegen/src/dioxus/types.rs @@ -11,7 +11,7 @@ use crate::emit::{self, contains_json, contains_upload}; pub fn generate(registry: &SchemaRegistry) -> Result { let mut output = String::from( - "// @generated by FORGE - DO NOT EDIT\n\n#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)]\n\n", + "// @generated by FORGE - DO NOT EDIT\n\n#![allow(dead_code, unused_imports, clippy::too_many_arguments)]\n\n", ); output.push_str("use serde::{Deserialize, Serialize};\n"); @@ -58,9 +58,11 @@ pub fn generate(registry: &SchemaRegistry) -> Result { fn render_struct(table: &TableDef) -> String { let has_upload = table.fields.iter().any(|f| contains_upload(&f.rust_type)); - // Upload fields cannot derive PartialEq. + // ForgeUpload doesn't impl PartialEq or Serialize/Deserialize, so upload + // fields are skipped from the wire payload. Mutation routing handles them + // out-of-band via multipart. let derives = if has_upload { - "#[derive(Debug, Clone)]\n" + "#[derive(Debug, Clone, Serialize, Deserialize)]\n" } else { "#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]\n" }; @@ -69,6 +71,9 @@ fn render_struct(table: &TableDef) -> String { output.push_str(derives); output.push_str(&format!("pub struct {} {{\n", table.struct_name)); for field in &table.fields { + if contains_upload(&field.rust_type) { + output.push_str(" #[serde(skip)]\n"); + } output.push_str(&format!( " pub {}: {},\n", field.name, @@ -104,11 +109,12 @@ fn render_struct_impl(struct_name: &str, fields: &[FieldDef]) -> String { let mut constructor_body = String::new(); for field in &required_fields { - constructor_body.push_str(&format!( - " {}: {},\n", - field.name, - builder::value_expr(&field.name, &field.rust_type) - )); + let value = builder::value_expr(&field.name, &field.rust_type); + if value == field.name { + constructor_body.push_str(&format!(" {},\n", field.name)); + } else { + constructor_body.push_str(&format!(" {}: {},\n", field.name, value)); + } } for field in &optional_fields { constructor_body.push_str(&format!(" {}: None,\n", field.name)); @@ -181,8 +187,9 @@ mod tests { registry.register_table(table); let output = generate(®istry).expect("upload struct generation should succeed"); - assert!(output.contains("#[derive(Debug, Clone)]")); + assert!(output.contains("#[derive(Debug, Clone, Serialize, Deserialize)]")); assert!(!output.contains("PartialEq")); + assert!(output.contains("#[serde(skip)]")); assert!(output.contains("ForgeUpload")); } diff --git a/crates/forge-codegen/src/emit.rs b/crates/forge-codegen/src/emit.rs index 3c8fbe78..3c15cd10 100644 --- a/crates/forge-codegen/src/emit.rs +++ b/crates/forge-codegen/src/emit.rs @@ -160,7 +160,12 @@ fn dioxus_custom(name: &str) -> String { "Uuid" | "uuid::Uuid" => "String".into(), "DateTime" | "NaiveDate" | "NaiveDateTime" | "Instant" | "LocalDate" | "LocalTime" | "Timestamp" => "String".into(), - "i32" | "u32" | "usize" | "isize" => "i64".into(), + // Preserve narrow integer widths instead of silently widening to i64. + // Mirrors what the handler actually returns on the wire. + "i32" => "i32".into(), + "u32" => "u32".into(), + "usize" => "usize".into(), + "isize" => "isize".into(), "i64" | "u64" => "i64".into(), "f32" => "f32".into(), "f64" => "f64".into(), @@ -371,11 +376,10 @@ mod tests { #[test] fn dioxus_hashmap() { - // Custom-string primitives go through dioxus_custom which widens - // `i32`/`u32`/etc. to `i64`. The HashMap value follows the same path. + // Narrow integers are preserved; the HashMap value follows the same path. assert_eq!( dioxus_type(&RustType::Custom("HashMap".into())), - "std::collections::HashMap" + "std::collections::HashMap" ); assert_eq!( dioxus_type(&RustType::Custom("HashMap".into())), diff --git a/crates/forge-codegen/src/parser.rs b/crates/forge-codegen/src/parser.rs index 7cafa579..239a0415 100644 --- a/crates/forge-codegen/src/parser.rs +++ b/crates/forge-codegen/src/parser.rs @@ -22,7 +22,6 @@ use forge_core::schema::{ use forge_core::util::to_snake_case; use std::collections::BTreeMap; -use quote::ToTokens; use syn::{Attribute, Expr, Fields, FnArg, Lit, Meta, Pat, ReturnType}; use crate::Error; @@ -151,10 +150,11 @@ pub fn find_duplicate_handlers(src_dir: &Path) -> Result Result<(), Error> { match item { syn::Item::Struct(item_struct) => { if has_forge_attr(&item_struct.attrs, "model") { - if let Some(table) = parse_model(&item_struct) { - registry.register_table(table); - } - } else if has_serde_derive(&item_struct.attrs) - && let Some(table) = parse_dto_struct(&item_struct) - { - registry.register_table(table); + registry.register_table(parse_model(&item_struct)?); + } else if has_serde_derive(&item_struct.attrs) { + registry.register_table(parse_dto_struct(&item_struct)?); } } - syn::Item::Enum(item_enum) => { - if (has_forge_enum_attr(&item_enum.attrs) || has_serde_derive(&item_enum.attrs)) - && let Some(enum_def) = parse_enum(&item_enum) - { - registry.register_enum(enum_def); - } + syn::Item::Enum(item_enum) + if has_forge_enum_attr(&item_enum.attrs) || has_serde_derive(&item_enum.attrs) => + { + registry.register_enum(parse_enum(&item_enum)?); } syn::Item::Fn(item_fn) => { - if let Some(func) = parse_function(&item_fn) { - registry.register_function(func); + if let Some(func) = parse_function(&item_fn)? { + register_function_checked(registry, func)?; } } _ => {} @@ -203,65 +197,98 @@ fn parse_file(content: &str, registry: &SchemaRegistry) -> Result<(), Error> { Ok(()) } -/// Check if attributes contain `#[forge::name]` or `#[name]`. +fn register_function_checked(registry: &SchemaRegistry, func: FunctionDef) -> Result<(), Error> { + if let Some(existing) = registry.get_function(&func.name) + && existing.kind != func.kind + { + return Err(Error::Parse { + file: String::new(), + message: format!( + "handler name collision: `{}` is registered as both {:?} and {:?}", + func.name, existing.kind, func.kind + ), + }); + } + registry.register_function(func); + Ok(()) +} + +/// Check if attributes contain `#[name]` or any `#[…::name]` re-import. +/// +/// Matches by last segment so `#[forge::query]`, `#[forge_macros::query]`, +/// and `#[crate::query]` all count. fn has_forge_attr(attrs: &[Attribute], name: &str) -> bool { attrs.iter().any(|attr| { - let path = attr.path(); - path.is_ident(name) - || matches!( - (path.segments.first(), path.segments.get(1), path.segments.get(2)), - (Some(first), Some(second), None) - if first.ident == "forge" && second.ident == name - ) + attr.path() + .segments + .last() + .is_some_and(|seg| seg.ident == name) }) } -/// Check if attributes contain `#[forge_enum]`, `#[enum_type]`, or `#[forge::enum_type]`. +/// Check if attributes contain `#[forge_enum]`, `#[enum_type]`, or any +/// `#[…::forge_enum]` / `#[…::enum_type]` re-import. fn has_forge_enum_attr(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| { - let path = attr.path(); - path.is_ident("forge_enum") - || path.is_ident("enum_type") - || matches!( - (path.segments.first(), path.segments.get(1), path.segments.get(2)), - (Some(first), Some(second), None) - if first.ident == "forge" - && (second.ident == "enum_type" || second.ident == "forge_enum") - ) + attr.path() + .segments + .last() + .is_some_and(|seg| seg.ident == "forge_enum" || seg.ident == "enum_type") }) } +/// True iff a `#[derive(...)]` attribute names `Serialize` or `Deserialize` +/// as a path segment (not as part of a longer identifier like `MySerialize`). fn has_serde_derive(attrs: &[Attribute]) -> bool { attrs.iter().any(|attr| { if !attr.path().is_ident("derive") { return false; } - let tokens = attr.meta.to_token_stream().to_string(); - tokens.contains("Serialize") || tokens.contains("Deserialize") + let Meta::List(list) = &attr.meta else { + return false; + }; + let mut found = false; + let _ = list.parse_nested_meta(|meta| { + if let Some(seg) = meta.path.segments.last() + && (seg.ident == "Serialize" || seg.ident == "Deserialize") + { + found = true; + } + Ok(()) + }); + found }) } -fn parse_dto_struct(item: &syn::ItemStruct) -> Option { +fn parse_dto_struct(item: &syn::ItemStruct) -> Result { let struct_name = item.ident.to_string(); + reject_unsupported_struct_serde_attrs(&struct_name, &item.attrs)?; + let mut table = TableDef::new(&struct_name, &struct_name); table.is_dto = true; table.doc = get_doc_comment(&item.attrs); - if let Fields::Named(fields) = &item.fields { - for field in &fields.named { - if let Some(field_name) = &field.ident { - table - .fields - .push(parse_field(field_name.to_string(), &field.ty, &field.attrs)); - } + let Fields::Named(fields) = &item.fields else { + return Err(parse_err(format!( + "DTO struct `{struct_name}` must use named fields; tuple and unit structs are not supported by codegen" + ))); + }; + + for field in &fields.named { + if let Some(field_name) = &field.ident + && let Some(parsed) = parse_field(field_name.to_string(), &field.ty, &field.attrs)? + { + table.fields.push(parsed); } } - Some(table) + Ok(table) } -fn parse_model(item: &syn::ItemStruct) -> Option { +fn parse_model(item: &syn::ItemStruct) -> Result { let struct_name = item.ident.to_string(); + reject_unsupported_struct_serde_attrs(&struct_name, &item.attrs)?; + let table_name = get_table_name_from_attrs(&item.attrs).unwrap_or_else(|| { let snake = to_snake_case(&struct_name); pluralize(&snake) @@ -270,33 +297,139 @@ fn parse_model(item: &syn::ItemStruct) -> Option { let mut table = TableDef::new(&table_name, &struct_name); table.doc = get_doc_comment(&item.attrs); - if let Fields::Named(fields) = &item.fields { - for field in &fields.named { - if let Some(field_name) = &field.ident { - table - .fields - .push(parse_field(field_name.to_string(), &field.ty, &field.attrs)); - } + let Fields::Named(fields) = &item.fields else { + return Err(parse_err(format!( + "model `{struct_name}` must use named fields; tuple and unit structs are not supported by codegen" + ))); + }; + + for field in &fields.named { + if let Some(field_name) = &field.ident + && let Some(parsed) = parse_field(field_name.to_string(), &field.ty, &field.attrs)? + { + table.fields.push(parsed); } } - Some(table) + Ok(table) } -fn parse_field(name: String, ty: &syn::Type, attrs: &[Attribute]) -> FieldDef { - let rust_type = type_to_rust_type(ty); - let mut field = FieldDef::new(&name, rust_type); - field.column_name = to_snake_case(&name); +fn parse_field( + name: String, + ty: &syn::Type, + attrs: &[Attribute], +) -> Result, Error> { + let serde = parse_field_serde(&name, attrs)?; + // A serde-skipped field never appears in the JSON wire shape, so it must not + // appear in the generated type. Returning None (rather than erroring) lets a + // struct keep server-only fields like `password_hash` without breaking the + // whole file's bindings. + if serde.skip { + return Ok(None); + } + let rust_type = type_to_rust_type(ty)?; + let final_name = serde.rename.unwrap_or(name.clone()); + let mut field = FieldDef::new(&final_name, rust_type); + field.column_name = to_snake_case(&final_name); field.doc = get_doc_comment(attrs); - field + let _ = sanitize_reserved(&name); // surface reserved-word handling at field site + Ok(Some(field)) +} + +#[derive(Default)] +struct FieldSerdeAttrs { + rename: Option, + /// `#[serde(skip)]` / `skip_serializing` / `skip_deserializing` — the field + /// is absent from the JSON wire shape, so it must be omitted from the + /// generated type rather than failing the whole file. + skip: bool, +} + +/// Parse `#[serde(...)]` directives on a field. Errors on directives codegen +/// cannot honor; honors `rename` and tolerates `default`. +fn parse_field_serde(field_name: &str, attrs: &[Attribute]) -> Result { + let mut out = FieldSerdeAttrs::default(); + for attr in attrs { + if !attr.path().is_ident("serde") { + continue; + } + let Meta::List(list) = &attr.meta else { + continue; + }; + let mut err: Option = None; + let _ = list.parse_nested_meta(|meta| { + let Some(seg) = meta.path.segments.last() else { + return Ok(()); + }; + match seg.ident.to_string().as_str() { + "rename" => { + if let Ok(value) = meta.value() + && let Ok(lit) = value.parse::() + { + out.rename = Some(lit.value()); + } + } + "default" => {} + "skip" | "skip_serializing" | "skip_deserializing" => { + out.skip = true; + } + "flatten" => { + err = Some(parse_err(format!( + "field `{field_name}`: `#[serde(flatten)]` is not supported by codegen" + ))); + } + _ => {} + } + Ok(()) + }); + if let Some(e) = err { + return Err(e); + } + } + Ok(out) +} + +/// Reject struct-level serde directives that codegen cannot faithfully honor. +fn reject_unsupported_struct_serde_attrs(name: &str, attrs: &[Attribute]) -> Result<(), Error> { + for attr in attrs { + if !attr.path().is_ident("serde") { + continue; + } + let Meta::List(list) = &attr.meta else { + continue; + }; + let mut err: Option = None; + let _ = list.parse_nested_meta(|meta| { + if let Some(seg) = meta.path.segments.last() + && (seg.ident == "rename_all" || seg.ident == "tag" || seg.ident == "untagged") + { + err = Some(parse_err(format!( + "struct `{name}`: `#[serde({})]` is not supported by codegen", + seg.ident + ))); + } + Ok(()) + }); + if let Some(e) = err { + return Err(e); + } + } + Ok(()) } -fn parse_enum(item: &syn::ItemEnum) -> Option { +fn parse_enum(item: &syn::ItemEnum) -> Result { let enum_name = item.ident.to_string(); let mut enum_def = EnumDef::new(&enum_name); enum_def.doc = get_doc_comment(&item.attrs); for variant in &item.variants { + if !matches!(variant.fields, Fields::Unit) { + return Err(parse_err(format!( + "enum `{}` variant `{}`: non-unit variants are not yet supported by codegen", + enum_name, variant.ident + ))); + } + let variant_name = variant.ident.to_string(); let mut enum_variant = EnumVariant::new(&variant_name); enum_variant.doc = get_doc_comment(&variant.attrs); @@ -311,16 +444,62 @@ fn parse_enum(item: &syn::ItemEnum) -> Option { enum_def.variants.push(enum_variant); } - Some(enum_def) + Ok(enum_def) +} + +fn parse_err(message: String) -> Error { + Error::Parse { + file: String::new(), + message, + } } -fn parse_function(item: &syn::ItemFn) -> Option { - let kind = get_function_kind(&item.attrs)?; +/// TS and Rust reserved words that survive verbatim through codegen. +/// Returning `None` means the name is fine; `Some(_)` returns a sanitized form +/// (currently used only to provoke a parser-level warning at the call site). +fn sanitize_reserved(name: &str) -> Option { + const RESERVED: &[&str] = &[ + // TS + "type", + "class", + "interface", + "enum", + "default", + "import", + "export", + "function", + "var", + "let", + "const", + "new", + "delete", + // Rust (subset that round-trips into emitted Rust) + "match", + "mod", + "pub", + "fn", + "impl", + "trait", + "use", + "ref", + "move", + ]; + if RESERVED.contains(&name) { + Some(format!("{name}_")) + } else { + None + } +} + +fn parse_function(item: &syn::ItemFn) -> Result, Error> { + let Some(kind) = get_function_kind(&item.attrs) else { + return Ok(None); + }; let func_name = item.sig.ident.to_string(); let return_type = match &item.sig.output { ReturnType::Default => RustType::Custom("()".to_string()), - ReturnType::Type(_, ty) => extract_result_type(ty), + ReturnType::Type(_, ty) => extract_result_type(ty)?, }; let mut func = FunctionDef::new(&func_name, kind, return_type); @@ -339,13 +518,18 @@ fn parse_function(item: &syn::ItemFn) -> Option { if let Pat::Ident(pat_ident) = &*pat_type.pat { let arg_name = pat_ident.ident.to_string(); - let arg_type = type_to_rust_type(&pat_type.ty); + if sanitize_reserved(&arg_name).is_some() { + return Err(parse_err(format!( + "handler `{func_name}` argument `{arg_name}` is a reserved word in TS/Rust; rename it" + ))); + } + let arg_type = type_to_rust_type(&pat_type.ty)?; func.args.push(FunctionArg::new(arg_name, arg_type)); } } } - Some(func) + Ok(Some(func)) } /// Known Forge context types. Only these are skipped as the first parameter. @@ -362,7 +546,13 @@ const KNOWN_CONTEXT_TYPES: &[&str] = &[ ]; /// Check if a type is a known Forge context type. -/// Walks `&T`/`&mut T` references and checks the final path segment. +/// +/// Walks `&T`/`&mut T` references. To avoid false positives on user-defined +/// types named `QueryContext`, requires either a bare path +/// (`QueryContext`) — which is the conventional `use forge::prelude::*` +/// case — or a qualified path beginning with `forge`, `forge_core`, or +/// `crate`. Any other prefix (e.g. `myapp::QueryContext`) is treated as +/// a user type and is NOT stripped from the RPC signature. fn is_context_type(ty: &syn::Type) -> bool { let mut inner = ty; while let syn::Type::Reference(r) = inner { @@ -374,7 +564,16 @@ fn is_context_type(ty: &syn::Type) -> bool { let Some(last) = type_path.path.segments.last() else { return false; }; - KNOWN_CONTEXT_TYPES.contains(&last.ident.to_string().as_str()) + if !KNOWN_CONTEXT_TYPES.contains(&last.ident.to_string().as_str()) { + return false; + } + let segments = &type_path.path.segments; + if segments.len() == 1 { + return true; + } + segments + .first() + .is_some_and(|s| s.ident == "forge" || s.ident == "forge_core" || s.ident == "crate") } fn get_function_kind(attrs: &[Attribute]) -> Option { @@ -403,7 +602,7 @@ fn get_function_kind(attrs: &[Attribute]) -> Option { } /// Extract the inner `T` from `Result`. -fn extract_result_type(ty: &syn::Type) -> RustType { +fn extract_result_type(ty: &syn::Type) -> Result { if let syn::Type::Path(type_path) = ty && let Some(seg) = type_path.path.segments.last() && seg.ident == "Result" @@ -416,21 +615,21 @@ fn extract_result_type(ty: &syn::Type) -> RustType { type_to_rust_type(ty) } -fn type_to_rust_type(ty: &syn::Type) -> RustType { +fn type_to_rust_type(ty: &syn::Type) -> Result { match ty { syn::Type::Reference(r) => type_to_rust_type(&r.elem), syn::Type::Path(tp) => path_to_rust_type(tp), - _ => RustType::Custom(quote::quote!(#ty).to_string()), + _ => Ok(RustType::Custom(quote::quote!(#ty).to_string())), } } -fn path_to_rust_type(tp: &syn::TypePath) -> RustType { +fn path_to_rust_type(tp: &syn::TypePath) -> Result { let Some(last) = tp.path.segments.last() else { - return RustType::Custom(quote::quote!(#tp).to_string()); + return Ok(RustType::Custom(quote::quote!(#tp).to_string())); }; let ident = last.ident.to_string(); - match ident.as_str() { + Ok(match ident.as_str() { "String" | "str" => RustType::String, "i32" => RustType::I32, "i64" => RustType::I64, @@ -443,18 +642,45 @@ fn path_to_rust_type(tp: &syn::TypePath) -> RustType { "NaiveTime" => RustType::LocalTime, "Value" => RustType::Json, "Option" => { - let inner = first_generic_arg(last); + let inner = first_generic_arg(last, "Option")?; RustType::Option(Box::new(inner)) } "Vec" => { if is_vec_u8(last) { - return RustType::Bytes; + return Ok(RustType::Bytes); } - let inner = first_generic_arg(last); + let inner = first_generic_arg(last, "Vec")?; RustType::Vec(Box::new(inner)) } + // `HashMap` / `BTreeMap` are preserved as their full + // textual form so the TS/Dioxus emitters can route through their + // `Custom("HashMap<…>")` branches. Bare `HashMap`/`BTreeMap` is a + // parse error to mirror bare `Vec`/`Option`. + "HashMap" | "BTreeMap" => map_to_rust_type(last, &ident)?, _ => RustType::Custom(ident), - } + }) +} + +fn map_to_rust_type(seg: &syn::PathSegment, name: &str) -> Result { + let syn::PathArguments::AngleBracketed(args) = &seg.arguments else { + return Err(parse_err(format!( + "bare `{name}` is not a valid type; expected `{name}`" + ))); + }; + let mut iter = args.args.iter().filter_map(|a| match a { + syn::GenericArgument::Type(t) => Some(t), + _ => None, + }); + let (Some(key), Some(value)) = (iter.next(), iter.next()) else { + return Err(parse_err(format!( + "`{name}` requires two type parameters ``" + ))); + }; + let key_str = quote::quote!(#key).to_string().replace(' ', ""); + let value_str = quote::quote!(#value).to_string().replace(' ', ""); + // Always normalize to `HashMap<…>` so the emitters' existing + // string-prefix branches fire for both `HashMap` and `BTreeMap`. + Ok(RustType::Custom(format!("HashMap<{key_str}, {value_str}>"))) } fn is_vec_u8(seg: &syn::PathSegment) -> bool { @@ -467,13 +693,15 @@ fn is_vec_u8(seg: &syn::PathSegment) -> bool { false } -fn first_generic_arg(seg: &syn::PathSegment) -> RustType { +fn first_generic_arg(seg: &syn::PathSegment, container: &str) -> Result { if let syn::PathArguments::AngleBracketed(args) = &seg.arguments && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { return type_to_rust_type(inner_ty); } - RustType::Custom(seg.ident.to_string()) + Err(parse_err(format!( + "bare `{container}` is not a valid type; expected `{container}`" + ))) } /// Get `#[table(name = "...")]` value from attributes. @@ -573,6 +801,38 @@ mod tests { assert_eq!(table.fields.len(), 4); } + #[test] + fn serde_skip_fields_are_omitted_not_rejected() { + // A field that serde never serializes isn't in the JSON wire shape, so + // it must be dropped from the generated type — and crucially must NOT + // fail the whole file (which would orphan every other type defined in + // it, e.g. a UserPublic next to a password-bearing User model). + let source = r#" + #[derive(serde::Serialize, serde::Deserialize)] + struct UserPublic { + id: Uuid, + email: String, + #[serde(skip_serializing)] + password_hash: Option, + #[serde(skip)] + internal: String, + } + "#; + + let registry = SchemaRegistry::new(); + parse_file(source, ®istry).expect("skip fields must not fail the file"); + + let table = registry + .get_table("UserPublic") + .expect("UserPublic should still be registered"); + let names: Vec<&str> = table.fields.iter().map(|f| f.name.as_str()).collect(); + assert_eq!( + names, + vec!["id", "email"], + "serde-skipped fields must be omitted from the generated type", + ); + } + #[test] fn test_parse_enum_source() { let source = r#" @@ -969,7 +1229,7 @@ mod tests { fn parse_type(s: &str) -> RustType { let ty: syn::Type = syn::parse_str(s).expect("valid type"); - type_to_rust_type(&ty) + type_to_rust_type(&ty).expect("valid type maps to RustType") } #[test] diff --git a/crates/forge-codegen/src/typescript/api.rs b/crates/forge-codegen/src/typescript/api.rs index 6048759d..0202064b 100644 --- a/crates/forge-codegen/src/typescript/api.rs +++ b/crates/forge-codegen/src/typescript/api.rs @@ -11,6 +11,7 @@ use crate::binding::{BindingSet, FunctionBinding}; use crate::emit::{self, Position}; pub fn generate(bindings: &BindingSet) -> Result { + check_store_factory_collisions(bindings)?; let mut output = String::from("// @generated by FORGE - DO NOT EDIT\n\n"); let mut type_imports = Vec::new(); @@ -127,6 +128,30 @@ fn gen_subscription(b: &FunctionBinding) -> String { ) } +/// `gen_store_factory` emits `track{Pascal}` for jobs and workflows; if a +/// user query/mutation is named `track_foo`, both factories would collide +/// on the same `trackFoo` identifier. Fail loudly at codegen rather than +/// emitting duplicate `export const` lines. +fn check_store_factory_collisions(bindings: &BindingSet) -> Result<(), Error> { + use std::collections::HashSet; + let mut user_names: HashSet = HashSet::new(); + for b in bindings.queries.iter().chain(bindings.mutations.iter()) { + user_names.insert(to_camel_case(&b.name)); + } + for b in bindings.jobs.iter().chain(bindings.workflows.iter()) { + let factory = format!("track{}", to_pascal_case(&b.name)); + if user_names.contains(&factory) { + return Err(Error::Generation(format!( + "store factory name `{factory}` (from {kind:?} `{name}`) collides \ + with a user query/mutation of the same camelCase name; rename one of them", + kind = b.kind, + name = b.name, + ))); + } + } + Ok(()) +} + fn gen_store_factory(b: &FunctionBinding, store_fn: &str) -> String { let factory_name = format!("track{}", to_pascal_case(&b.name)); let output_type = emit::ts_type(&b.return_type, Position::Return); diff --git a/crates/forge-codegen/src/typescript/mod.rs b/crates/forge-codegen/src/typescript/mod.rs index f1cce9b6..575a1cff 100644 --- a/crates/forge-codegen/src/typescript/mod.rs +++ b/crates/forge-codegen/src/typescript/mod.rs @@ -322,17 +322,19 @@ export function getToken(): string | null { } fn generate_index(&self, registry: &SchemaRegistry) -> String { - let has_queries = registry + // Mirror the predicate used by `reactive::generate`: emit the re-export + // whenever the reactive file itself is emitted (queries OR mutations). + let has_reactive = registry .all_functions() .iter() - .any(|f| matches!(f.kind, FunctionKind::Query)); + .any(|f| matches!(f.kind, FunctionKind::Query | FunctionKind::Mutation)); let mut output = String::from("// @generated by FORGE - DO NOT EDIT\n"); output.push_str("export * from './types';\n"); output.push_str("export * from './api';\n"); output.push_str("export * from './stores';\n"); output.push_str("export * from './runes.svelte';\n"); - if has_queries { + if has_reactive { output.push_str("export * from './reactive.svelte';\n"); } if self.options.generate_auth_store { @@ -409,17 +411,32 @@ mod tests { } #[test] - fn generate_index_skips_reactive_when_only_mutations_present() { + fn generate_index_emits_reactive_when_only_mutations_present() { + // Reactive file is emitted for queries OR mutations, so the re-export + // must match. A mutation-only project still publishes `createX$`. let generator = TypeScriptGenerator::new("/tmp/forge"); let registry = SchemaRegistry::new(); registry.register_function(FunctionDef::mutation("create_user", RustType::String)); let index = generator.generate_index(®istry); assert!( - !index.contains("'./reactive.svelte'"), - "no queries => no reactive export" + index.contains("'./reactive.svelte'"), + "mutations must trigger reactive export" ); } + #[test] + fn generate_index_skips_reactive_when_no_queries_or_mutations() { + let generator = TypeScriptGenerator::new("/tmp/forge"); + let registry = SchemaRegistry::new(); + registry.register_function(FunctionDef::new( + "daily_cleanup", + FunctionKind::Cron, + RustType::Custom("()".into()), + )); + let index = generator.generate_index(®istry); + assert!(!index.contains("'./reactive.svelte'")); + } + #[test] fn generate_index_emits_auth_only_when_flag_set() { let registry = SchemaRegistry::new(); diff --git a/crates/forge-codegen/tests/snapshot.rs b/crates/forge-codegen/tests/snapshot.rs index c799411f..b6112281 100644 --- a/crates/forge-codegen/tests/snapshot.rs +++ b/crates/forge-codegen/tests/snapshot.rs @@ -78,6 +78,59 @@ fn output_is_deterministic() { assert_eq!(first, second, "codegen output must be deterministic"); } +/// Negative coverage: a `#[forge::model]` on a tuple struct must produce a +/// parser diagnostic with a clear message, not silently emit an empty +/// interface. A regression that drops the named-fields guard would otherwise +/// let downstream tests pass with an empty `Wrapper {}` interface. +#[test] +fn tuple_struct_model_is_rejected_with_clear_diagnostic() { + let src = r#" + #[forge::model] + pub struct Wrapper(pub String); + "#; + let src_dir = TempDir::new().expect("tempdir"); + fs::write(src_dir.path().join("handlers.rs"), src).expect("write fixture"); + + let outcome = parse_project(src_dir.path()).expect("parse_project"); + assert!( + !outcome.parse_failures.is_empty(), + "tuple struct must be rejected by the parser, got no failures" + ); + let (_, msg) = outcome + .parse_failures + .first() + .expect("at least one failure"); + assert!( + msg.contains("named fields"), + "diagnostic must explain the constraint, got: {msg}" + ); +} + +/// Same guard for unit structs marked as DTOs (serde derive). +#[test] +fn unit_struct_dto_is_rejected_with_clear_diagnostic() { + let src = r#" + #[derive(serde::Serialize, serde::Deserialize)] + pub struct Marker; + "#; + let src_dir = TempDir::new().expect("tempdir"); + fs::write(src_dir.path().join("handlers.rs"), src).expect("write fixture"); + + let outcome = parse_project(src_dir.path()).expect("parse_project"); + assert!( + !outcome.parse_failures.is_empty(), + "unit struct DTO must be rejected by the parser, got no failures" + ); + let (_, msg) = outcome + .parse_failures + .first() + .expect("at least one failure"); + assert!( + msg.contains("named fields"), + "diagnostic must explain the constraint, got: {msg}" + ); +} + fn run_fixture(source: &str, generate_auth: bool) -> String { let src_dir = TempDir::new().expect("tempdir"); fs::write(src_dir.path().join("handlers.rs"), source).expect("write fixture"); diff --git a/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap b/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap index ac244c3d..b3f82529 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__custom_args.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 33 expression: "run_fixture(include_str!(\"fixtures/custom_args.rs.txt\"), false)" --- @@ -261,7 +260,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] diff --git a/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap b/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap index 3aed3cf3..f386ed64 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__full_app.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 57 expression: "run_fixture(include_str!(\"fixtures/full_app.rs.txt\"), false)" --- @@ -325,7 +324,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -336,7 +335,7 @@ pub struct CleanupOutput { impl CleanupOutput { pub fn new(deleted_count: i64) -> Self { Self { - deleted_count: deleted_count, + deleted_count, } } } @@ -349,7 +348,7 @@ pub struct CleanupRequest { impl CleanupRequest { pub fn new(older_than_days: i32) -> Self { Self { - older_than_days: older_than_days, + older_than_days, } } } @@ -364,7 +363,7 @@ impl CreateTodoInput { pub fn new(title: impl Into, status: TodoStatus) -> Self { Self { title: title.into(), - status: status, + status, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap b/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap index a8dc95b3..d06848a6 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__full_app_with_auth.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 65 expression: "run_fixture(include_str!(\"fixtures/full_app.rs.txt\"), true)" --- @@ -483,7 +482,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -494,7 +493,7 @@ pub struct CleanupOutput { impl CleanupOutput { pub fn new(deleted_count: i64) -> Self { Self { - deleted_count: deleted_count, + deleted_count, } } } @@ -507,7 +506,7 @@ pub struct CleanupRequest { impl CleanupRequest { pub fn new(older_than_days: i32) -> Self { Self { - older_than_days: older_than_days, + older_than_days, } } } @@ -522,7 +521,7 @@ impl CreateTodoInput { pub fn new(title: impl Into, status: TodoStatus) -> Self { Self { title: title.into(), - status: status, + status, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap b/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap index c364c76c..72ff4cd8 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__jobs_and_workflows.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 41 expression: "run_fixture(include_str!(\"fixtures/jobs_and_workflows.rs.txt\"), false)" --- @@ -215,7 +214,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -228,7 +227,7 @@ impl ExportRequest { pub fn new(format: impl Into, include_archived: bool) -> Self { Self { format: format.into(), - include_archived: include_archived, + include_archived, } } } @@ -243,7 +242,7 @@ impl ExportResult { pub fn new(url: impl Into, row_count: i64) -> Self { Self { url: url.into(), - row_count: row_count, + row_count, } } } @@ -271,7 +270,7 @@ pub struct VerifyOutput { impl VerifyOutput { pub fn new(verified: bool) -> Self { Self { - verified: verified, + verified, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap b/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap index 8aaa023a..791ce6c8 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__models_and_enums.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 25 expression: "run_fixture(include_str!(\"fixtures/models_and_enums.rs.txt\"), false)" --- @@ -242,7 +241,7 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -266,7 +265,7 @@ impl UserSummary { Self { id: id.into(), email: email.into(), - role: role, + role, } } } diff --git a/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap b/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap index 961b38e4..b9e75109 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__primitives.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 17 expression: "run_fixture(include_str!(\"fixtures/primitives.rs.txt\"), false)" --- @@ -332,6 +331,6 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; diff --git a/crates/forge-codegen/tests/snapshots/snapshot__upload.snap b/crates/forge-codegen/tests/snapshots/snapshot__upload.snap index c66c8cda..fa354dc5 100644 --- a/crates/forge-codegen/tests/snapshots/snapshot__upload.snap +++ b/crates/forge-codegen/tests/snapshots/snapshot__upload.snap @@ -1,6 +1,5 @@ --- source: crates/forge-codegen/tests/snapshot.rs -assertion_line: 49 expression: "run_fixture(include_str!(\"fixtures/upload.rs.txt\"), false)" --- @@ -24,6 +23,7 @@ export * from './types'; export * from './api'; export * from './stores'; export * from './runes.svelte'; +export * from './reactive.svelte'; export { ForgeClient, ForgeClientError, createForgeClient, ForgeProvider } from '@forge-rs/svelte'; === ts/reactive.svelte.ts === @@ -215,14 +215,15 @@ pub use types::*; === dioxus/types.rs === // @generated by FORGE - DO NOT EDIT -#![allow(dead_code, unused_imports, clippy::redundant_field_names, clippy::too_many_arguments)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; use forge_dioxus::ForgeUpload; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct AvatarInput { pub name: String, + #[serde(skip)] pub file: ForgeUpload, } @@ -230,7 +231,7 @@ impl AvatarInput { pub fn new(name: impl Into, file: ForgeUpload) -> Self { Self { name: name.into(), - file: file, + file, } } } diff --git a/crates/forge-core/Cargo.toml b/crates/forge-core/Cargo.toml index 8bd28479..b566528b 100644 --- a/crates/forge-core/Cargo.toml +++ b/crates/forge-core/Cargo.toml @@ -34,6 +34,10 @@ testcontainers-modules = { workspace = true, optional = true } [features] testcontainers = ["dep:testcontainers", "dep:testcontainers-modules"] +# Unsafe escape hatches that bypass framework guard rails (circuit breaker, +# host blocklists, etc). Opt-in only — keep this off unless a specific +# integration genuinely needs raw access. +escape-hatches = [] [dev-dependencies] tokio-test = { workspace = true } diff --git a/crates/forge-core/src/config/auth.rs b/crates/forge-core/src/config/auth.rs index 3298b2a2..03214494 100644 --- a/crates/forge-core/src/config/auth.rs +++ b/crates/forge-core/src/config/auth.rs @@ -30,7 +30,7 @@ pub enum JwtAlgorithm { /// Rotate by adding the outgoing secret here with `valid_until` set one /// access-token TTL into the future, swap `jwt_secret` to the new value, /// then remove the entry once the window closes. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] pub struct LegacySecret { /// HMAC secret bytes (treated as opaque; min length is not re-enforced /// here — the active `jwt_secret` validation already covers minimum @@ -40,8 +40,17 @@ pub struct LegacySecret { pub valid_until: chrono::DateTime, } +impl std::fmt::Debug for LegacySecret { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LegacySecret") + .field("secret", &"***redacted***") + .field("valid_until", &self.valid_until) + .finish() + } +} + /// Authentication configuration. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[non_exhaustive] pub struct AuthConfig { /// Required for HS256. @@ -106,6 +115,45 @@ pub struct AuthConfig { /// are silently dropped at middleware construction. #[serde(default)] pub legacy_secrets: Vec, + + /// When `true` (default), browser clients (forge-svelte, forge-dioxus on + /// wasm) treat the refresh token as an `HttpOnly; Secure; SameSite=Strict` + /// cookie and do **not** persist it in JS-reachable storage. Your + /// `refresh` mutation should set the cookie on issue and clear it on + /// rotation/logout; the clients send it automatically via `credentials: + /// include`. + /// + /// Set to `false` only if you cannot serve the refresh endpoint from the + /// same registrable domain as the frontend, or for legacy clients that + /// must read the refresh token from a response body. + #[serde(default = "default_true")] + pub refresh_cookie: bool, +} + +impl std::fmt::Debug for AuthConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthConfig") + .field( + "jwt_secret", + &self.jwt_secret.as_ref().map(|_| "***redacted***"), + ) + .field("jwt_algorithm", &self.jwt_algorithm) + .field("jwt_issuer", &self.jwt_issuer) + .field("jwt_audience", &self.jwt_audience) + .field("access_token_ttl", &self.access_token_ttl) + .field("refresh_token_ttl", &self.refresh_token_ttl) + .field("jwks_url", &self.jwks_url) + .field("jwks_cache_ttl", &self.jwks_cache_ttl) + .field("session_ttl", &self.session_ttl) + .field("jwt_leeway", &self.jwt_leeway) + .field("audience_required", &self.audience_required) + .field("required_claims", &self.required_claims) + .field("session_cookie_ttl", &self.session_cookie_ttl) + .field("jwks_require_kid", &self.jwks_require_kid) + .field("legacy_secrets", &self.legacy_secrets) + .field("refresh_cookie", &self.refresh_cookie) + .finish() + } } impl Default for AuthConfig { @@ -126,6 +174,7 @@ impl Default for AuthConfig { session_cookie_ttl: None, jwks_require_kid: default_true(), legacy_secrets: Vec::new(), + refresh_cookie: true, } } } @@ -190,12 +239,24 @@ impl AuthConfig { } } JwtAlgorithm::RS256 => { - if self.jwks_url.is_none() { + let Some(url) = self.jwks_url.as_deref() else { return Err(ForgeError::config( "auth.jwks_url is required for RSA algorithms (RS256). \ Set auth.jwks_url to your identity provider's JWKS endpoint, \ or switch to HS256 and provide auth.jwt_secret for symmetric signing.", )); + }; + // Plain HTTP would let an on-path attacker substitute keys and + // mint arbitrary RS256 tokens. Loopback is allowed for local + // dev so test mocks don't need TLS termination. + if let Some(hostname) = crate::util::http_hostname(url) + && !crate::util::is_loopback_host(hostname) + { + return Err(ForgeError::config(format!( + "auth.jwks_url '{url}' uses plain HTTP. JWKS must be fetched over \ + HTTPS (or from loopback for local development) so an on-path \ + attacker cannot substitute signing keys." + ))); } } } diff --git a/crates/forge-core/src/config/database.rs b/crates/forge-core/src/config/database.rs index f808703f..2e65af18 100644 --- a/crates/forge-core/src/config/database.rs +++ b/crates/forge-core/src/config/database.rs @@ -11,7 +11,7 @@ use super::types::DurationStr; /// separation belongs at the worker level, not the connection level. The /// single-pool contention model and sizing formula are documented at the /// runtime side in `forge_runtime::pg::pool` module docs. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] #[non_exhaustive] pub struct DatabaseConfig { @@ -56,6 +56,24 @@ pub struct DatabaseConfig { pub test_before_acquire: bool, } +impl std::fmt::Debug for DatabaseConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let redacted_replicas: Vec<&str> = + self.replica_urls.iter().map(|_| "***redacted***").collect(); + f.debug_struct("DatabaseConfig") + .field("url", &"***redacted***") + .field("pool_size", &self.pool_size) + .field("pool_timeout", &self.pool_timeout) + .field("statement_timeout", &self.statement_timeout) + .field("replica_urls", &redacted_replicas) + .field("read_from_replica", &self.read_from_replica) + .field("replica_pool_size", &self.replica_pool_size) + .field("min_pool_size", &self.min_pool_size) + .field("test_before_acquire", &self.test_before_acquire) + .finish() + } +} + impl Default for DatabaseConfig { fn default() -> Self { Self { diff --git a/crates/forge-core/src/config/loader.rs b/crates/forge-core/src/config/loader.rs index e67428b7..a5571679 100644 --- a/crates/forge-core/src/config/loader.rs +++ b/crates/forge-core/src/config/loader.rs @@ -48,13 +48,19 @@ pub fn substitute_env_vars(content: &str) -> String { /// Parse `VAR-default` or `VAR:-default` into (name, optional default). /// Both forms behave identically (fallback when unset). `:-` is checked /// first so its `-` doesn't get matched by the plain `-` branch. +/// +/// For the bare `-` form, the split is taken at the LAST `-` so that +/// `${MY-NAMESPACE-VAR-fallback}` parses to name `MY-NAMESPACE-VAR` +/// (which then fails `is_valid_env_var_name` and the literal is +/// preserved) rather than silently substituting `$MY` with default +/// `NAMESPACE-VAR-fallback`. #[allow(clippy::indexing_slicing)] // All indices from str::find(); guaranteed valid. fn parse_var_with_default(inner: &str) -> (&str, Option<&str>) { if let Some(pos) = inner.find(":-") { return (&inner[..pos], Some(&inner[pos + 2..])); } - if let Some(pos) = inner.find('-') { - return (&inner[..pos], Some(&inner[pos + 1..])); + if let Some((name, default)) = inner.rsplit_once('-') { + return (name, Some(default)); } (inner, None) } @@ -121,4 +127,96 @@ mod tests { let result = substitute_env_vars(input); assert_eq!(result, r#"val = """#); } + + #[test] + fn plain_braced_var_substituted_when_set() { + // `${VAR}` with no default, variable present -> raw value. + unsafe { std::env::set_var("TEST_FORGE_PLAIN_SET", "postgres://db") }; + + let input = r#"url = "${TEST_FORGE_PLAIN_SET}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"url = "postgres://db""#); + + unsafe { std::env::remove_var("TEST_FORGE_PLAIN_SET") }; + } + + #[test] + fn set_var_wins_over_dash_default() { + unsafe { std::env::set_var("TEST_FORGE_DASH_SET", "real") }; + + let input = r#"x = "${TEST_FORGE_DASH_SET-fallback}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"x = "real""#); + + unsafe { std::env::remove_var("TEST_FORGE_DASH_SET") }; + } + + #[test] + fn dash_split_takes_last_dash() { + // `parse_var_with_default` splits on the LAST `-`, so the name here is + // "TEST_FORGE_NS_VAR" and the default is "tail". Name is valid and unset, + // so the default wins. + unsafe { std::env::remove_var("TEST_FORGE_NS_VAR") }; + + let input = r#"v = "${TEST_FORGE_NS_VAR-tail}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "tail""#); + } + + #[test] + fn invalid_var_name_from_multi_dash_preserves_literal() { + // Last-dash split yields name "MY-NAMESPACE-VAR", which fails + // `is_valid_env_var_name` (contains '-'). The whole `${...}` is kept + // verbatim rather than silently substituting a partial match. + unsafe { std::env::remove_var("MY") }; + + let input = r#"v = "${MY-NAMESPACE-VAR-fallback}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "${MY-NAMESPACE-VAR-fallback}""#); + } + + #[test] + fn lowercase_var_name_is_invalid_and_preserved() { + // Env var names must be uppercase/underscore-led; a lowercase name is + // not treated as a variable. + let input = r#"v = "${lowercase}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "${lowercase}""#); + } + + #[test] + fn unterminated_brace_kept_verbatim() { + // No closing `}` -> the remainder is emitted as-is, no panic. + let input = r#"v = "${TEST_FORGE_UNTERMINATED"#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "${TEST_FORGE_UNTERMINATED"#); + } + + #[test] + fn colon_dash_split_is_preferred_over_plain_dash() { + // `:-` is checked before plain `-`, so the name is the part before `:-` + // and the `-` inside the default is left intact. + unsafe { std::env::remove_var("TEST_FORGE_CDASH") }; + + let input = r#"v = "${TEST_FORGE_CDASH:-a-b-c}""#; + let result = substitute_env_vars(input); + assert_eq!(result, r#"v = "a-b-c""#); + } + + #[test] + fn parse_var_with_default_forms() { + assert_eq!(parse_var_with_default("VAR"), ("VAR", None)); + assert_eq!( + parse_var_with_default("VAR-default"), + ("VAR", Some("default")) + ); + assert_eq!( + parse_var_with_default("VAR:-default"), + ("VAR", Some("default")) + ); + // Last-dash split. + assert_eq!(parse_var_with_default("A-B-C"), ("A-B", Some("C"))); + // Colon-dash beats plain dash and keeps trailing dashes in the default. + assert_eq!(parse_var_with_default("V:-a-b"), ("V", Some("a-b"))); + } } diff --git a/crates/forge-core/src/config/mod.rs b/crates/forge-core/src/config/mod.rs index 3b1eda99..8c6bc450 100644 --- a/crates/forge-core/src/config/mod.rs +++ b/crates/forge-core/src/config/mod.rs @@ -104,9 +104,6 @@ pub struct ForgeConfig { #[serde(default)] pub realtime: RealtimeConfig, - - #[serde(default)] - pub email: crate::email::EmailConfig, } impl ForgeConfig { @@ -218,6 +215,39 @@ impl ForgeConfig { ))); } + let ratio = self.observability.sampling_ratio; + if !ratio.is_finite() || !(0.0..=1.0).contains(&ratio) { + return Err(ForgeError::config(format!( + "observability.sampling_ratio must be a finite number in [0.0, 1.0], got {ratio}" + ))); + } + + if let Some(path) = &self.signals.geoip_db_path + && !path.is_empty() + { + let p = std::path::Path::new(path); + if !p.exists() { + return Err(ForgeError::config(format!( + "signals.geoip_db_path points to '{path}' which does not exist" + ))); + } + if std::fs::File::open(p).is_err() { + return Err(ForgeError::config(format!( + "signals.geoip_db_path '{path}' exists but is not readable" + ))); + } + } + + if self.gateway.cors_enabled + && self.gateway.cors_origins.iter().any(|o| o == "*") + && self.gateway.cors_origins.len() == 1 + { + tracing::warn!( + "gateway.cors_origins = [\"*\"] allows any origin; browsers reject \ + wildcard with credentialed requests. Set explicit origins for production." + ); + } + for entry in &self.gateway.trusted_proxies { if entry.parse::().is_err() && entry.parse::().is_err() { @@ -251,7 +281,6 @@ impl ForgeConfig { signals: SignalsConfig::default(), rate_limit: RateLimitSettings::default(), realtime: RealtimeConfig::default(), - email: crate::email::EmailConfig::default(), } } } @@ -821,6 +850,142 @@ mod tests { assert_eq!(entry.valid_until.to_rfc3339(), "2099-01-01T00:00:00+00:00"); } + #[test] + fn validate_rejects_invalid_trusted_proxy_entry() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [gateway] + trusted_proxies = ["not-an-ip"] + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("trusted_proxies") && err.contains("not-an-ip"), + "expected trusted_proxies rejection, got: {err}" + ); + } + + #[test] + fn validate_accepts_ip_and_cidr_trusted_proxies() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [gateway] + trusted_proxies = ["10.0.0.1", "10.0.0.0/8", "::1", "fd00::/8"] + "#; + assert!(ForgeConfig::parse_toml(toml).is_ok()); + } + + #[test] + fn validate_rejects_sampling_ratio_above_one() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [observability] + sampling_ratio = 1.5 + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("sampling_ratio") && err.contains("[0.0, 1.0]"), + "expected sampling_ratio bound error, got: {err}" + ); + } + + #[test] + fn validate_rejects_negative_sampling_ratio() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [observability] + sampling_ratio = -0.1 + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("sampling_ratio"), + "expected sampling_ratio bound error, got: {err}" + ); + } + + #[test] + fn validate_accepts_sampling_ratio_boundaries() { + for ratio in ["0.0", "1.0", "0.5"] { + let toml = format!( + r#" + [database] + url = "postgres://localhost/test" + [observability] + sampling_ratio = {ratio} + "# + ); + assert!( + ForgeConfig::parse_toml(&toml).is_ok(), + "ratio {ratio} should validate" + ); + } + } + + #[test] + fn validate_rejects_debounce_quiet_window_exceeding_max_wait() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [realtime] + debounce_quiet_window = "500ms" + debounce_max_wait = "200ms" + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("debounce_quiet_window") && err.contains("debounce_max_wait"), + "expected debounce ordering error, got: {err}" + ); + } + + #[test] + fn validate_accepts_debounce_quiet_window_equal_to_max_wait() { + // quiet == max is allowed; only quiet > max is rejected. + let toml = r#" + [database] + url = "postgres://localhost/test" + [realtime] + debounce_quiet_window = "200ms" + debounce_max_wait = "200ms" + "#; + assert!(ForgeConfig::parse_toml(toml).is_ok()); + } + + #[test] + fn validate_rejects_cors_origin_without_scheme() { + let toml = r#" + [database] + url = "postgres://localhost/test" + [gateway] + cors_enabled = true + cors_origins = ["example.com"] + "#; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("http://") && err.contains("https://"), + "expected scheme-required error, got: {err}" + ); + } + + #[test] + fn validate_rejects_cors_origin_with_control_char() { + // A control char in an origin would corrupt the response header. + let toml = " + [database] + url = \"postgres://localhost/test\" + [gateway] + cors_enabled = true + cors_origins = [\"https://exa\tmple.com\"] + "; + let err = ForgeConfig::parse_toml(toml).unwrap_err().to_string(); + assert!( + err.contains("invalid origin") || err.contains("valid HTTP header"), + "expected invalid-origin error, got: {err}" + ); + } + #[test] fn realtime_quota_fields_parse_and_enforce() { let toml = r#" diff --git a/crates/forge-core/src/config/security.rs b/crates/forge-core/src/config/security.rs index a7930b22..0460402b 100644 --- a/crates/forge-core/src/config/security.rs +++ b/crates/forge-core/src/config/security.rs @@ -3,8 +3,19 @@ use serde::{Deserialize, Serialize}; /// Security configuration. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Clone, Serialize, Deserialize, Default)] #[non_exhaustive] pub struct SecurityConfig { pub secret_key: Option, } + +impl std::fmt::Debug for SecurityConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SecurityConfig") + .field( + "secret_key", + &self.secret_key.as_ref().map(|_| "***redacted***"), + ) + .finish() + } +} diff --git a/crates/forge-core/src/config/signals.rs b/crates/forge-core/src/config/signals.rs index 980158a0..8954c7b7 100644 --- a/crates/forge-core/src/config/signals.rs +++ b/crates/forge-core/src/config/signals.rs @@ -7,6 +7,10 @@ use serde::{Deserialize, Serialize}; use super::default_true; use super::types::DurationStr; +fn default_false() -> bool { + false +} + /// Signals configuration for built-in product analytics and frontend diagnostics. /// /// Captures user behavior, acquisition channels, feature usage, and frontend @@ -15,7 +19,10 @@ use super::types::DurationStr; #[non_exhaustive] pub struct SignalsConfig { /// Enable the signals pipeline (event ingestion, auto-capture, dashboards). - #[serde(default = "default_true")] + /// + /// Off by default so new projects ship without product analytics enabled. + /// Set `signals.enabled = true` in forge.toml to opt in. + #[serde(default = "default_false")] pub enabled: bool, /// Auto-capture RPC calls as events without user code. @@ -68,12 +75,19 @@ pub struct SignalsConfig { /// database provides country-level resolution with zero configuration. #[serde(default)] pub geoip_db_path: Option, + + /// Per-IP request ceiling for the `/signal` endpoint, measured over a + /// rolling 60-second window. Generous enough to absorb legitimate bursts + /// (page view + web-vital flush + a handful of tracked events on a + /// navigation) while still capping runaway clients. + #[serde(default = "default_rate_limit_per_minute")] + pub rate_limit_per_minute: u32, } impl Default for SignalsConfig { fn default() -> Self { Self { - enabled: true, + enabled: false, auto_capture: true, diagnostics: true, session_timeout: default_session_timeout(), @@ -85,10 +99,15 @@ impl Default for SignalsConfig { excluded_functions: Vec::new(), bot_detection: true, geoip_db_path: None, + rate_limit_per_minute: default_rate_limit_per_minute(), } } } +fn default_rate_limit_per_minute() -> u32 { + 600 +} + fn default_session_timeout() -> DurationStr { DurationStr::new(Duration::from_secs(1800)) } @@ -117,7 +136,7 @@ mod tests { #[tokio::test] async fn default_config_has_correct_values() { let config = SignalsConfig::default(); - assert!(config.enabled); + assert!(!config.enabled); assert!(config.auto_capture); assert!(config.diagnostics); assert_eq!(config.session_timeout.as_secs(), 1800); @@ -141,7 +160,7 @@ mod tests { let from_table: Wrapper = toml::from_str("[signals]").unwrap(); for config in [from_empty, from_table.signals] { - assert!(config.enabled); + assert!(!config.enabled); assert!(config.auto_capture); assert!(config.diagnostics); assert_eq!(config.session_timeout.as_secs(), 1800); diff --git a/crates/forge-core/src/context.rs b/crates/forge-core/src/context.rs index aea72367..8f813a39 100644 --- a/crates/forge-core/src/context.rs +++ b/crates/forge-core/src/context.rs @@ -87,7 +87,7 @@ impl HandlerContext for crate::function::MutationContext { // MutationContext::tx() returns DbConn, not ForgeDb. // For HandlerContext we expose the pool-backed ForgeDb view, which // intentionally bypasses the active transaction. - crate::function::ForgeDb::from_pool(self.bypass_pool()) + crate::function::ForgeDb::from_pool(self.pool_outside_transaction()) } fn db_conn(&self) -> DbConn<'_> { diff --git a/crates/forge-core/src/cron/schedule.rs b/crates/forge-core/src/cron/schedule.rs index b833ef8d..c9ff8d70 100644 --- a/crates/forge-core/src/cron/schedule.rs +++ b/crates/forge-core/src/cron/schedule.rs @@ -35,6 +35,15 @@ impl CronSchedule { }) } + /// Validate a timezone string at registration time. Returns an error when + /// the timezone is not recognised so misconfigured crons fail loudly at + /// startup instead of silently never firing. + pub fn validate_timezone(timezone: &str) -> Result<(), CronParseError> { + timezone.parse::().map(|_| ()).map_err(|e| { + CronParseError::InvalidExpression(format!("invalid timezone '{timezone}': {e}")) + }) + } + /// Create a cron schedule from an expression that was already validated at compile time. /// /// Falls back to a non-firing schedule if parsing somehow fails, which cannot happen @@ -96,12 +105,17 @@ impl CronSchedule { return vec![]; }; - let local_start = start.with_timezone(&tz); + // `cron::Schedule::after` is exclusive of the boundary. Subtract one + // second so a scheduled tick that lands exactly on `start` is still + // emitted — otherwise a 1 s scheduler poll can drop a tick whose + // moment coincides with the window edge. + let local_start = start.with_timezone(&tz) - chrono::Duration::seconds(1); let local_end = end.with_timezone(&tz); schedule .after(&local_start) .take_while(|dt| *dt <= local_end) + .filter(|dt| *dt >= start.with_timezone(&tz)) .map(|dt| dt.with_timezone(&Utc)) .collect() } diff --git a/crates/forge-core/src/email/mod.rs b/crates/forge-core/src/email/mod.rs deleted file mode 100644 index 3b4018ab..00000000 --- a/crates/forge-core/src/email/mod.rs +++ /dev/null @@ -1,158 +0,0 @@ -//! Email sending trait and types. -//! -//! Defines the `EmailSender` trait used by handler contexts via `ctx.email()`. -//! The runtime provides concrete implementations (SMTP, HTTP-based providers). - -use std::future::Future; -use std::pin::Pin; - -use crate::error::Result; - -/// An email message. -#[derive(Debug, Clone)] -pub struct Email { - /// Overrides the default `from` in config if set. - pub from: Option, - pub to: Vec, - pub cc: Vec, - pub bcc: Vec, - pub subject: String, - pub text: Option, - pub html: Option, - pub reply_to: Option, -} - -impl Email { - /// Create a new email to a single recipient. - pub fn to(recipient: impl Into) -> EmailBuilder { - EmailBuilder { - email: Self { - from: None, - to: vec![recipient.into()], - cc: Vec::new(), - bcc: Vec::new(), - subject: String::new(), - text: None, - html: None, - reply_to: None, - }, - } - } -} - -/// Builder for constructing email messages. -pub struct EmailBuilder { - email: Email, -} - -impl EmailBuilder { - pub fn to(mut self, recipient: impl Into) -> Self { - self.email.to.push(recipient.into()); - self - } - - pub fn from(mut self, sender: impl Into) -> Self { - self.email.from = Some(sender.into()); - self - } - - pub fn cc(mut self, recipient: impl Into) -> Self { - self.email.cc.push(recipient.into()); - self - } - - pub fn bcc(mut self, recipient: impl Into) -> Self { - self.email.bcc.push(recipient.into()); - self - } - - pub fn subject(mut self, subject: impl Into) -> Self { - self.email.subject = subject.into(); - self - } - - pub fn text(mut self, body: impl Into) -> Self { - self.email.text = Some(body.into()); - self - } - - pub fn html(mut self, body: impl Into) -> Self { - self.email.html = Some(body.into()); - self - } - - pub fn reply_to(mut self, address: impl Into) -> Self { - self.email.reply_to = Some(address.into()); - self - } - - pub fn build(self) -> Email { - self.email - } -} - -/// Trait for sending emails from handler contexts. -/// -/// Implemented by the runtime for SMTP and HTTP-based providers (Resend, SES). -/// Mocked in test contexts. -pub trait EmailSender: Send + Sync + 'static { - /// Send an email. Returns the provider's message ID on success. - fn send<'a>( - &'a self, - email: &'a Email, - ) -> Pin> + Send + 'a>>; -} - -/// Email configuration from forge.toml. -#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] -#[serde(default)] -pub struct EmailConfig { - pub enabled: bool, - /// Provider: "smtp", "resend", "ses", "log" (development). - pub provider: String, - /// Default sender address. - pub from: String, - pub smtp_host: Option, - /// Default 587. - pub smtp_port: Option, - /// Env var containing the API key or SMTP password. - pub secret_env: Option, -} - -#[cfg(test)] -#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)] -mod tests { - use super::*; - - #[test] - fn email_builder_creates_message() { - let email = Email::to("user@example.com") - .from("noreply@app.com") - .subject("Hello") - .text("Hi there") - .html("

Hi there

") - .cc("cc@example.com") - .bcc("bcc@example.com") - .reply_to("reply@app.com") - .build(); - - assert_eq!(email.to, vec!["user@example.com"]); - assert_eq!(email.from.as_deref(), Some("noreply@app.com")); - assert_eq!(email.subject, "Hello"); - assert_eq!(email.text.as_deref(), Some("Hi there")); - assert_eq!(email.html.as_deref(), Some("

Hi there

")); - assert_eq!(email.cc, vec!["cc@example.com"]); - assert_eq!(email.bcc, vec!["bcc@example.com"]); - assert_eq!(email.reply_to.as_deref(), Some("reply@app.com")); - } - - #[test] - fn email_builder_multiple_recipients() { - let email = Email::to("a@example.com") - .to("b@example.com") - .subject("Test") - .build(); - - assert_eq!(email.to.len(), 2); - } -} diff --git a/crates/forge-core/src/error.rs b/crates/forge-core/src/error.rs index e095fc4e..f755cc63 100644 --- a/crates/forge-core/src/error.rs +++ b/crates/forge-core/src/error.rs @@ -15,8 +15,17 @@ pub enum ForgeError { source: Option>, }, - #[error("Database error: {0}")] - Database(#[from] sqlx::Error), + /// Wraps the inner [`sqlx::Error`] without rendering it in `Display`. + /// The raw sqlx error (which may contain constraint names, schema names, + /// or bound parameter previews) is reachable via [`std::error::Error::source`] + /// for structured logging, but the public `Display` impl emits a generic + /// "database error" so it is safe to surface in API responses. + #[error("database error")] + Database( + #[source] + #[from] + sqlx::Error, + ), #[error("Job cancelled: {0}")] JobCancelled(String), @@ -158,10 +167,31 @@ impl ForgeError { } pub fn is_retryable(&self) -> bool { - matches!( - self, - Self::ServiceUnavailable(_) | Self::Timeout(_) | Self::RateLimitExceeded { .. } - ) + match self { + Self::ServiceUnavailable(_) | Self::Timeout(_) | Self::RateLimitExceeded { .. } => true, + Self::Database(err) => is_transient_sqlx_error(err), + _ => false, + } + } +} + +/// Heuristic for sqlx errors that are safe-to-retry transient failures: +/// pool checkout timeouts, dropped or closed connections, and IO errors +/// against the database socket. Logical errors (constraint violations, +/// type mismatches, missing rows) intentionally do not retry. +fn is_transient_sqlx_error(err: &sqlx::Error) -> bool { + match err { + sqlx::Error::PoolTimedOut | sqlx::Error::PoolClosed | sqlx::Error::WorkerCrashed => true, + sqlx::Error::Io(_) => true, + sqlx::Error::Database(db_err) => { + // PostgreSQL connection_exception family (08xxx) and + // statement_timeout (57014) are transient. + db_err + .code() + .map(|c| c.starts_with("08") || c == "57014" || c == "57P03") + .unwrap_or(false) + } + _ => false, } } @@ -180,10 +210,12 @@ impl From for ForgeError { crate::http::CircuitBreakerError::Request(err) if err.is_timeout() => { ForgeError::Timeout(err.to_string()) } - crate::http::CircuitBreakerError::Request(err) => ForgeError::Internal { - context: "HTTP request failed".to_string(), - source: Some(Box::new(err)), - }, + crate::http::CircuitBreakerError::Request(err) => { + // Non-timeout reqwest failures (connection refused, DNS, + // TLS) are upstream-side problems, not local bugs. Map to + // 503 so clients understand it's worth retrying. + ForgeError::ServiceUnavailable(format!("HTTP request failed: {err}")) + } crate::http::CircuitBreakerError::PrivateHostBlocked(host) => { ForgeError::Forbidden(format!("Outbound request to private host '{host}' blocked")) } @@ -210,7 +242,7 @@ mod tests { ), ( ForgeError::Database(sqlx::Error::RowNotFound), - "Database error: no rows returned by a query that expected to return at least one row", + "database error", ), ( ForgeError::JobCancelled("user request".into()), @@ -438,4 +470,143 @@ mod tests { assert_eq!(err.to_string(), "Internal error: connection failed"); assert!(err.source().is_some(), "source should be preserved"); } + + /// Minimal `sqlx::error::DatabaseError` carrying a fixed SQLSTATE code so we + /// can drive `is_transient_sqlx_error`'s `Database` arm without a live PG. + #[derive(Debug)] + struct FakeDbError { + code: Option, + unique: bool, + } + + impl FakeDbError { + fn with_code(code: &str) -> Self { + Self { + code: Some(code.to_string()), + unique: false, + } + } + + fn unique_violation() -> Self { + // 23505 is PG's unique_violation. A logical constraint failure must + // not be treated as transient. + Self { + code: Some("23505".to_string()), + unique: true, + } + } + + fn no_code() -> Self { + Self { + code: None, + unique: false, + } + } + } + + impl std::fmt::Display for FakeDbError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "fake db error ({:?})", self.code) + } + } + + impl std::error::Error for FakeDbError {} + + impl sqlx::error::DatabaseError for FakeDbError { + fn message(&self) -> &str { + "fake db error" + } + + fn code(&self) -> Option> { + self.code.as_deref().map(std::borrow::Cow::Borrowed) + } + + fn as_error(&self) -> &(dyn std::error::Error + Send + Sync + 'static) { + self + } + + fn as_error_mut(&mut self) -> &mut (dyn std::error::Error + Send + Sync + 'static) { + self + } + + fn into_error(self: Box) -> Box { + self + } + + fn kind(&self) -> sqlx::error::ErrorKind { + if self.unique { + sqlx::error::ErrorKind::UniqueViolation + } else { + sqlx::error::ErrorKind::Other + } + } + } + + fn db(err: FakeDbError) -> sqlx::Error { + sqlx::Error::Database(Box::new(err)) + } + + #[test] + fn transient_sqlx_pool_and_worker_errors_retry() { + assert!(is_transient_sqlx_error(&sqlx::Error::PoolTimedOut)); + assert!(is_transient_sqlx_error(&sqlx::Error::PoolClosed)); + assert!(is_transient_sqlx_error(&sqlx::Error::WorkerCrashed)); + } + + #[test] + fn transient_sqlx_io_error_retries() { + let io = std::io::Error::new(std::io::ErrorKind::ConnectionReset, "reset"); + assert!(is_transient_sqlx_error(&sqlx::Error::Io(io))); + } + + #[test] + fn transient_sqlx_connection_family_08xxx_retries() { + // 08006 connection_failure, 08003 connection_does_not_exist, etc. + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "08006" + )))); + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "08003" + )))); + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "08000" + )))); + } + + #[test] + fn transient_sqlx_statement_timeout_and_admin_shutdown_retry() { + // 57014 query_canceled (statement_timeout), 57P03 cannot_connect_now. + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "57014" + )))); + assert!(is_transient_sqlx_error(&db(FakeDbError::with_code( + "57P03" + )))); + } + + #[test] + fn non_transient_sqlx_logical_errors_do_not_retry() { + // Constraint violation, row-not-found, and a missing code are all + // logical/non-retryable. + assert!(!is_transient_sqlx_error(&db( + FakeDbError::unique_violation() + ))); + assert!(!is_transient_sqlx_error(&db(FakeDbError::with_code( + "23503" + )))); + assert!(!is_transient_sqlx_error(&db(FakeDbError::no_code()))); + assert!(!is_transient_sqlx_error(&sqlx::Error::RowNotFound)); + // 57 family that isn't a retry code (e.g. 57000 operator_intervention). + assert!(!is_transient_sqlx_error(&db(FakeDbError::with_code( + "57000" + )))); + } + + #[test] + fn is_retryable_database_delegates_to_transient_check() { + assert!(ForgeError::Database(db(FakeDbError::with_code("08006"))).is_retryable()); + assert!(ForgeError::Database(db(FakeDbError::with_code("57014"))).is_retryable()); + assert!(!ForgeError::Database(db(FakeDbError::unique_violation())).is_retryable()); + assert!(!ForgeError::Database(sqlx::Error::RowNotFound).is_retryable()); + } } diff --git a/crates/forge-core/src/function/context.rs b/crates/forge-core/src/function/context.rs index 9da60a5c..1f881ce0 100644 --- a/crates/forge-core/src/function/context.rs +++ b/crates/forge-core/src/function/context.rs @@ -53,6 +53,11 @@ use crate::auth::Claims; use crate::env::{EnvAccess, EnvProvider, RealEnvProvider}; use crate::http::CircuitBreakerClient; +/// Default outbound HTTP timeout applied by [`MutationContext::http`] when +/// no per-handler `timeout` is configured. Keeps a misbehaving downstream +/// from hanging an RPC indefinitely. +pub const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); + /// Token issuer for signing JWTs. /// /// Implemented by the runtime when HMAC auth is configured. @@ -420,6 +425,13 @@ impl<'c> sqlx::Executor<'c> for &'c mut ForgeConn<'_> { } /// Authentication context available to all functions. +/// +/// KNOWN ISSUE: `authenticated` and `user_id` encode overlapping state — +/// an authenticated subject without a UUID (Firebase, Clerk) is represented +/// as `authenticated = true` with `user_id = None`. Constructors are the +/// only places that set these; each one preserves the invariant +/// `authenticated == (user_id.is_some() || claims.contains_key("sub"))`. +/// Collapsing into a single sum type is tracked for a future cleanup. #[derive(Debug, Clone)] #[non_exhaustive] pub struct AuthContext { @@ -434,13 +446,15 @@ pub struct AuthContext { impl AuthContext { /// Create an unauthenticated context. pub fn unauthenticated() -> Self { - Self { + let ctx = Self { user_id: None, roles: Vec::new(), claims: HashMap::new(), authenticated: false, token_exp: None, - } + }; + debug_assert!(!ctx.authenticated && ctx.user_id.is_none()); + ctx } /// Create an authenticated context with a UUID user ID. @@ -449,13 +463,15 @@ impl AuthContext { roles: Vec, claims: HashMap, ) -> Self { - Self { + let ctx = Self { user_id: Some(user_id), roles, claims, authenticated: true, token_exp: None, - } + }; + debug_assert!(ctx.authenticated && ctx.user_id.is_some()); + ctx } /// Create an authenticated context without requiring a UUID user ID. @@ -467,13 +483,15 @@ impl AuthContext { roles: Vec, claims: HashMap, ) -> Self { - Self { + let ctx = Self { user_id: None, roles, claims, authenticated: true, token_exp: None, - } + }; + debug_assert!(ctx.authenticated && ctx.user_id.is_none()); + ctx } /// Attach the JWT expiry timestamp to this context. @@ -853,7 +871,9 @@ pub struct MutationContext { pub request: RequestMetadata, db_pool: sqlx::PgPool, http_client: CircuitBreakerClient, - /// `None` means unlimited. + /// `None` means "apply the default ceiling" ([`DEFAULT_HTTP_TIMEOUT`]). + /// A caller that genuinely needs an unbounded outbound request should + /// build its own [`reqwest::Client`] outside the framework. http_timeout: Option, job_dispatch: Option>, workflow_dispatch: Option>, @@ -867,7 +887,6 @@ pub struct MutationContext { /// 0 = unlimited. max_jobs_per_request: usize, kv: Option>, - email_sender: Option>, } impl MutationContext { @@ -888,7 +907,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, } } @@ -916,7 +934,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, } } @@ -945,7 +962,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, } } @@ -986,7 +1002,6 @@ impl MutationContext { dispatched_job_count: Arc::new(AtomicUsize::new(0)), max_jobs_per_request: 0, kv: None, - email_sender: None, }; (ctx, tx_handle) @@ -1009,18 +1024,6 @@ impl MutationContext { .ok_or_else(|| crate::error::ForgeError::internal("KV store not available")) } - /// Attach an email sender. - pub fn set_email(&mut self, sender: Arc) { - self.email_sender = Some(sender); - } - - /// Access the email sender. - pub fn email(&self) -> crate::error::Result<&dyn crate::email::EmailSender> { - self.email_sender - .as_deref() - .ok_or_else(|| crate::error::ForgeError::internal("Email not configured")) - } - pub fn is_transactional(&self) -> bool { self.tx.is_some() } @@ -1045,14 +1048,15 @@ impl MutationContext { /// Direct pool access that **bypasses the active transaction**. /// - /// In a transactional mutation, this returns the raw [`sqlx::PgPool`] and - /// any queries run on it execute outside the transaction — so they will - /// not see uncommitted writes and will not be rolled back if the mutation - /// fails. Prefer [`MutationContext::conn`] or [`MutationContext::db`] for - /// anything that should participate in the transaction. Reach for this - /// only for operations that fundamentally cannot run inside a transaction - /// (e.g. `LISTEN`/`NOTIFY`, advisory locks, or background pool work). - pub fn bypass_pool(&self) -> &sqlx::PgPool { + /// WARNING: in a transactional mutation, this returns the raw + /// [`sqlx::PgPool`] and any queries run on it execute outside the + /// transaction — they will not see uncommitted writes and will not be + /// rolled back if the mutation fails. Prefer + /// [`MutationContext::conn`] or [`MutationContext::db`] for anything + /// that should participate in the transaction. Reach for this only for + /// operations that fundamentally cannot run inside a transaction (e.g. + /// `LISTEN`/`NOTIFY`, advisory locks, or background pool work). + pub fn pool_outside_transaction(&self) -> &sqlx::PgPool { &self.db_pool } @@ -1087,10 +1091,15 @@ impl MutationContext { /// declared an explicit `timeout`, that timeout is also applied to outbound /// HTTP requests unless the request overrides it. pub fn http(&self) -> crate::http::HttpClient { - self.http_client.with_timeout(self.http_timeout) + let timeout = self.http_timeout.or(Some(DEFAULT_HTTP_TIMEOUT)); + self.http_client.with_timeout(timeout) } - /// Get the raw reqwest client, bypassing circuit breaker execution. + /// Get the raw reqwest client, bypassing circuit breaker execution, + /// host blocklist, and retries. + /// + /// Gated behind the `escape-hatches` feature; prefer [`Self::http`]. + #[cfg(feature = "escape-hatches")] pub fn raw_http(&self) -> &reqwest::Client { self.http_client.inner() } @@ -1133,6 +1142,35 @@ impl MutationContext { self.max_jobs_per_request = limit; } + /// Atomically reserve a job-dispatch slot under `max_jobs_per_request`. + /// + /// Uses a `compare_exchange` loop so concurrent dispatches (e.g. via + /// `join_all`) cannot briefly exceed the limit. Returns an error when + /// the cap has been reached. + fn reserve_job_slot(&self) -> crate::error::Result<()> { + if self.max_jobs_per_request == 0 { + return Ok(()); + } + let mut current = self.dispatched_job_count.load(Ordering::Acquire); + loop { + if current >= self.max_jobs_per_request { + return Err(crate::error::ForgeError::Validation(format!( + "max_jobs_per_request limit of {} exceeded", + self.max_jobs_per_request + ))); + } + match self.dispatched_job_count.compare_exchange( + current, + current + 1, + Ordering::AcqRel, + Ordering::Acquire, + ) { + Ok(_) => return Ok(()), + Err(observed) => current = observed, + } + } + } + /// Issue a signed JWT from the given claims. /// /// Only available when HMAC auth is configured in `forge.toml`. @@ -1252,18 +1290,7 @@ impl MutationContext { job_type: &str, args: T, ) -> crate::error::Result { - if self.max_jobs_per_request > 0 { - let count = self.dispatched_job_count.fetch_add(1, Ordering::Relaxed); - if count >= self.max_jobs_per_request { - // Undo the increment so repeated calls after the limit give a - // consistent count rather than growing without bound. - self.dispatched_job_count.fetch_sub(1, Ordering::Relaxed); - return Err(crate::error::ForgeError::Validation(format!( - "max_jobs_per_request limit of {} exceeded", - self.max_jobs_per_request - ))); - } - } + self.reserve_job_slot()?; let args_json = serde_json::to_value(args)?; let dispatcher = self @@ -1312,16 +1339,7 @@ impl MutationContext { args: T, scheduled_at: DateTime, ) -> crate::error::Result { - if self.max_jobs_per_request > 0 { - let count = self.dispatched_job_count.fetch_add(1, Ordering::Relaxed); - if count >= self.max_jobs_per_request { - self.dispatched_job_count.fetch_sub(1, Ordering::Relaxed); - return Err(crate::error::ForgeError::Validation(format!( - "max_jobs_per_request limit of {} exceeded", - self.max_jobs_per_request - ))); - } - } + self.reserve_job_slot()?; let args_json = serde_json::to_value(args)?; let dispatcher = self diff --git a/crates/forge-core/src/job/traits.rs b/crates/forge-core/src/job/traits.rs index 52a3bf99..326efb8d 100644 --- a/crates/forge-core/src/job/traits.rs +++ b/crates/forge-core/src/job/traits.rs @@ -211,14 +211,51 @@ impl Default for RetryConfig { } impl RetryConfig { + /// Compute the retry delay for `attempt`. Adds ±25% jitter to the base + /// strategy so a fleet of jobs retrying after the same upstream outage + /// doesn't align to the same wall-clock second and re-thunder the + /// recovering dependency (#14 in issues doc). Also uses `checked_pow` + /// to avoid overflow at attempt 33+ (#21 in issues doc). pub fn calculate_backoff(&self, attempt: u32) -> Duration { let base = Duration::from_secs(1); - let backoff = match self.backoff { + let base_backoff = match self.backoff { BackoffStrategy::Fixed => base, - BackoffStrategy::Linear => base * attempt, - BackoffStrategy::Exponential => base * 2u32.pow(attempt.saturating_sub(1)), + BackoffStrategy::Linear => base.saturating_mul(attempt.max(1)), + BackoffStrategy::Exponential => { + let exp = attempt.saturating_sub(1); + let factor = 2u32.checked_pow(exp).unwrap_or(u32::MAX); + base.saturating_mul(factor) + } }; - backoff.min(self.max_backoff) + let capped = base_backoff.min(self.max_backoff); + Self::apply_jitter(capped) + } + + /// Apply ±25% jitter using a process-wide nanosecond clock as entropy. + /// No `rand`/`fastrand` dependency in the workspace; an Instant-derived + /// LCG is sufficient for desynchronizing retries. + fn apply_jitter(d: Duration) -> Duration { + let nanos = d.as_nanos(); + if nanos == 0 { + return d; + } + let now_ns = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|x| x.as_nanos()) + .unwrap_or(0); + // Stretch entropy with a small LCG step so adjacent calls don't return + // identical jitter. + let mixed = now_ns + .wrapping_mul(6364136223846793005) + .wrapping_add(1442695040888963407); + // Map to [-0.25, +0.25] of the duration as a u128 fraction. + let bucket = (mixed % 1000) as i128 - 500; // -500..=499 + let delta_nanos = (nanos as i128) * bucket / 2000; // ±25% + let adjusted = (nanos as i128).saturating_add(delta_nanos); + let adjusted_u: u128 = adjusted.max(0) as u128; + // Cap at u64::MAX nanos (≈ 584 years) to fit Duration::from_nanos. + let capped = adjusted_u.min(u64::MAX as u128) as u64; + Duration::from_nanos(capped) } } @@ -272,20 +309,37 @@ mod tests { #[test] fn test_exponential_backoff() { + // Backoff now applies ±25% jitter; assert bounds rather than exact + // equality. Base values: 1s, 2s, 4s, 8s. let config = RetryConfig::default(); - assert_eq!(config.calculate_backoff(1), Duration::from_secs(1)); - assert_eq!(config.calculate_backoff(2), Duration::from_secs(2)); - assert_eq!(config.calculate_backoff(3), Duration::from_secs(4)); - assert_eq!(config.calculate_backoff(4), Duration::from_secs(8)); + for (attempt, base_ms) in [(1u32, 1000u128), (2, 2000), (3, 4000), (4, 8000)] { + let got = config.calculate_backoff(attempt).as_millis(); + let lo = base_ms * 75 / 100; + let hi = base_ms * 125 / 100; + assert!( + got >= lo && got <= hi, + "attempt {attempt}: expected {lo}..={hi}ms, got {got}ms" + ); + } } #[test] fn test_max_backoff_cap() { + // The cap applies before jitter, so the returned value is in + // [0.75*cap, 1.25*cap]. + let cap = Duration::from_secs(10); let config = RetryConfig { - max_backoff: Duration::from_secs(10), + max_backoff: cap, ..Default::default() }; - assert_eq!(config.calculate_backoff(10), Duration::from_secs(10)); + let got = config.calculate_backoff(10).as_millis(); + let cap_ms = cap.as_millis(); + let lo = cap_ms * 75 / 100; + let hi = cap_ms * 125 / 100; + assert!( + got >= lo && got <= hi, + "expected {lo}..={hi}ms (cap ±25%), got {got}ms" + ); } #[test] @@ -402,6 +456,18 @@ mod tests { assert!(cfg.retry_on.is_empty(), "empty list ⇒ retry on every error"); } + // calculate_backoff applies ±25% jitter; assert bounds, not equality. + fn assert_within_jitter(actual: Duration, target: Duration) { + let target_ms = target.as_millis() as i128; + let actual_ms = actual.as_millis() as i128; + let low = target_ms * 75 / 100; + let high = target_ms * 125 / 100; + assert!( + actual_ms >= low && actual_ms <= high, + "{actual:?} outside ±25% of {target:?}" + ); + } + #[test] fn backoff_fixed_returns_base_for_any_attempt() { let cfg = RetryConfig { @@ -409,7 +475,7 @@ mod tests { ..Default::default() }; for attempt in [1u32, 2, 5, 100] { - assert_eq!(cfg.calculate_backoff(attempt), Duration::from_secs(1)); + assert_within_jitter(cfg.calculate_backoff(attempt), Duration::from_secs(1)); } } @@ -419,23 +485,24 @@ mod tests { backoff: BackoffStrategy::Linear, ..Default::default() }; - assert_eq!(cfg.calculate_backoff(1), Duration::from_secs(1)); - assert_eq!(cfg.calculate_backoff(5), Duration::from_secs(5)); - assert_eq!(cfg.calculate_backoff(50), Duration::from_secs(50)); + assert_within_jitter(cfg.calculate_backoff(1), Duration::from_secs(1)); + assert_within_jitter(cfg.calculate_backoff(5), Duration::from_secs(5)); + assert_within_jitter(cfg.calculate_backoff(50), Duration::from_secs(50)); } #[test] fn backoff_exponential_handles_attempt_zero_without_underflow() { // attempt = 0 ⇒ saturating_sub keeps exponent at 0 ⇒ 2^0 = 1 ⇒ base. let cfg = RetryConfig::default(); - assert_eq!(cfg.calculate_backoff(0), Duration::from_secs(1)); + assert_within_jitter(cfg.calculate_backoff(0), Duration::from_secs(1)); } #[test] fn backoff_exponential_caps_at_max_backoff_for_large_attempt() { - // attempt = 20 ⇒ 2^19 seconds = ~6 days, must cap to default 5 min. + // attempt = 20 saturates above max_backoff (300s); jitter pulls it + // down by up to 25%. let cfg = RetryConfig::default(); - assert_eq!(cfg.calculate_backoff(20), Duration::from_secs(300)); + assert_within_jitter(cfg.calculate_backoff(20), Duration::from_secs(300)); } #[test] diff --git a/crates/forge-core/src/lib.rs b/crates/forge-core/src/lib.rs index b8674061..c3d2b460 100644 --- a/crates/forge-core/src/lib.rs +++ b/crates/forge-core/src/lib.rs @@ -10,7 +10,6 @@ pub mod config; pub mod context; pub mod cron; pub mod daemon; -pub mod email; pub mod env; pub mod error; pub mod function; diff --git a/crates/forge-core/src/pagination.rs b/crates/forge-core/src/pagination.rs index 79436397..16d05a54 100644 --- a/crates/forge-core/src/pagination.rs +++ b/crates/forge-core/src/pagination.rs @@ -17,11 +17,20 @@ impl Cursor { Self(value.into()) } - pub fn as_str(&self) -> &str { + /// Internal accessor for serde / wire-format glue. Treat the returned + /// string as opaque: its encoding is an implementation detail and may + /// change between releases. + #[doc(hidden)] + pub fn as_inner_for_serde(&self) -> &str { &self.0 } } +/// Upper bound on items in a single [`Page`]. Helpers that build pages +/// from client-supplied limits should clamp to this value to prevent a +/// caller from extracting an unbounded number of rows. +pub const MAX_PAGE_SIZE: usize = 1000; + /// A page of results with cursor-based navigation. #[derive(Debug, Clone, Serialize, Deserialize)] #[non_exhaustive] @@ -31,7 +40,12 @@ pub struct Page { } impl Page { - pub fn new(items: Vec, page_info: PageInfo) -> Self { + /// Constructs a page, truncating `items` to [`MAX_PAGE_SIZE`] entries. + /// Callers that already enforce a stricter cap can pass shorter vecs. + pub fn new(mut items: Vec, page_info: PageInfo) -> Self { + if items.len() > MAX_PAGE_SIZE { + items.truncate(MAX_PAGE_SIZE); + } Self { items, page_info } } } @@ -44,7 +58,7 @@ pub struct PageInfo { #[serde(skip_serializing_if = "Option::is_none")] pub end_cursor: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub total_count: Option, + pub total_count: Option, } impl PageInfo { @@ -77,7 +91,7 @@ mod tests { PageInfo { has_next_page: true, end_cursor: Some(Cursor::new("abc")), - total_count: Some(10), + total_count: Some(10u64), }, ); let json: serde_json::Value = serde_json::to_value(&page).unwrap(); diff --git a/crates/forge-core/src/realtime/subscription.rs b/crates/forge-core/src/realtime/subscription.rs index 07f9d556..5da5773d 100644 --- a/crates/forge-core/src/realtime/subscription.rs +++ b/crates/forge-core/src/realtime/subscription.rs @@ -77,6 +77,11 @@ pub struct AuthScope { pub tenant_id: Option, /// Hash of the sorted roles for this auth context. pub role_hash: u64, + /// Admin status is part of the dedup key: `check_owner_access` treats + /// admins as having implicit access to every owner's data, so an admin + /// must never share a group with a non-admin even when principal/tenant/ + /// roles coincide. + pub is_admin: bool, } impl PartialEq for AuthScope { @@ -84,6 +89,7 @@ impl PartialEq for AuthScope { self.principal_id == other.principal_id && self.tenant_id == other.tenant_id && self.role_hash == other.role_hash + && self.is_admin == other.is_admin } } @@ -94,6 +100,7 @@ impl std::hash::Hash for AuthScope { self.principal_id.hash(state); self.tenant_id.hash(state); self.role_hash.hash(state); + self.is_admin.hash(state); } } @@ -114,6 +121,7 @@ impl AuthScope { .and_then(|v| v.as_str()) .map(ToString::to_string), role_hash, + is_admin: auth.is_admin(), } } } @@ -186,14 +194,37 @@ impl QueryGroup { use std::hash::Hash; match value { serde_json::Value::Object(map) => { + hasher.write_u8(b'o'); let mut keys: Vec<&String> = map.keys().collect(); keys.sort(); for key in keys { key.hash(hasher); Self::hash_json_canonical(&map[key], hasher); } + hasher.write_u8(b'e'); + } + serde_json::Value::Array(items) => { + hasher.write_u8(b'a'); + for item in items { + Self::hash_json_canonical(item, hasher); + } + hasher.write_u8(b'e'); + } + serde_json::Value::Null => { + hasher.write_u8(b'0'); + } + serde_json::Value::Bool(b) => { + hasher.write_u8(b'b'); + b.hash(hasher); + } + serde_json::Value::Number(n) => { + hasher.write_u8(b'n'); + n.to_string().hash(hasher); + } + serde_json::Value::String(s) => { + hasher.write_u8(b's'); + s.hash(hasher); } - other => other.to_string().hash(hasher), } } @@ -221,8 +252,11 @@ impl QueryGroup { /// Uses the runtime read set when populated, otherwise falls back to the /// compile-time table dependencies from macro extraction. pub fn should_invalidate(&self, change: &super::readset::Change) -> bool { + let in_compile_deps = self.table_deps.iter().any(|t| *t == change.table); + let in_runtime_set = self.read_set.tables.iter().any(|t| t == &change.table); + let table_matches = if self.read_set.tables.is_empty() { - self.table_deps.iter().any(|t| *t == change.table) + in_compile_deps } else { change.invalidates(&self.read_set) }; @@ -231,7 +265,14 @@ impl QueryGroup { return false; } - if !change.invalidates_columns(self.selected_cols) { + // `selected_cols` is captured at compile time for the macro-declared + // tables. Runtime-discovered tables (added by the manager when the + // read_set widens after execution) don't have a per-table column + // map, so applying the compile-time column filter to them would + // wrongly suppress real changes. Fall back to "always invalidate" + // for those, and only apply the column filter when the change came + // through a compile-time-declared table. + if in_compile_deps && !in_runtime_set && !change.invalidates_columns(self.selected_cols) { return false; } @@ -385,6 +426,7 @@ mod tests { principal_id: Some("user-1".to_string()), tenant_id: None, role_hash: 0, + is_admin: false, }; let key1 = QueryGroup::compute_lookup_key( "get_projects", @@ -402,6 +444,7 @@ mod tests { principal_id: Some("user-2".to_string()), tenant_id: None, role_hash: 0, + is_admin: false, }; let key3 = QueryGroup::compute_lookup_key( "get_projects", @@ -417,6 +460,7 @@ mod tests { principal_id: Some("u1".to_string()), tenant_id: None, role_hash: 0, + is_admin: false, }; let key = QueryGroup::compute_lookup_key("get_items", &serde_json::json!({"id": "42"}), &scope); @@ -433,6 +477,7 @@ mod tests { principal_id: None, tenant_id: None, role_hash: 0, + is_admin: false, }; let key_ab = QueryGroup::compute_lookup_key("q", &serde_json::json!({"a": 1, "b": 2}), &scope); diff --git a/crates/forge-core/src/tenant/mod.rs b/crates/forge-core/src/tenant/mod.rs index a7f67a53..d46e9ee0 100644 --- a/crates/forge-core/src/tenant/mod.rs +++ b/crates/forge-core/src/tenant/mod.rs @@ -142,4 +142,45 @@ mod tests { let ctx = TenantContext::strict(tenant_id); assert!(ctx.require_tenant().is_ok()); } + + #[test] + fn sql_filter_rejects_injection_attempts() { + let ctx = TenantContext::strict(Uuid::new_v4()); + // Anything outside [A-Za-z0-9_] must be refused so the column name can + // never carry SQL. Empty is rejected too. + for bad in [ + "", + "tenant_id; DROP TABLE users", + "tenant_id OR 1=1", + "tenant_id--", + "tenant\"_id", + "tenant id", + "tenant.id", + "tenant_id)", + "té", + ] { + assert!( + ctx.sql_filter(bad, 1).is_none(), + "column {bad:?} should be rejected" + ); + } + } + + #[test] + fn sql_filter_accepts_valid_identifiers_and_quotes_them() { + let tenant_id = Uuid::new_v4(); + let ctx = TenantContext::strict(tenant_id); + let (clause, id) = ctx + .sql_filter("org_id_2", 5) + .expect("alphanumeric+underscore column is valid"); + assert_eq!(clause, "\"org_id_2\" = $5"); + assert_eq!(id, tenant_id); + } + + #[test] + fn sql_filter_returns_none_without_tenant_even_for_valid_column() { + // No tenant id => nothing to scope by, regardless of column validity. + let ctx = TenantContext::none(); + assert!(ctx.sql_filter("tenant_id", 1).is_none()); + } } diff --git a/crates/forge-core/src/testing/assertions.rs b/crates/forge-core/src/testing/assertions.rs index ed364687..d1dee025 100644 --- a/crates/forge-core/src/testing/assertions.rs +++ b/crates/forge-core/src/testing/assertions.rs @@ -274,127 +274,3 @@ where { items.iter().any(predicate) } - -#[cfg(test)] -mod tests { - use super::{assert_contains, assert_json_matches}; - use crate::error::ForgeError; - - #[test] - fn test_assert_ok_macro() { - let result: Result = Ok(42); - assert_ok!(result); - } - - #[test] - #[should_panic(expected = "expected Ok")] - fn test_assert_ok_macro_fails() { - let result: Result = Err("error".to_string()); - assert_ok!(result); - } - - #[test] - fn test_assert_err_macro() { - let result: Result = Err("error".to_string()); - assert_err!(result); - } - - #[test] - #[should_panic(expected = "expected Err")] - fn test_assert_err_macro_fails() { - let result: Result = Ok(42); - assert_err!(result); - } - - #[test] - fn test_assert_err_variant() { - let result: Result<(), ForgeError> = Err(ForgeError::NotFound("user".into())); - assert_err_variant!(result, ForgeError::NotFound(_)); - } - - #[test] - #[should_panic(expected = "expected ForgeError::Unauthorized(_)")] - fn test_assert_err_variant_wrong_variant() { - let result: Result<(), ForgeError> = Err(ForgeError::NotFound("user".into())); - assert_err_variant!(result, ForgeError::Unauthorized(_)); - } - - #[test] - fn test_assert_err_matches_no_guard() { - let result: Result<(), ForgeError> = - Err(ForgeError::Validation("email is required".into())); - assert_err_matches!(result, ForgeError::Validation(_)); - } - - #[test] - #[allow(unused_variables)] - fn test_assert_err_matches_with_guard() { - let result: Result<(), ForgeError> = - Err(ForgeError::Validation("email is required".into())); - assert_err_matches!(result, ForgeError::Validation(msg) if msg.contains("email")); - } - - #[test] - #[should_panic(expected = "guard failed")] - #[allow(unused_variables)] - fn test_assert_err_matches_guard_fails() { - let result: Result<(), ForgeError> = - Err(ForgeError::Validation("email is required".into())); - assert_err_matches!(result, ForgeError::Validation(msg) if msg.contains("password")); - } - - #[test] - #[should_panic(expected = "expected ForgeError::Unauthorized(_)")] - fn test_assert_err_matches_wrong_variant() { - let result: Result<(), ForgeError> = Err(ForgeError::NotFound("user".into())); - assert_err_matches!(result, ForgeError::Unauthorized(_)); - } - - #[test] - fn test_assert_json_matches() { - let actual = serde_json::json!({ - "id": 123, - "name": "Test", - "nested": { - "foo": "bar" - } - }); - - assert!(assert_json_matches( - &actual, - &serde_json::json!({"id": 123}) - )); - assert!(assert_json_matches( - &actual, - &serde_json::json!({"name": "Test"}) - )); - assert!(assert_json_matches( - &actual, - &serde_json::json!({"nested": {"foo": "bar"}}) - )); - - assert!(!assert_json_matches( - &actual, - &serde_json::json!({"id": 456}) - )); - assert!(!assert_json_matches( - &actual, - &serde_json::json!({"missing": true}) - )); - } - - #[test] - fn test_assert_json_matches_arrays() { - let actual = serde_json::json!([1, 2, 3]); - assert!(assert_json_matches(&actual, &serde_json::json!([1, 2, 3]))); - assert!(!assert_json_matches(&actual, &serde_json::json!([1, 2]))); - assert!(!assert_json_matches(&actual, &serde_json::json!([1, 2, 4]))); - } - - #[test] - fn test_assert_contains() { - let items = vec![1, 2, 3, 4, 5]; - assert!(assert_contains(&items, |x| *x == 3)); - assert!(!assert_contains(&items, |x| *x == 6)); - } -} diff --git a/crates/forge-core/src/testing/context/mcp_tool.rs b/crates/forge-core/src/testing/context/mcp_tool.rs index 78f075fb..101fefd0 100644 --- a/crates/forge-core/src/testing/context/mcp_tool.rs +++ b/crates/forge-core/src/testing/context/mcp_tool.rs @@ -137,8 +137,17 @@ impl TestMcpToolContextBuilder { self } + /// Set the tenant id for multi-tenant testing. + /// + /// Production code reads the tenant from `auth.claims["tenant_id"]`, so + /// this writes the same value into the claims map. Tests calling + /// `ctx.auth.tenant_id()` then behave identically to production. pub fn with_tenant(mut self, tenant_id: Uuid) -> Self { self.tenant_id = Some(tenant_id); + self.claims.insert( + "tenant_id".to_string(), + serde_json::Value::String(tenant_id.to_string()), + ); self } diff --git a/crates/forge-core/src/testing/context/query.rs b/crates/forge-core/src/testing/context/query.rs index 15024cd0..51a41637 100644 --- a/crates/forge-core/src/testing/context/query.rs +++ b/crates/forge-core/src/testing/context/query.rs @@ -128,8 +128,16 @@ impl TestQueryContextBuilder { } /// Set the tenant ID for multi-tenant testing. + /// + /// Production code reads the tenant from `auth.claims["tenant_id"]`, so + /// this writes the same value into the claims map. Tests calling + /// `ctx.auth.tenant_id()` then behave identically to production. pub fn with_tenant(mut self, tenant_id: Uuid) -> Self { self.tenant_id = Some(tenant_id); + self.claims.insert( + "tenant_id".to_string(), + serde_json::Value::String(tenant_id.to_string()), + ); self } @@ -167,45 +175,6 @@ impl TestQueryContextBuilder { mod tests { use super::*; - #[test] - fn test_minimal_context() { - let ctx = TestQueryContext::minimal(); - assert!(!ctx.auth.is_authenticated()); - assert!(ctx.db().is_none()); - } - - #[test] - fn test_authenticated_context() { - let user_id = Uuid::new_v4(); - let ctx = TestQueryContext::authenticated(user_id); - assert!(ctx.auth.is_authenticated()); - assert_eq!(ctx.user_id().unwrap(), user_id); - } - - #[test] - fn test_context_with_roles() { - let ctx = TestQueryContext::builder() - .as_user(Uuid::new_v4()) - .with_role("admin") - .with_role("user") - .build(); - - assert!(ctx.has_role("admin")); - assert!(ctx.has_role("user")); - assert!(!ctx.has_role("superuser")); - } - - #[test] - fn test_context_with_claims() { - let ctx = TestQueryContext::builder() - .as_user(Uuid::new_v4()) - .with_claim("org_id", serde_json::json!("org-123")) - .build(); - - assert_eq!(ctx.claim("org_id"), Some(&serde_json::json!("org-123"))); - assert!(ctx.claim("nonexistent").is_none()); - } - #[test] fn test_context_with_env() { let ctx = TestQueryContext::builder() diff --git a/crates/forge-core/src/testing/context/workflow.rs b/crates/forge-core/src/testing/context/workflow.rs index 2f8b5ceb..d7fa9a2e 100644 --- a/crates/forge-core/src/testing/context/workflow.rs +++ b/crates/forge-core/src/testing/context/workflow.rs @@ -275,8 +275,15 @@ impl TestWorkflowContextBuilder { } /// Set the tenant ID. + /// + /// Production reads the tenant from `auth.claims["tenant_id"]`, so this + /// writes the same value into the claims map. pub fn with_tenant(mut self, tenant_id: Uuid) -> Self { self.tenant_id = Some(tenant_id); + self.claims.insert( + "tenant_id".to_string(), + serde_json::Value::String(tenant_id.to_string()), + ); self } diff --git a/crates/forge-core/src/testing/db.rs b/crates/forge-core/src/testing/db.rs index 57585343..3d4bfa2c 100644 --- a/crates/forge-core/src/testing/db.rs +++ b/crates/forge-core/src/testing/db.rs @@ -119,10 +119,16 @@ impl TestDatabase { /// Create a dedicated database for a single test, providing full isolation. pub async fn isolated(&self, test_name: &str) -> Result { let base_url = self.url.clone(); + // Cap the final identifier well under Postgres' 63-char limit so two + // tests with the same long prefix never collide on a truncated name. + // Layout: `forge_test_` (11) + sanitized name (<=16) + `_` (1) + + // 8 hex chars of a UUID = 36 chars total. + let uuid_hex = uuid::Uuid::new_v4().simple().to_string(); + let short_uuid: String = uuid_hex.chars().take(8).collect(); let db_name = format!( "forge_test_{}_{}", - sanitize_db_name(test_name), - uuid::Uuid::new_v4().simple() + sanitize_db_name_short(test_name), + short_uuid ); sqlx::query(&format!("CREATE DATABASE \"{}\"", db_name)) @@ -139,7 +145,7 @@ impl TestDatabase { .map_err(ForgeError::Database)?; Ok(IsolatedTestDb { - pool: test_pool, + pool: Some(test_pool), db_name, base_url, #[cfg(feature = "testcontainers")] @@ -148,10 +154,15 @@ impl TestDatabase { } } -/// A test database scoped to a single test. Call `cleanup()` to drop it immediately, -/// or rely on future test runs to clean up orphaned databases. +/// A test database scoped to a single test. +/// +/// Cleanup happens in `Drop`: the pool is closed and `DROP DATABASE` is fired +/// on a fresh sync connection via `tokio::task::block_in_place` + +/// `Handle::current().block_on()`. Tests that want to surface cleanup errors +/// can call [`IsolatedTestDb::cleanup`] (async) explicitly instead — `Drop` +/// then becomes a no-op. pub struct IsolatedTestDb { - pool: PgPool, + pool: Option, db_name: String, base_url: String, #[cfg(feature = "testcontainers")] @@ -160,16 +171,24 @@ pub struct IsolatedTestDb { impl IsolatedTestDb { /// Convenience: `from_env()` → `isolated()` → `run_sql(internal_sql)` → `migrate()`. + /// + /// On a partial failure (system SQL or user migrations), the freshly-created + /// database is dropped via the standard `Drop` path of the guard struct — + /// the caller never observes a leaked database. pub async fn setup(test_name: &str, internal_sql: &str, migrations_dir: &Path) -> Result { let base = TestDatabase::from_env().await?; let db = base.isolated(test_name).await?; + // The half-built db is owned by `db`; if either step below returns + // early, `db`'s Drop fires and the database is dropped. db.run_sql(internal_sql).await?; db.migrate(migrations_dir).await?; Ok(db) } pub fn pool(&self) -> &PgPool { - &self.pool + self.pool + .as_ref() + .expect("IsolatedTestDb pool is taken only during Drop/cleanup") } pub fn db_name(&self) -> &str { @@ -179,7 +198,7 @@ impl IsolatedTestDb { /// Run raw SQL for test setup. pub async fn execute(&self, sql: &str) -> Result<()> { sqlx::query(sql) - .execute(&self.pool) + .execute(self.pool()) .await .map_err(ForgeError::Database)?; Ok(()) @@ -193,7 +212,7 @@ impl IsolatedTestDb { continue; } sqlx::query(stmt) - .execute(&self.pool) + .execute(self.pool()) .await .map_err(|e| ForgeError::internal_with("Failed to execute SQL", e))?; } @@ -201,32 +220,97 @@ impl IsolatedTestDb { } /// Drop the isolated database and close all connections. - pub async fn cleanup(self) -> Result<()> { - self.pool.close().await; + /// + /// Calling this disarms the `Drop` impl — useful for tests that want + /// cleanup errors to surface rather than being logged. + pub async fn cleanup(mut self) -> Result<()> { + let pool = match self.pool.take() { + Some(p) => p, + None => return Ok(()), + }; + drop_db_async(pool, &self.base_url, &self.db_name).await + } +} - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(1) - .connect(&self.base_url) - .await - .map_err(ForgeError::Database)?; +async fn drop_db_async(pool: PgPool, base_url: &str, db_name: &str) -> Result<()> { + pool.close().await; - if let Err(e) = - sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1") - .bind(&self.db_name) - .execute(&pool) - .await - { - tracing::warn!(db = %self.db_name, error = %e, "failed to terminate backend connections during test cleanup"); - } + let admin_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(1) + .connect(base_url) + .await + .map_err(ForgeError::Database)?; - sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", self.db_name)) - .execute(&pool) + if let Err(e) = + sqlx::query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = $1") + .bind(db_name) + .execute(&admin_pool) .await - .map_err(ForgeError::Database)?; + { + tracing::warn!(db = %db_name, error = %e, "failed to terminate backend connections during test cleanup"); + } - Ok(()) + sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", db_name)) + .execute(&admin_pool) + .await + .map_err(ForgeError::Database)?; + + Ok(()) +} + +impl Drop for IsolatedTestDb { + fn drop(&mut self) { + let Some(pool) = self.pool.take() else { + return; + }; + let base_url = self.base_url.clone(); + let db_name = self.db_name.clone(); + + // The runtime flavor decides how we drive the async cleanup: + // - multi_thread: `block_in_place` releases the worker so a nested + // `block_on` is safe. + // - current_thread: `block_in_place` panics; we instead spawn the + // cleanup as a detached task on the existing handle. The runtime + // drives it to completion before the process exits as long as the + // test runtime outlives this Drop (true for `#[tokio::test]` since + // the runtime owns the future). + // - no runtime: nothing we can do; log and leak. + match tokio::runtime::Handle::try_current() { + Ok(handle) => match handle.runtime_flavor() { + tokio::runtime::RuntimeFlavor::MultiThread => { + tokio::task::block_in_place(|| { + if let Err(e) = handle.block_on(drop_db_async(pool, &base_url, &db_name)) { + tracing::warn!( + db = %db_name, + error = %e, + "IsolatedTestDb::drop failed to clean up; database leaked" + ); + } + }); + } + _ => { + handle.spawn(async move { + if let Err(e) = drop_db_async(pool, &base_url, &db_name).await { + tracing::warn!( + db = %db_name, + error = %e, + "IsolatedTestDb::drop failed to clean up; database leaked" + ); + } + }); + } + }, + Err(_) => { + tracing::warn!( + db = %db_name, + "IsolatedTestDb dropped outside a tokio runtime; database leaked" + ); + } + } } +} +impl IsolatedTestDb { /// Run migrations: loads all `.sql` files from the directory, sorts alphabetically, executes in order. pub async fn migrate(&self, migrations_dir: &Path) -> Result<()> { if !migrations_dir.exists() { @@ -268,7 +352,7 @@ impl IsolatedTestDb { if is_blank_sql(stmt) { continue; } - sqlx::query(stmt).execute(&self.pool).await.map_err(|e| { + sqlx::query(stmt).execute(self.pool()).await.map_err(|e| { ForgeError::internal(format!("Failed to apply migration '{name}': {e}")) })?; } @@ -285,10 +369,14 @@ fn is_blank_sql(sql: &str) -> bool { .all(|l| l.trim().is_empty() || l.trim().starts_with("--")) } -fn sanitize_db_name(name: &str) -> String { +/// Sanitize a test name into something that's safe to embed in a Postgres +/// identifier. Capped at 16 characters so the final +/// `forge_test__<8hex>` identifier stays well under Postgres' 63-char +/// identifier limit (11 + 16 + 1 + 8 = 36). +fn sanitize_db_name_short(name: &str) -> String { name.chars() .map(|c| if c.is_alphanumeric() { c } else { '_' }) - .take(32) + .take(16) .collect() } @@ -445,11 +533,11 @@ mod tests { use super::*; #[test] - fn test_sanitize_db_name() { - assert_eq!(sanitize_db_name("my_test"), "my_test"); - assert_eq!(sanitize_db_name("my-test"), "my_test"); - assert_eq!(sanitize_db_name("my test"), "my_test"); - assert_eq!(sanitize_db_name("test::function"), "test__function"); + fn test_sanitize_db_name_short() { + assert_eq!(sanitize_db_name_short("my_test"), "my_test"); + assert_eq!(sanitize_db_name_short("my-test"), "my_test"); + assert_eq!(sanitize_db_name_short("my test"), "my_test"); + assert_eq!(sanitize_db_name_short("test::function"), "test__function"); } #[test] @@ -560,18 +648,21 @@ mod tests { } #[test] - fn sanitize_truncates_long_names() { + fn sanitize_short_caps_at_16() { let long_name = "a".repeat(100); - let sanitized = sanitize_db_name(&long_name); - assert_eq!(sanitized.len(), 32); + let sanitized = sanitize_db_name_short(&long_name); + assert_eq!(sanitized.len(), 16); + // Full identifier: 11 ("forge_test_") + 16 + 1 + 8 = 36, safely <= 63. + let identifier = format!("forge_test_{}_{}", sanitized, "12345678"); + assert!(identifier.len() <= 63); } #[test] fn sanitize_handles_special_characters() { assert_eq!( - sanitize_db_name("test/with:special!chars"), - "test_with_special_chars" + sanitize_db_name_short("test/with:specia"), + "test_with_specia" ); - assert_eq!(sanitize_db_name(""), ""); + assert_eq!(sanitize_db_name_short(""), ""); } } diff --git a/crates/forge-core/src/testing/mock_dispatch.rs b/crates/forge-core/src/testing/mock_dispatch.rs index bec8e4aa..11506d5e 100644 --- a/crates/forge-core/src/testing/mock_dispatch.rs +++ b/crates/forge-core/src/testing/mock_dispatch.rs @@ -87,18 +87,21 @@ impl MockJobDispatch { cancel_reason: None, }; - self.jobs.write().expect("jobs lock poisoned").push(job); + self.jobs + .write() + .unwrap_or_else(|p| p.into_inner()) + .push(job); Ok(id) } pub fn dispatched_jobs(&self) -> Vec { - self.jobs.read().expect("jobs lock poisoned").clone() + self.jobs.read().unwrap_or_else(|p| p.into_inner()).clone() } pub fn jobs_of_type(&self, job_type: &str) -> Vec { self.jobs .read() - .expect("jobs lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .iter() .filter(|j| j.job_type == job_type) .cloned() @@ -106,7 +109,7 @@ impl MockJobDispatch { } pub fn assert_dispatched(&self, job_type: &str) { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let found = jobs.iter().any(|j| j.job_type == job_type); assert!( found, @@ -116,11 +119,13 @@ impl MockJobDispatch { ); } + /// Lenient: passes when *any* dispatched job with this name matches + /// the predicate. Other unrelated dispatches are ignored. pub fn assert_dispatched_with(&self, job_type: &str, predicate: F) where F: Fn(&serde_json::Value) -> bool, { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let found = jobs .iter() .any(|j| j.job_type == job_type && predicate(&j.args)); @@ -131,8 +136,37 @@ impl MockJobDispatch { ); } + /// Strict: passes only when *every* dispatched job with this name + /// matches the predicate (and at least one such dispatch exists). + /// Use to assert a precise audience, e.g. "the email job ran for + /// user 5 and nobody else." + pub fn assert_dispatched_with_exact(&self, job_type: &str, predicate: F) + where + F: Fn(&serde_json::Value) -> bool, + { + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); + let matching: Vec<&DispatchedJob> = + jobs.iter().filter(|j| j.job_type == job_type).collect(); + assert!( + !matching.is_empty(), + "Expected at least one dispatch of '{}', but none were recorded", + job_type + ); + let mismatches: Vec<&serde_json::Value> = matching + .iter() + .filter(|j| !predicate(&j.args)) + .map(|j| &j.args) + .collect(); + assert!( + mismatches.is_empty(), + "Expected every dispatch of '{}' to match predicate; mismatched args: {:?}", + job_type, + mismatches + ); + } + pub fn assert_not_dispatched(&self, job_type: &str) { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let found = jobs.iter().any(|j| j.job_type == job_type); assert!( !found, @@ -142,7 +176,7 @@ impl MockJobDispatch { } pub fn assert_dispatch_count(&self, job_type: &str, expected: usize) { - let jobs = self.jobs.read().expect("jobs lock poisoned"); + let jobs = self.jobs.read().unwrap_or_else(|p| p.into_inner()); let count = jobs.iter().filter(|j| j.job_type == job_type).count(); assert_eq!( count, expected, @@ -152,28 +186,34 @@ impl MockJobDispatch { } pub fn clear(&self) { - self.jobs.write().expect("jobs lock poisoned").clear(); + self.jobs.write().unwrap_or_else(|p| p.into_inner()).clear(); } pub fn complete_job(&self, job_id: Uuid) { - let mut jobs = self.jobs.write().expect("jobs lock poisoned"); + let mut jobs = self.jobs.write().unwrap_or_else(|p| p.into_inner()); if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) { job.status = JobStatus::Completed; } } pub fn fail_job(&self, job_id: Uuid) { - let mut jobs = self.jobs.write().expect("jobs lock poisoned"); + let mut jobs = self.jobs.write().unwrap_or_else(|p| p.into_inner()); if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) { job.status = JobStatus::Failed; } } - pub fn cancel_job(&self, job_id: Uuid, reason: Option) { - let mut jobs = self.jobs.write().expect("jobs lock poisoned"); + /// Returns `true` when a matching dispatched job was found and marked + /// cancelled, `false` otherwise. Mirrors production semantics so tests + /// can assert cancel-of-unknown-id behaviour. + pub fn cancel_job(&self, job_id: Uuid, reason: Option) -> bool { + let mut jobs = self.jobs.write().unwrap_or_else(|p| p.into_inner()); if let Some(job) = jobs.iter_mut().find(|j| j.id == job_id) { job.status = JobStatus::Cancelled; job.cancel_reason = reason; + true + } else { + false } } } @@ -252,10 +292,7 @@ impl crate::function::JobDispatch for MockJobDispatch { job_id: Uuid, reason: Option, ) -> std::pin::Pin> + Send + '_>> { - Box::pin(async move { - self.cancel_job(job_id, reason); - Ok(true) - }) + Box::pin(async move { Ok(self.cancel_job(job_id, reason)) }) } } @@ -286,7 +323,7 @@ impl MockWorkflowDispatch { self.workflows .write() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .push(workflow); Ok(run_id) } @@ -294,14 +331,14 @@ impl MockWorkflowDispatch { pub fn started_workflows(&self) -> Vec { self.workflows .read() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .clone() } pub fn workflows_named(&self, name: &str) -> Vec { self.workflows .read() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .iter() .filter(|w| w.workflow_name == name) .cloned() @@ -309,7 +346,7 @@ impl MockWorkflowDispatch { } pub fn assert_started(&self, workflow_name: &str) { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let found = workflows.iter().any(|w| w.workflow_name == workflow_name); assert!( found, @@ -326,7 +363,7 @@ impl MockWorkflowDispatch { where F: Fn(&serde_json::Value) -> bool, { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let found = workflows .iter() .any(|w| w.workflow_name == workflow_name && predicate(&w.input)); @@ -338,7 +375,7 @@ impl MockWorkflowDispatch { } pub fn assert_not_started(&self, workflow_name: &str) { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let found = workflows.iter().any(|w| w.workflow_name == workflow_name); assert!( !found, @@ -348,7 +385,7 @@ impl MockWorkflowDispatch { } pub fn assert_start_count(&self, workflow_name: &str, expected: usize) { - let workflows = self.workflows.read().expect("workflows lock poisoned"); + let workflows = self.workflows.read().unwrap_or_else(|p| p.into_inner()); let count = workflows .iter() .filter(|w| w.workflow_name == workflow_name) @@ -363,19 +400,19 @@ impl MockWorkflowDispatch { pub fn clear(&self) { self.workflows .write() - .expect("workflows lock poisoned") + .unwrap_or_else(|p| p.into_inner()) .clear(); } pub fn complete_workflow(&self, run_id: Uuid) { - let mut workflows = self.workflows.write().expect("workflows lock poisoned"); + let mut workflows = self.workflows.write().unwrap_or_else(|p| p.into_inner()); if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) { workflow.status = WorkflowStatus::Completed; } } pub fn fail_workflow(&self, run_id: Uuid) { - let mut workflows = self.workflows.write().expect("workflows lock poisoned"); + let mut workflows = self.workflows.write().unwrap_or_else(|p| p.into_inner()); if let Some(workflow) = workflows.iter_mut().find(|w| w.run_id == run_id) { workflow.status = WorkflowStatus::Failed; } diff --git a/crates/forge-core/src/testing/mock_email.rs b/crates/forge-core/src/testing/mock_email.rs deleted file mode 100644 index d2913524..00000000 --- a/crates/forge-core/src/testing/mock_email.rs +++ /dev/null @@ -1,72 +0,0 @@ -//! Mock email sender for testing. - -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use tokio::sync::Mutex; - -use crate::email::{Email, EmailSender}; -use crate::error::Result; - -/// Records sent emails for assertion in tests. -#[derive(Debug, Clone, Default)] -pub struct MockEmailSender { - sent: Arc>>, -} - -/// A recorded email send. -#[derive(Debug, Clone)] -pub struct SentEmail { - pub to: Vec, - pub subject: String, - pub text: Option, - pub html: Option, -} - -impl MockEmailSender { - pub fn new() -> Self { - Self::default() - } - - pub async fn sent(&self) -> Vec { - self.sent.lock().await.clone() - } - - /// Assert that exactly one email was sent to the given address. - pub async fn assert_sent_to(&self, address: &str) { - let sent = self.sent.lock().await; - let matching: Vec<_> = sent - .iter() - .filter(|e| e.to.contains(&address.to_string())) - .collect(); - assert!( - matching.len() == 1, - "Expected 1 email to {address}, found {}", - matching.len() - ); - } - - /// Assert that no emails were sent. - pub async fn assert_none_sent(&self) { - let sent = self.sent.lock().await; - assert!(sent.is_empty(), "Expected no emails, found {}", sent.len()); - } -} - -impl EmailSender for MockEmailSender { - fn send<'a>( - &'a self, - email: &'a Email, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - self.sent.lock().await.push(SentEmail { - to: email.to.clone(), - subject: email.subject.clone(), - text: email.text.clone(), - html: email.html.clone(), - }); - Ok(format!("mock-{}", uuid::Uuid::new_v4())) - }) - } -} diff --git a/crates/forge-core/src/testing/mock_http.rs b/crates/forge-core/src/testing/mock_http.rs index 5e46ae88..a5ebafae 100644 --- a/crates/forge-core/src/testing/mock_http.rs +++ b/crates/forge-core/src/testing/mock_http.rs @@ -327,281 +327,3 @@ impl Default for MockHttpBuilder { Self::new() } } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_mock_response_json() { - let response = MockResponse::json(serde_json::json!({"id": 123})); - assert_eq!(response.status, 200); - assert_eq!(response.body["id"], 123); - } - - #[test] - fn test_mock_response_error() { - let response = MockResponse::error(404, "Not found"); - assert_eq!(response.status, 404); - assert_eq!(response.body["error"], "Not found"); - } - - #[test] - fn test_pattern_matching() { - let mock = MockHttp::new(); - - assert!(mock.matches_pattern( - "https://api.example.com/users", - "https://api.example.com/users" - )); - - assert!(mock.matches_pattern( - "https://api.example.com/users/123", - "https://api.example.com/*" - )); - - assert!(mock.matches_pattern( - "https://api.example.com/v2/users", - "https://api.example.com/*/users" - )); - - assert!(!mock.matches_pattern("https://other.com/users", "https://api.example.com/*")); - } - - #[tokio::test] - async fn test_mock_execution() { - let mock = MockHttp::new(); - mock.add_mock_sync("https://api.example.com/*", |_| { - MockResponse::json(serde_json::json!({"status": "ok"})) - }); - - let request = MockRequest { - method: "GET".to_string(), - path: "/users".to_string(), - url: "https://api.example.com/users".to_string(), - headers: HashMap::new(), - body: serde_json::Value::Null, - }; - - let response = mock.execute(request).await; - assert_eq!(response.status, 200); - assert_eq!(response.body["status"], "ok"); - } - - #[tokio::test] - async fn test_request_recording() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let request = MockRequest { - method: "POST".to_string(), - path: "/api/users".to_string(), - url: "https://api.example.com/users".to_string(), - headers: HashMap::from([("authorization".to_string(), "Bearer token".to_string())]), - body: serde_json::json!({"name": "Test"}), - }; - - let _ = mock.execute(request).await; - - let requests = mock.requests(); - assert_eq!(requests.len(), 1); - assert_eq!(requests[0].method, "POST"); - assert_eq!(requests[0].body["name"], "Test"); - } - - #[tokio::test] - async fn test_assert_called() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let request = MockRequest { - method: "GET".to_string(), - path: "/test".to_string(), - url: "https://api.example.com/test".to_string(), - headers: HashMap::new(), - body: serde_json::Value::Null, - }; - - let _ = mock.execute(request).await; - - mock.assert_called("https://api.example.com/*"); - mock.assert_called_times("https://api.example.com/*", 1); - mock.assert_not_called("https://other.com/*"); - } - - #[test] - fn test_builder() { - let mock = MockHttpBuilder::new() - .mock("https://api.example.com/*", |_| MockResponse::ok()) - .mock_json("https://other.com/*", serde_json::json!({"id": 1})) - .build(); - - assert_eq!(mock.mocks.read().unwrap().len(), 2); - } - - fn req(method: &str, url: &str, path: &str) -> MockRequest { - MockRequest { - method: method.to_string(), - path: path.to_string(), - url: url.to_string(), - headers: HashMap::new(), - body: serde_json::Value::Null, - } - } - - #[test] - fn response_status_helpers_use_documented_codes() { - assert_eq!(MockResponse::internal_error("boom").status, 500); - assert_eq!(MockResponse::not_found("nope").status, 404); - assert_eq!(MockResponse::unauthorized("nope").status, 401); - assert_eq!(MockResponse::ok().status, 200); - - // ok() returns an empty JSON object — handlers that key off body shape - // (e.g., serde to () or empty struct) rely on this. - assert_eq!(MockResponse::ok().body, serde_json::json!({})); - } - - #[test] - fn response_json_sets_content_type_header() { - let r = MockResponse::json(serde_json::json!({"ok": true})); - assert_eq!( - r.headers.get("content-type"), - Some(&"application/json".to_string()) - ); - } - - #[test] - fn pattern_matcher_handles_leading_and_double_wildcards() { - let m = MockHttp::new(); - // Leading wildcard (pattern_parts[0] is empty). - assert!(m.matches_pattern("https://api.example.com/v1/users", "*/users")); - assert!(!m.matches_pattern("https://api.example.com/v1/posts", "*/users")); - - // Bare `*` matches anything (both pattern parts are empty strings). - assert!(m.matches_pattern("anything", "*")); - assert!(m.matches_pattern("", "*")); - } - - #[test] - fn pattern_matcher_rejects_exact_pattern_with_extra_suffix() { - let m = MockHttp::new(); - assert!(!m.matches_pattern( - "https://api.example.com/users/extra", - "https://api.example.com/users" - )); - } - - #[tokio::test] - async fn execute_falls_back_to_500_when_no_mock_matches() { - let mock = MockHttp::new(); - let r = mock.execute(req("GET", "https://nowhere/", "/")).await; - assert_eq!(r.status, 500); - assert!( - r.body["error"] - .as_str() - .unwrap_or_default() - .contains("No mock found"), - "fallback should explain the failure, got {:?}", - r.body - ); - } - - #[tokio::test] - async fn execute_records_request_even_when_no_mock_matches() { - // The recording happens before the lookup so failed-match calls still - // show up in requests() — important for diagnosing "why didn't my mock fire". - let mock = MockHttp::new(); - let _ = mock.execute(req("DELETE", "https://nowhere/x", "/x")).await; - let recorded = mock.requests(); - assert_eq!(recorded.len(), 1); - assert_eq!(recorded[0].method, "DELETE"); - assert_eq!(recorded[0].url, "https://nowhere/x"); - } - - #[tokio::test] - async fn execute_matches_against_path_when_url_misses() { - // Pattern only matches the path, not the full URL. - let mock = MockHttp::new(); - mock.add_mock_sync("/health", |_| MockResponse::ok()); - let r = mock - .execute(req("GET", "https://internal.svc:8080/health", "/health")) - .await; - assert_eq!(r.status, 200); - } - - #[tokio::test] - async fn execute_uses_first_registered_mock_on_overlapping_patterns() { - let mock = MockHttp::new(); - mock.add_mock_sync("https://api.example.com/*", |_| { - MockResponse::json(serde_json::json!({"hit": "first"})) - }); - mock.add_mock_sync("https://api.example.com/users", |_| { - MockResponse::json(serde_json::json!({"hit": "second"})) - }); - - let r = mock - .execute(req("GET", "https://api.example.com/users", "/users")) - .await; - assert_eq!(r.body["hit"], "first"); - } - - #[tokio::test] - async fn requests_to_filters_by_pattern() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let _ = mock - .execute(req("GET", "https://api.example.com/a", "/a")) - .await; - let _ = mock.execute(req("GET", "https://other.com/b", "/b")).await; - let _ = mock - .execute(req("GET", "https://api.example.com/c", "/c")) - .await; - - let api_calls = mock.requests_to("https://api.example.com/*"); - assert_eq!(api_calls.len(), 2); - assert!(api_calls.iter().all(|r| r.url.contains("api.example.com"))); - } - - #[tokio::test] - async fn clear_requests_and_clear_mocks_independently_reset_state() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - let _ = mock.execute(req("GET", "https://x/", "/")).await; - assert_eq!(mock.requests().len(), 1); - - mock.clear_requests(); - assert!(mock.requests().is_empty()); - // Mocks survive a requests-only clear; the next call should still match. - let r = mock.execute(req("GET", "https://x/", "/")).await; - assert_eq!(r.status, 200); - - mock.clear_mocks(); - let r = mock.execute(req("GET", "https://x/", "/")).await; - assert_eq!(r.status, 500, "after clear_mocks, fallback should hit"); - } - - #[tokio::test] - async fn assert_called_with_body_runs_predicate_against_recorded_body() { - let mock = MockHttp::new(); - mock.add_mock_sync("*", |_| MockResponse::ok()); - - let mut request = req("POST", "https://api/upload", "/upload"); - request.body = serde_json::json!({"size": 42}); - let _ = mock.execute(request).await; - - // Predicate matches — should not panic. - mock.assert_called_with_body("https://api/*", |body| body["size"] == 42); - } - - #[test] - fn defaults_match_new() { - // Default impls are wrappers around new(); just exercise them so the - // Default path doesn't silently rot. - let m1 = MockHttp::default(); - assert!(m1.requests().is_empty()); - let b1 = MockHttpBuilder::default(); - let m2 = b1.build(); - assert!(m2.requests().is_empty()); - } -} diff --git a/crates/forge-core/src/testing/mod.rs b/crates/forge-core/src/testing/mod.rs index 29b85e99..66553074 100644 --- a/crates/forge-core/src/testing/mod.rs +++ b/crates/forge-core/src/testing/mod.rs @@ -2,14 +2,12 @@ pub mod assertions; pub mod context; pub mod db; pub mod mock_dispatch; -pub mod mock_email; pub mod mock_http; pub use assertions::*; pub use context::*; pub use db::{IsolatedTestDb, TestDatabase}; pub use mock_dispatch::{DispatchedJob, MockJobDispatch, MockWorkflowDispatch, StartedWorkflow}; -pub use mock_email::{MockEmailSender, SentEmail}; pub use mock_http::{MockHttp, MockHttpBuilder, MockRequest, MockResponse}; use std::time::Duration; diff --git a/crates/forge-core/src/util/mod.rs b/crates/forge-core/src/util/mod.rs index 3c99d202..3c418690 100644 --- a/crates/forge-core/src/util/mod.rs +++ b/crates/forge-core/src/util/mod.rs @@ -133,6 +133,62 @@ pub fn to_camel_case(s: &str) -> String { result } +/// Normalize an args/input envelope before deserialization. +/// +/// Job and workflow handlers accept either a bare value or a single-key +/// `{"args": …}` / `{"input": …}` wrapper depending on how the caller phrased +/// the dispatch. This helper unwraps the envelope so the handler's `Args` / +/// `Input` deserialize path doesn't have to special-case both shapes. `null` +/// is collapsed to an empty object so handlers with `()` args still match. +pub fn normalize_handler_args(args: serde_json::Value) -> serde_json::Value { + let unwrapped = match &args { + serde_json::Value::Object(map) if map.len() == 1 => { + if map.contains_key("args") { + map.get("args").cloned().unwrap_or(serde_json::Value::Null) + } else if map.contains_key("input") { + map.get("input").cloned().unwrap_or(serde_json::Value::Null) + } else { + args + } + } + _ => args, + }; + + match &unwrapped { + serde_json::Value::Null => serde_json::Value::Object(serde_json::Map::new()), + _ => unwrapped, + } +} + +/// Extract the bare hostname from an authority component (`host[:port]`), +/// stripping an IPv6 bracket pair and any port. e.g. `[::1]:8080` -> `::1`, +/// `localhost:9081` -> `localhost`, `127.0.0.1` -> `127.0.0.1`. +pub fn hostname_from_authority(authority: &str) -> &str { + match authority.strip_prefix('[') { + // IPv6 literal: the hostname is everything up to the closing bracket. + Some(rest) => rest.split(']').next().unwrap_or(rest), + None => authority.split(':').next().unwrap_or(authority), + } +} + +/// True if `hostname` is a loopback address. Expects a bare hostname with no +/// port or brackets (see [`hostname_from_authority`]). +pub fn is_loopback_host(hostname: &str) -> bool { + matches!(hostname, "localhost" | "127.0.0.1" | "::1") +} + +/// Bare hostname of a plain-`http://` URL (port and IPv6 brackets stripped), +/// or `None` if `url` is not `http://`. +/// +/// Used to decide whether a plain-HTTP endpoint is a safe loopback exception to +/// the HTTPS requirement. A naive `starts_with("http://localhost")` check would +/// wrongly accept `http://localhost.evil.com`, so callers parse the host first. +pub fn http_hostname(url: &str) -> Option<&str> { + let rest = url.strip_prefix("http://")?; + let authority = rest.split(['/', '?', '#']).next().unwrap_or(rest); + Some(hostname_from_authority(authority)) +} + #[cfg(test)] #[allow(clippy::unwrap_used, clippy::indexing_slicing)] mod tests { @@ -269,4 +325,72 @@ mod tests { assert_eq!(to_camel_case("list_all_projects"), "listAllProjects"); assert_eq!(to_camel_case("simple"), "simple"); } + + #[test] + fn normalize_handler_args_converts_null_to_empty_object() { + use serde_json::json; + assert_eq!(normalize_handler_args(json!(null)), json!({})); + } + + #[test] + fn normalize_handler_args_unwraps_args_envelope() { + use serde_json::json; + assert_eq!( + normalize_handler_args(json!({"args": {"x": 1}})), + json!({"x": 1}) + ); + assert_eq!(normalize_handler_args(json!({"args": null})), json!({})); + } + + #[test] + fn normalize_handler_args_unwraps_input_envelope() { + use serde_json::json; + assert_eq!( + normalize_handler_args(json!({"input": [1, 2]})), + json!([1, 2]) + ); + } + + #[test] + fn normalize_handler_args_preserves_other_shapes() { + use serde_json::json; + assert_eq!(normalize_handler_args(json!({"id": 7})), json!({"id": 7})); + assert_eq!(normalize_handler_args(json!([1, 2])), json!([1, 2])); + assert_eq!(normalize_handler_args(json!(42)), json!(42)); + } + + #[test] + fn hostname_from_authority_strips_port_and_brackets() { + assert_eq!(hostname_from_authority("localhost:9081"), "localhost"); + assert_eq!(hostname_from_authority("127.0.0.1"), "127.0.0.1"); + assert_eq!(hostname_from_authority("[::1]:8080"), "::1"); + assert_eq!(hostname_from_authority("[::1]"), "::1"); + assert_eq!(hostname_from_authority("example.com:443"), "example.com"); + } + + #[test] + fn is_loopback_host_matches_only_loopback() { + assert!(is_loopback_host("localhost")); + assert!(is_loopback_host("127.0.0.1")); + assert!(is_loopback_host("::1")); + assert!(!is_loopback_host("localhost.evil.com")); + assert!(!is_loopback_host("example.com")); + } + + #[test] + fn http_hostname_parses_host_and_rejects_spoofs() { + assert_eq!( + http_hostname("http://localhost:9081/jwks"), + Some("localhost") + ); + assert_eq!(http_hostname("http://[::1]:8080/jwks"), Some("::1")); + assert_eq!(http_hostname("http://127.0.0.1/cb?x=1"), Some("127.0.0.1")); + // The classic spoof: a subdomain of localhost is not loopback. + assert_eq!( + http_hostname("http://localhost.evil.com/cb"), + Some("localhost.evil.com") + ); + // Not plain HTTP. + assert_eq!(http_hostname("https://localhost/jwks"), None); + } } diff --git a/crates/forge-core/src/workflow/mod.rs b/crates/forge-core/src/workflow/mod.rs index a689b425..02091188 100644 --- a/crates/forge-core/src/workflow/mod.rs +++ b/crates/forge-core/src/workflow/mod.rs @@ -6,6 +6,6 @@ mod traits; pub use context::{CompensationHandler, StepState, WorkflowContext}; pub use events::{NoOpEventSender, WorkflowEventSender, serialize_payload}; -pub use step::{Step, StepBuilder, StepConfig, StepResult, StepStatus}; +pub use step::StepStatus; pub use suspend::{SuspendReason, WorkflowEvent}; pub use traits::{ForgeWorkflow, WorkflowDefStatus, WorkflowInfo, WorkflowStatus}; diff --git a/crates/forge-core/src/workflow/step.rs b/crates/forge-core/src/workflow/step.rs index 56518200..2d78cb44 100644 --- a/crates/forge-core/src/workflow/step.rs +++ b/crates/forge-core/src/workflow/step.rs @@ -1,15 +1,4 @@ -use std::future::Future; -use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; - -use serde::{Serialize, de::DeserializeOwned}; - -use crate::Result; - -/// Type alias for compensation function to reduce complexity. -type CompensateFn<'a, T, C> = Arc Pin> + Send + Sync + 'a>; /// Step execution status. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -25,6 +14,9 @@ pub enum StepStatus { Failed, /// Step compensation ran. Compensated, + /// Step compensation handler ran but failed; manual remediation may be + /// required for any side effects of the original step. + CompensationFailed, /// Step was skipped. Skipped, /// Step is waiting (suspended). @@ -40,6 +32,7 @@ impl StepStatus { Self::Completed => "completed", Self::Failed => "failed", Self::Compensated => "compensated", + Self::CompensationFailed => "compensation_failed", Self::Skipped => "skipped", Self::Waiting => "waiting", } @@ -67,6 +60,7 @@ impl FromStr for StepStatus { "completed" => Ok(Self::Completed), "failed" => Ok(Self::Failed), "compensated" => Ok(Self::Compensated), + "compensation_failed" => Ok(Self::CompensationFailed), "skipped" => Ok(Self::Skipped), "waiting" => Ok(Self::Waiting), _ => Err(ParseStepStatusError(s.to_string())), @@ -74,220 +68,22 @@ impl FromStr for StepStatus { } } -/// Result of a step execution. -#[derive(Debug, Clone)] -pub struct StepResult { - /// Step name. - pub name: String, - /// Step status. - pub status: StepStatus, - /// Step result (if completed). - pub value: Option, - /// Error message (if failed). - pub error: Option, -} - -/// A workflow step definition. -pub struct Step { - /// Step name. - pub name: String, - /// Step result type. - _marker: std::marker::PhantomData, -} - -impl Step { - /// Create a new step. - pub fn new(name: impl Into) -> Self { - Self { - name: name.into(), - _marker: std::marker::PhantomData, - } - } -} - -/// Builder for configuring and executing a step. -pub struct StepBuilder<'a, T, F, C> -where - T: Serialize + DeserializeOwned + Send + 'static, - F: Future> + Send + 'a, - C: Future> + Send + 'a, -{ - name: String, - run_fn: Option F + Send + 'a>>>, - compensate_fn: Option>, - timeout: Option, - retry_count: u32, - retry_delay: Duration, - optional: bool, - _marker: std::marker::PhantomData<(T, F, C)>, -} - -impl<'a, T, F, C> StepBuilder<'a, T, F, C> -where - T: Serialize + DeserializeOwned + Send + Clone + 'static, - F: Future> + Send + 'a, - C: Future> + Send + 'a, -{ - /// Create a new step builder. - pub fn new(name: impl Into) -> Self { - Self { - name: name.into(), - run_fn: None, - compensate_fn: None, - timeout: None, - retry_count: 0, - retry_delay: Duration::from_secs(1), - optional: false, - _marker: std::marker::PhantomData, - } - } - - /// Set the step execution function. - pub fn run(mut self, f: RF) -> Self - where - RF: FnOnce() -> F + Send + 'a, - { - self.run_fn = Some(Box::pin(f)); - self - } - - /// Set the compensation function. - /// - /// # Warning - /// - /// Compensation handlers are in-memory closures. They do **not** survive - /// process restarts. If the workflow suspends (via `ctx.sleep()` or - /// `ctx.wait_for_event()`) and the process restarts before the workflow - /// completes, registered compensation handlers are lost. The executor - /// detects this and fails the workflow with a message requiring manual - /// remediation. - pub fn compensate(mut self, f: CF) -> Self - where - CF: Fn(T) -> Pin> + Send + Sync + 'a, - { - self.compensate_fn = Some(Arc::new(f)); - self - } - - /// Set step timeout. - pub fn timeout(mut self, duration: Duration) -> Self { - self.timeout = Some(duration); - self - } - - /// Configure retry behavior. - pub fn retry(mut self, count: u32, delay: Duration) -> Self { - self.retry_count = count; - self.retry_delay = delay; - self - } - - /// Mark the step as optional (failure won't trigger compensation). - pub fn optional(mut self) -> Self { - self.optional = true; - self - } - - /// Get step name. - pub fn name(&self) -> &str { - &self.name - } - - /// Check if step is optional. - pub fn is_optional(&self) -> bool { - self.optional - } - - /// Get retry count. - pub fn retry_count(&self) -> u32 { - self.retry_count - } - - /// Get retry delay. - pub fn retry_delay(&self) -> Duration { - self.retry_delay - } - - /// Get timeout. - pub fn get_timeout(&self) -> Option { - self.timeout - } -} - -/// Configuration for a step (without closures, for storage). -#[derive(Debug, Clone)] -pub struct StepConfig { - /// Step name. - pub name: String, - /// Step timeout. - pub timeout: Option, - /// Retry count. - pub retry_count: u32, - /// Retry delay. - pub retry_delay: Duration, - /// Whether the step is optional. - pub optional: bool, - /// Whether the step has a compensation function. - pub has_compensation: bool, -} - -impl Default for StepConfig { - fn default() -> Self { - Self { - name: String::new(), - timeout: None, - retry_count: 0, - retry_delay: Duration::from_secs(1), - optional: false, - has_compensation: false, - } - } -} - #[cfg(test)] -#[allow(clippy::unwrap_used, clippy::indexing_slicing)] +#[allow(clippy::unwrap_used)] mod tests { use super::*; - #[test] - fn test_step_status_conversion() { - assert_eq!(StepStatus::Pending.as_str(), "pending"); - assert_eq!(StepStatus::Running.as_str(), "running"); - assert_eq!(StepStatus::Completed.as_str(), "completed"); - assert_eq!(StepStatus::Failed.as_str(), "failed"); - assert_eq!(StepStatus::Compensated.as_str(), "compensated"); - - assert_eq!("pending".parse::(), Ok(StepStatus::Pending)); - assert_eq!("completed".parse::(), Ok(StepStatus::Completed)); - } - - #[test] - fn test_step_config_default() { - let config = StepConfig::default(); - assert!(config.name.is_empty()); - assert!(!config.optional); - assert_eq!(config.retry_count, 0); - } - - #[test] - fn step_status_as_str_covers_all_variants() { - assert_eq!(StepStatus::Pending.as_str(), "pending"); - assert_eq!(StepStatus::Running.as_str(), "running"); - assert_eq!(StepStatus::Completed.as_str(), "completed"); - assert_eq!(StepStatus::Failed.as_str(), "failed"); - assert_eq!(StepStatus::Compensated.as_str(), "compensated"); - assert_eq!(StepStatus::Skipped.as_str(), "skipped"); - assert_eq!(StepStatus::Waiting.as_str(), "waiting"); - } - #[test] fn step_status_parse_roundtrips_every_variant() { + // StepStatus is persisted to and read back from the DB (executor.rs, + // state.rs), so as_str() and FromStr must stay inverses for every variant. for status in [ StepStatus::Pending, StepStatus::Running, StepStatus::Completed, StepStatus::Failed, StepStatus::Compensated, + StepStatus::CompensationFailed, StepStatus::Skipped, StepStatus::Waiting, ] { @@ -304,46 +100,4 @@ mod tests { // Display must echo the bad value so logs pinpoint the typo. assert!(err.to_string().contains("garbage")); } - - #[test] - fn step_constructor_records_name() { - let s: Step = Step::new("send_email"); - assert_eq!(s.name, "send_email"); - } - - type NoFut = Pin> + Send + 'static>>; - type NoComp = Pin> + Send + 'static>>; - - fn fresh_builder<'a>() -> StepBuilder<'a, u32, NoFut, NoComp> { - StepBuilder::new("noop") - } - - #[test] - fn step_builder_defaults() { - let b = fresh_builder(); - assert_eq!(b.name(), "noop"); - assert!(!b.is_optional()); - assert_eq!(b.retry_count(), 0); - assert_eq!(b.retry_delay(), Duration::from_secs(1)); - assert!(b.get_timeout().is_none()); - } - - #[test] - fn step_builder_optional_flag_flips() { - let b = fresh_builder().optional(); - assert!(b.is_optional()); - } - - #[test] - fn step_builder_retry_sets_count_and_delay() { - let b = fresh_builder().retry(3, Duration::from_millis(250)); - assert_eq!(b.retry_count(), 3); - assert_eq!(b.retry_delay(), Duration::from_millis(250)); - } - - #[test] - fn step_builder_timeout_setter() { - let b = fresh_builder().timeout(Duration::from_secs(5)); - assert_eq!(b.get_timeout(), Some(Duration::from_secs(5))); - } } diff --git a/crates/forge-macros/Cargo.toml b/crates/forge-macros/Cargo.toml index 328ea1dc..df76bd93 100644 --- a/crates/forge-macros/Cargo.toml +++ b/crates/forge-macros/Cargo.toml @@ -17,6 +17,8 @@ syn = { workspace = true } darling = { workspace = true } quote = { workspace = true } proc-macro2 = { workspace = true } +proc-macro-crate = { workspace = true } sqlparser = { workspace = true } cron = { workspace = true } +chrono-tz = { workspace = true } blake3 = { workspace = true } diff --git a/crates/forge-macros/src/cron.rs b/crates/forge-macros/src/cron.rs index 5a526c73..09913c8f 100644 --- a/crates/forge-macros/src/cron.rs +++ b/crates/forge-macros/src/cron.rs @@ -55,6 +55,7 @@ struct CronAttrs { } pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -145,26 +146,48 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let _vis = &input.vis; let block = &input.block; - let schedule = attrs.schedule.unwrap_or_else(|| "* * * * *".to_string()); + // Require explicit schedule — silent fallback to `* * * * *` (every + // minute) is almost always wrong. + let Some(raw_schedule) = attrs.schedule else { + return syn::Error::new_spanned( + &input.sig.ident, + "cron handlers require an explicit schedule. Use a positional cron \ + expression, `schedule = \"...\"`, `every = \"...\"`, or `daily_at = \"...\"`.", + ) + .to_compile_error() + .into(); + }; - // Normalize 5-part to 6-part (prepend seconds) to match what CronSchedule::new does. - { - let parts: Vec<&str> = schedule.split_whitespace().collect(); + // Normalize 5-part to 6-part (prepend seconds) to match what CronSchedule::new does, + // and pass the normalized form to the runtime so compile- and run-time agree. + let schedule = { + let parts: Vec<&str> = raw_schedule.split_whitespace().collect(); let normalized = if parts.len() == 5 { - format!("0 {schedule}") + format!("0 {raw_schedule}") } else { - schedule.clone() + raw_schedule.clone() }; if cron::Schedule::from_str(&normalized).is_err() { return syn::Error::new_spanned( &input.sig.ident, - format!("Invalid cron schedule: \"{schedule}\""), + format!("Invalid cron schedule: \"{raw_schedule}\""), ) .to_compile_error() .into(); } - } + normalized + }; let timezone = attrs.timezone.unwrap_or_else(|| "UTC".to_string()); + if timezone.parse::().is_err() { + return syn::Error::new_spanned( + &input.sig.ident, + format!( + "Invalid timezone: \"{timezone}\". Must be an IANA tz database name (e.g., \"UTC\", \"America/New_York\")." + ), + ) + .to_compile_error() + .into(); + } let group = attrs.group.unwrap_or_else(|| "default".to_string()); let catch_up = attrs.catch_up; let catch_up_limit = attrs.catch_up_limit.unwrap_or(10); @@ -185,7 +208,7 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.crons.register::<#struct_name>(); })); } @@ -202,15 +225,15 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::cron::ForgeCron for #struct_name { + impl #forge::forge_core::cron::ForgeCron for #struct_name { type Args = (); - fn info() -> forge::forge_core::cron::CronInfo { - forge::forge_core::cron::CronInfo { + fn info() -> #forge::forge_core::cron::CronInfo { + #forge::forge_core::cron::CronInfo { name: #rpc_name, - schedule: forge::forge_core::cron::CronSchedule::new_validated(#schedule), + schedule: #forge::forge_core::cron::CronSchedule::new_validated(#schedule), timezone: #timezone, group: #group, catch_up: #catch_up, @@ -221,8 +244,8 @@ pub fn cron_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::cron::CronContext, - ) -> std::pin::Pin> + Send + '_>> { + ctx: &#forge::forge_core::cron::CronContext, + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-macros/src/daemon.rs b/crates/forge-macros/src/daemon.rs index 4dfd588f..665260a8 100644 --- a/crates/forge-macros/src/daemon.rs +++ b/crates/forge-macros/src/daemon.rs @@ -45,6 +45,7 @@ struct DaemonAttrs { } pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -108,7 +109,7 @@ pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.daemons.register::<#struct_name>(); })); } @@ -125,11 +126,11 @@ pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::daemon::ForgeDaemon for #struct_name { - fn info() -> forge::forge_core::daemon::DaemonInfo { - forge::forge_core::daemon::DaemonInfo { + impl #forge::forge_core::daemon::ForgeDaemon for #struct_name { + fn info() -> #forge::forge_core::daemon::DaemonInfo { + #forge::forge_core::daemon::DaemonInfo { name: #rpc_name, leader_elected: #leader_elected, restart_on_panic: #restart_on_panic, @@ -141,8 +142,8 @@ pub fn daemon_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::daemon::DaemonContext, - ) -> std::pin::Pin> + Send + '_>> { + ctx: &#forge::forge_core::daemon::DaemonContext, + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-macros/src/job.rs b/crates/forge-macros/src/job.rs index ce03f292..ffa8edd2 100644 --- a/crates/forge-macros/src/job.rs +++ b/crates/forge-macros/src/job.rs @@ -129,6 +129,7 @@ struct JobAttrs { } pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -167,7 +168,18 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let mut args_type = quote! { () }; let mut args_ident = format_ident!("_args"); - for input_arg in input.sig.inputs.iter().skip(1) { + let user_args: Vec<_> = input.sig.inputs.iter().skip(1).collect(); + if user_args.len() > 1 { + return TokenStream::from( + syn::Error::new_spanned( + user_args[1], + "jobs may take at most one user argument (besides the JobContext). \ + Wrap multiple values in a single struct that derives Serialize/Deserialize.", + ) + .into_compile_error(), + ); + } + for input_arg in user_args { if let syn::FnArg::Typed(pat_type) = input_arg { if let syn::Pat::Ident(ident) = pat_type.pat.as_ref() { args_ident = ident.ident.clone(); @@ -224,27 +236,27 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let priority = if let Some(ref p) = attrs.priority { let p_lower = p.to_lowercase(); match p_lower.as_str() { - "background" => quote! { forge::forge_core::job::JobPriority::Background }, - "low" => quote! { forge::forge_core::job::JobPriority::Low }, - "normal" => quote! { forge::forge_core::job::JobPriority::Normal }, - "high" => quote! { forge::forge_core::job::JobPriority::High }, - "critical" => quote! { forge::forge_core::job::JobPriority::Critical }, - _ => quote! { forge::forge_core::job::JobPriority::Normal }, + "background" => quote! { #forge::forge_core::job::JobPriority::Background }, + "low" => quote! { #forge::forge_core::job::JobPriority::Low }, + "normal" => quote! { #forge::forge_core::job::JobPriority::Normal }, + "high" => quote! { #forge::forge_core::job::JobPriority::High }, + "critical" => quote! { #forge::forge_core::job::JobPriority::Critical }, + _ => quote! { #forge::forge_core::job::JobPriority::Normal }, } } else { - quote! { forge::forge_core::job::JobPriority::Normal } + quote! { #forge::forge_core::job::JobPriority::Normal } }; let max_attempts = attrs.max_attempts.unwrap_or(3); let backoff = if let Some(ref b) = attrs.backoff { match b.as_str() { - "fixed" => quote! { forge::forge_core::job::BackoffStrategy::Fixed }, - "linear" => quote! { forge::forge_core::job::BackoffStrategy::Linear }, - "exponential" => quote! { forge::forge_core::job::BackoffStrategy::Exponential }, - _ => quote! { forge::forge_core::job::BackoffStrategy::Exponential }, + "fixed" => quote! { #forge::forge_core::job::BackoffStrategy::Fixed }, + "linear" => quote! { #forge::forge_core::job::BackoffStrategy::Linear }, + "exponential" => quote! { #forge::forge_core::job::BackoffStrategy::Exponential }, + _ => quote! { #forge::forge_core::job::BackoffStrategy::Exponential }, } } else { - quote! { forge::forge_core::job::BackoffStrategy::Exponential } + quote! { #forge::forge_core::job::BackoffStrategy::Exponential } }; let max_backoff = if let Some(ref mb) = attrs.max_backoff { @@ -285,10 +297,10 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let compensation_args_ident = format_ident!("_comp_args"); quote! { fn compensate( - ctx: &forge::forge_core::job::JobContext, + ctx: &#forge::forge_core::job::JobContext, #compensation_args_ident: Self::Args, reason: &str, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #handler_ident(ctx, #compensation_args_ident, reason).await }) } } @@ -300,7 +312,7 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.jobs.register::<#struct_name>(); })); } @@ -317,20 +329,20 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::job::ForgeJob for #struct_name { + impl #forge::forge_core::job::ForgeJob for #struct_name { type Args = #args_type; type Output = #output_type; - fn info() -> forge::forge_core::job::JobInfo { - forge::forge_core::job::JobInfo { + fn info() -> #forge::forge_core::job::JobInfo { + #forge::forge_core::job::JobInfo { name: #fn_name_str, description: #description_tokens, timeout: #timeout, http_timeout: #http_timeout, priority: #priority, - retry: forge::forge_core::job::RetryConfig { + retry: #forge::forge_core::job::RetryConfig { max_attempts: #max_attempts, backoff: #backoff, max_backoff: #max_backoff, @@ -346,9 +358,9 @@ pub fn job_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::job::JobContext, + ctx: &#forge::forge_core::job::JobContext, #args_ident: Self::Args, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } diff --git a/crates/forge-macros/src/mcp_tool.rs b/crates/forge-macros/src/mcp_tool.rs index 8e02ffb2..d14dc03e 100644 --- a/crates/forge-macros/src/mcp_tool.rs +++ b/crates/forge-macros/src/mcp_tool.rs @@ -84,9 +84,23 @@ struct McpToolAttrs { } fn convert_mcp_tool_attrs(darling: DarlingMcpToolAttrs) -> Result { - let timeout = darling - .timeout - .and_then(|s| parse_duration_secs(&s).or_else(|| s.parse::().ok())); + // Require a unit suffix on timeouts to match every other macro. Bare + // integers like `timeout = "30"` are ambiguous (seconds? milliseconds?) + // and were only accepted here historically. + let timeout = match darling.timeout { + Some(ref s) => match parse_duration_secs(s) { + Some(t) => Some(t), + None => { + return Err(syn::Error::new( + proc_macro2::Span::call_site(), + format!( + "invalid timeout \"{s}\": use a duration string like \"30s\", \"5m\", or \"1h\"" + ), + )); + } + }, + None => None, + }; let (rate_limit_requests, rate_limit_per_secs, rate_limit_key) = if let Some(ref rl) = darling.rate_limit { @@ -159,6 +173,12 @@ fn tool_type_stem(fn_name: &str) -> &str { } fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result { + let forge = crate::utils::forge_path(); + // schemars(crate = "...") needs a literal string. Build it from the + // resolved forge prefix at expansion time so a renamed dep still emits + // a working path. Tokens like `::forge` and `::forgex` render with the + // leading colons, which schemars accepts. + let schemars_crate_str = format!("{}::forge_core::schemars", forge); let fn_name = &input.sig.ident; let fn_name_str = attrs.name.unwrap_or_else(|| fn_name.to_string()); validate_tool_name(&fn_name_str)?; @@ -212,8 +232,7 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result = params.iter().skip(1).cloned().collect(); @@ -359,8 +378,8 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result syn::Result syn::Result forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } } } else if arg_names.is_empty() { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } }; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.mcp_tools.register::<#struct_name>(); })); } @@ -433,14 +452,14 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result forge::forge_core::McpToolInfo { - forge::forge_core::McpToolInfo { + fn info() -> #forge::forge_core::McpToolInfo { + #forge::forge_core::McpToolInfo { name: #fn_name_str, title: #title, description: #description, @@ -450,7 +469,7 @@ fn expand_mcp_tool_impl(input: ItemFn, attrs: McpToolAttrs) -> syn::Result syn::Result std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #execute_call }) diff --git a/crates/forge-macros/src/model.rs b/crates/forge-macros/src/model.rs index 0737c0c8..5eac3e83 100644 --- a/crates/forge-macros/src/model.rs +++ b/crates/forge-macros/src/model.rs @@ -1,7 +1,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenStream as TokenStream2; use quote::quote; -use syn::{Data, DeriveInput, Fields, Meta, parse_macro_input, spanned::Spanned}; +use syn::{Data, DeriveInput, Fields, parse_macro_input, spanned::Spanned}; pub fn expand_model(attr: TokenStream, item: TokenStream) -> TokenStream { let input_clone = item.clone(); @@ -18,6 +18,7 @@ fn expand_model_impl( input: DeriveInput, _original_tokens: TokenStream2, ) -> syn::Result { + let forge = crate::utils::forge_path(); let attr_str = attr.to_string(); let trimmed = attr_str.trim(); if !trimmed.is_empty() { @@ -55,8 +56,8 @@ fn expand_model_impl( quote! { { - let rust_type = forge::forge_core::schema::RustType::from_type_string(#type_str); - let mut field = forge::forge_core::schema::FieldDef::new(#name, rust_type); + let rust_type = #forge::forge_core::schema::RustType::from_type_string(#type_str); + let mut field = #forge::forge_core::schema::FieldDef::new(#name, rust_type); field.column_name = #column_name.to_string(); field } @@ -89,11 +90,11 @@ fn expand_model_impl( #(#field_defs),* } - impl forge::forge_core::schema::ModelMeta for #struct_name { + impl #forge::forge_core::schema::ModelMeta for #struct_name { const TABLE_NAME: &'static str = #table_name; - fn table_def() -> forge::forge_core::schema::TableDef { - let mut table = forge::forge_core::schema::TableDef::new(#table_name, stringify!(#struct_name)); + fn table_def() -> #forge::forge_core::schema::TableDef { + let mut table = #forge::forge_core::schema::TableDef::new(#table_name, stringify!(#struct_name)); table.fields = vec![ #(#field_tokens),* ]; @@ -110,18 +111,23 @@ fn expand_model_impl( } fn get_table_name(input: &DeriveInput) -> syn::Result { - // Look for #[table(name = "...")] + // Look for #[table(name = "...")]. parse_nested_meta correctly handles + // escaped quotes and inner whitespace, unlike the previous string-slice + // parser which choked on `name = "with \"escape\""` and similar. for attr in &input.attrs { if attr.path().is_ident("table") { - let meta = attr.meta.clone(); - if let Meta::List(list) = meta { - let tokens: TokenStream2 = list.tokens; - let tokens_str = tokens.to_string(); - if tokens_str.starts_with("name") - && let Some(value) = extract_string_value(&tokens_str) - { - return Ok(value); + let mut found: Option = None; + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("name") { + let lit: syn::LitStr = meta.value()?.parse()?; + found = Some(lit.value()); + Ok(()) + } else { + Err(meta.error("expected `name = \"...\"`")) } + })?; + if let Some(name) = found { + return Ok(name); } } } @@ -131,18 +137,6 @@ fn get_table_name(input: &DeriveInput) -> syn::Result { Ok(pluralize(&name)) } -fn extract_string_value(s: &str) -> Option { - // Parse "name = \"value\"" pattern - let parts: Vec<&str> = s.splitn(2, '=').collect(); - if parts.len() == 2 { - let value = parts[1].trim(); - if let Some(stripped) = value.strip_prefix('"').and_then(|s| s.strip_suffix('"')) { - return Some(stripped.to_string()); - } - } - None -} - use crate::utils::{pluralize, to_snake_case}; #[cfg(test)] @@ -197,24 +191,6 @@ mod tests { assert_eq!(pluralize("buy"), "buys"); } - #[test] - fn extract_string_value_valid() { - assert_eq!( - extract_string_value(r#"name = "custom_table""#), - Some("custom_table".to_string()) - ); - } - - #[test] - fn extract_string_value_no_quotes() { - assert_eq!(extract_string_value("name = bare_value"), None); - } - - #[test] - fn extract_string_value_no_equals() { - assert_eq!(extract_string_value(r#""just a string""#), None); - } - // --- Table name derivation (integration of to_snake_case + pluralize) --- #[test] diff --git a/crates/forge-macros/src/mutation.rs b/crates/forge-macros/src/mutation.rs index bd90cfd5..5e0a8e29 100644 --- a/crates/forge-macros/src/mutation.rs +++ b/crates/forge-macros/src/mutation.rs @@ -215,6 +215,7 @@ fn convert_mutation_attrs(darling: DarlingMutationAttrs) -> Result syn::Result { + let forge = crate::utils::forge_path(); let fn_name = &input.sig.ident; let fn_name_str = fn_name.to_string(); let rpc_name = attrs.name.as_deref().unwrap_or(&fn_name_str).to_string(); @@ -250,53 +251,27 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result, found: bool, } - impl DispatchCallVisitor { - fn receiver_root_ident(mut expr: &syn::Expr) -> Option<&syn::Ident> { - loop { - match expr { - syn::Expr::MethodCall(inner) => expr = &inner.receiver, - syn::Expr::Try(inner) => expr = &inner.expr, - syn::Expr::Await(inner) => expr = &inner.base, - syn::Expr::Paren(inner) => expr = &inner.expr, - syn::Expr::Reference(inner) => expr = &inner.expr, - syn::Expr::Path(path) => { - if path.qself.is_none() && path.path.segments.len() == 1 { - return path.path.segments.first().map(|s| &s.ident); - } - return None; - } - _ => return None, - } - } - } - - fn receiver_is_ctx(&self, receiver: &syn::Expr) -> bool { - let Some(ref ctx) = self.ctx_ident else { - return true; - }; - Self::receiver_root_ident(receiver).is_some_and(|root| root == ctx) - } - } impl<'ast> syn::visit::Visit<'ast> for DispatchCallVisitor { fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) { let method = node.method.to_string(); - if (method == "dispatch_job" || method == "start_workflow") - && self.receiver_is_ctx(&node.receiver) - { + if method == "dispatch_job" || method == "start_workflow" { self.found = true; } syn::visit::visit_expr_method_call(self, node); } } - let mut visitor = DispatchCallVisitor { - ctx_ident: mutation_ctx_ident.clone(), - found: false, - }; + let mut visitor = DispatchCallVisitor { found: false }; syn::visit::visit_block(&mut visitor, fn_block); visitor.found }; @@ -323,17 +298,50 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result, found: bool, } + impl HttpCallVisitor { + fn receiver_root_ident(mut expr: &syn::Expr) -> Option<&syn::Ident> { + loop { + match expr { + syn::Expr::MethodCall(inner) => expr = &inner.receiver, + syn::Expr::Try(inner) => expr = &inner.expr, + syn::Expr::Await(inner) => expr = &inner.base, + syn::Expr::Paren(inner) => expr = &inner.expr, + syn::Expr::Reference(inner) => expr = &inner.expr, + syn::Expr::Path(path) => { + if path.qself.is_none() && path.path.segments.len() == 1 { + return path.path.segments.first().map(|s| &s.ident); + } + return None; + } + _ => return None, + } + } + } + fn receiver_is_ctx(&self, receiver: &syn::Expr) -> bool { + let Some(ref ctx) = self.ctx_ident else { + return true; + }; + Self::receiver_root_ident(receiver).is_some_and(|root| root == ctx) + } + } impl<'ast> syn::visit::Visit<'ast> for HttpCallVisitor { fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) { - if node.method == "http" { + // Gate on the receiver root resolving to the mutation context + // binding. Without this, any builder method named `.http()` + // on an unrelated type would trip the lint. + if node.method == "http" && self.receiver_is_ctx(&node.receiver) { self.found = true; } syn::visit::visit_expr_method_call(self, node); } } - let mut visitor = HttpCallVisitor { found: false }; + let mut visitor = HttpCallVisitor { + ctx_ident: mutation_ctx_ident.clone(), + found: false, + }; syn::visit::visit_block(&mut visitor, fn_block); if visitor.found { return Err(syn::Error::new_spanned( @@ -388,8 +396,7 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result = params.iter().skip(1).cloned().collect(); @@ -479,16 +486,18 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result { let key_tokens = match k.as_str() { - "user" => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, - "ip" => quote! { forge::forge_core::rate_limit::RateLimitKey::Ip }, - "tenant" => quote! { forge::forge_core::rate_limit::RateLimitKey::Tenant }, - "user_action" => quote! { forge::forge_core::rate_limit::RateLimitKey::UserAction }, - "global" => quote! { forge::forge_core::rate_limit::RateLimitKey::Global }, + "user" => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, + "ip" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Ip }, + "tenant" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Tenant }, + "user_action" => { + quote! { #forge::forge_core::rate_limit::RateLimitKey::UserAction } + } + "global" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Global }, _ if k.starts_with("custom:") => { let claim = k.trim_start_matches("custom:"); - quote! { forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } + quote! { #forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } } - _ => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, + _ => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, }; quote! { Some(#key_tokens) } } @@ -498,13 +507,13 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result { let level_tokens = match l.as_str() { - "trace" => quote! { forge::forge_core::LogLevel::Trace }, - "debug" => quote! { forge::forge_core::LogLevel::Debug }, - "info" => quote! { forge::forge_core::LogLevel::Info }, - "warn" => quote! { forge::forge_core::LogLevel::Warn }, - "error" => quote! { forge::forge_core::LogLevel::Error }, - "off" => quote! { forge::forge_core::LogLevel::Off }, - _ => quote! { forge::forge_core::LogLevel::Trace }, + "trace" => quote! { #forge::forge_core::LogLevel::Trace }, + "debug" => quote! { #forge::forge_core::LogLevel::Debug }, + "info" => quote! { #forge::forge_core::LogLevel::Info }, + "warn" => quote! { #forge::forge_core::LogLevel::Warn }, + "error" => quote! { #forge::forge_core::LogLevel::Error }, + "off" => quote! { #forge::forge_core::LogLevel::Off }, + _ => quote! { #forge::forge_core::LogLevel::Trace }, }; quote! { Some(#level_tokens) } } @@ -523,6 +532,12 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result = if let Some(ref tables) = attrs.tables { tables.clone() } else { + if let Some(issue) = extractor.issues.first() { + return Err(syn::Error::new_spanned( + &input.sig.ident, + issue.describe(&fn_name_str, "mutation"), + )); + } match extract_tables_from_sql(&extractor.sql_strings) { TableExtractionResult::Ok(tables) => { let mut sorted: Vec = tables.into_iter().collect(); @@ -671,29 +686,29 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } } } else if arg_names.is_empty() { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } }; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.functions.register_mutation::<#struct_name>(); })); } @@ -711,17 +726,17 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result forge::forge_core::FunctionInfo { - forge::forge_core::FunctionInfo { + fn info() -> #forge::forge_core::FunctionInfo { + #forge::forge_core::FunctionInfo { name: #rpc_name, description: #description, - kind: forge::forge_core::FunctionKind::Mutation, + kind: #forge::forge_core::FunctionKind::Mutation, required_role: #required_role, is_public: #is_public, cache_ttl: None, @@ -742,9 +757,9 @@ fn expand_mutation_impl(input: ItemFn, attrs: MutationAttrs) -> syn::Result std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #execute_call }) diff --git a/crates/forge-macros/src/query.rs b/crates/forge-macros/src/query.rs index d6988c0a..a6503c91 100644 --- a/crates/forge-macros/src/query.rs +++ b/crates/forge-macros/src/query.rs @@ -209,6 +209,7 @@ fn convert_query_attrs(darling: DarlingQueryAttrs) -> Result syn::Result { + let forge = crate::utils::forge_path(); let fn_name = &input.sig.ident; let fn_name_str = fn_name.to_string(); let rpc_name = attrs.name.as_deref().unwrap_or(&fn_name_str).to_string(); @@ -258,8 +259,7 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result = if let Some(explicit_tables) = attrs.tables { @@ -268,6 +268,16 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result { let mut sorted: Vec = tables.into_iter().collect(); @@ -458,16 +468,18 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result { let key_tokens = match k.as_str() { - "user" => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, - "ip" => quote! { forge::forge_core::rate_limit::RateLimitKey::Ip }, - "tenant" => quote! { forge::forge_core::rate_limit::RateLimitKey::Tenant }, - "user_action" => quote! { forge::forge_core::rate_limit::RateLimitKey::UserAction }, - "global" => quote! { forge::forge_core::rate_limit::RateLimitKey::Global }, + "user" => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, + "ip" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Ip }, + "tenant" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Tenant }, + "user_action" => { + quote! { #forge::forge_core::rate_limit::RateLimitKey::UserAction } + } + "global" => quote! { #forge::forge_core::rate_limit::RateLimitKey::Global }, _ if k.starts_with("custom:") => { let claim = k.trim_start_matches("custom:"); - quote! { forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } + quote! { #forge::forge_core::rate_limit::RateLimitKey::Custom(#claim.to_string()) } } - _ => quote! { forge::forge_core::rate_limit::RateLimitKey::User }, + _ => quote! { #forge::forge_core::rate_limit::RateLimitKey::User }, }; quote! { Some(#key_tokens) } } @@ -477,13 +489,13 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result { let level_tokens = match l.as_str() { - "trace" => quote! { forge::forge_core::LogLevel::Trace }, - "debug" => quote! { forge::forge_core::LogLevel::Debug }, - "info" => quote! { forge::forge_core::LogLevel::Info }, - "warn" => quote! { forge::forge_core::LogLevel::Warn }, - "error" => quote! { forge::forge_core::LogLevel::Error }, - "off" => quote! { forge::forge_core::LogLevel::Off }, - _ => quote! { forge::forge_core::LogLevel::Trace }, + "trace" => quote! { #forge::forge_core::LogLevel::Trace }, + "debug" => quote! { #forge::forge_core::LogLevel::Debug }, + "info" => quote! { #forge::forge_core::LogLevel::Info }, + "warn" => quote! { #forge::forge_core::LogLevel::Warn }, + "error" => quote! { #forge::forge_core::LogLevel::Error }, + "off" => quote! { #forge::forge_core::LogLevel::Off }, + _ => quote! { #forge::forge_core::LogLevel::Trace }, }; quote! { Some(#level_tokens) } } @@ -551,29 +563,29 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: #ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } } } else if arg_names.is_empty() { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type) -> #forge::forge_core::Result<#output_type> #fn_block } } else { quote! { #(#fn_attrs)* - #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> forge::forge_core::Result<#output_type> #fn_block + #vis async fn #fn_name(#ctx_name: &#ctx_type, #(#arg_params),*) -> #forge::forge_core::Result<#output_type> #fn_block } }; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.functions.register_query::<#struct_name>(); })); } @@ -591,17 +603,17 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result forge::forge_core::FunctionInfo { - forge::forge_core::FunctionInfo { + fn info() -> #forge::forge_core::FunctionInfo { + #forge::forge_core::FunctionInfo { name: #rpc_name, description: #description, - kind: forge::forge_core::FunctionKind::Query, + kind: #forge::forge_core::FunctionKind::Query, required_role: #required_role, is_public: #is_public, cache_ttl: #cache_ttl, @@ -622,9 +634,9 @@ fn expand_query_impl(input: ItemFn, attrs: QueryAttrs) -> syn::Result std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move { #execute_call }) diff --git a/crates/forge-macros/src/sql_extractor.rs b/crates/forge-macros/src/sql_extractor.rs index 2a8436f3..fedb115a 100644 --- a/crates/forge-macros/src/sql_extractor.rs +++ b/crates/forge-macros/src/sql_extractor.rs @@ -11,6 +11,53 @@ use sqlparser::parser::Parser; use syn::visit::Visit; use syn::{Expr as SynExpr, ExprCall, ExprLit, ExprMacro, ExprMethodCall}; +/// Reasons that SQL extraction can't reason about the call site at all. +/// Surfaced to callers so they can emit a clear compile error directing the +/// user to set explicit `tables(...)`. +#[derive(Debug, Clone)] +pub enum SqlAnalysisIssue { + /// `sqlx::query(&some_string)` or `sqlx::query_as::<_, T>(...)` — runtime + /// variant that bypasses compile-time checking entirely. + RuntimeSqlxCall, + /// Inside a sqlx::query!{} macro, the SQL is built via `format!`, + /// `String::from`, `concat!`, or other non-literal construction. + DynamicSqlInMacro, + /// SQL string is hoisted into a `const`/`let` binding or `include_str!`, + /// so the macro can't see the literal at the call site. + HoistedSqlBinding, + /// `sqlx::query!{}` received a byte-string literal which would otherwise + /// be silently dropped. + ByteStringInMacro, +} + +impl SqlAnalysisIssue { + pub fn describe(&self, fn_name: &str, macro_kind: &str) -> String { + let header = match self { + Self::RuntimeSqlxCall => format!( + "`{fn_name}` calls runtime `sqlx::query()`/`sqlx::query_as::<_, T>()`. \ + Use the `sqlx::query!` / `sqlx::query_as!` macros for compile-time checks." + ), + Self::DynamicSqlInMacro => format!( + "`{fn_name}` builds SQL dynamically (e.g. `format!`, `String::from`, `concat!`) \ + inside a `sqlx::query!` macro. Table dependencies and the scope lint cannot be \ + verified." + ), + Self::HoistedSqlBinding => format!( + "`{fn_name}` references SQL via `const`, `let`, or `include_str!` inside a \ + `sqlx::query!` macro. The literal is invisible to the extractor." + ), + Self::ByteStringInMacro => format!( + "`{fn_name}` passes a byte-string literal to a `sqlx::query!` macro. \ + SQL must be a regular string literal." + ), + }; + format!( + "{header}\n\ + Add #[{macro_kind}(tables(\"your_table\"))] to declare table dependencies explicitly." + ) + } +} + /// Detects `.pool()` calls in a handler body, signalling DB work delegated /// to a helper function whose SQL is invisible to `SqlStringExtractor`. pub struct DbDelegationDetector { @@ -35,12 +82,17 @@ impl<'ast> Visit<'ast> for DbDelegationDetector { /// Visitor that extracts SQL string literals from function bodies. pub struct SqlStringExtractor { pub sql_strings: Vec, + /// Patterns that defeat static SQL analysis. Callers should treat any + /// non-empty list as a hard compile error unless explicit `tables(...)` + /// was provided. + pub issues: Vec, } impl SqlStringExtractor { pub fn new() -> Self { Self { sql_strings: Vec::new(), + issues: Vec::new(), } } @@ -83,6 +135,13 @@ impl SqlStringExtractor { match token { proc_macro2::TokenTree::Literal(lit) => { let lit_str = lit.to_string(); + // Reject byte strings outright — they parse as syn::LitStr + // failures and would otherwise be silently dropped. + let trimmed = lit_str.trim_start(); + if trimmed.starts_with("b\"") || trimmed.starts_with("br") { + self.issues.push(SqlAnalysisIssue::ByteStringInMacro); + continue; + } if let Some(sql) = Self::extract_string_content(&lit_str) && Self::looks_like_sql(&sql) { @@ -97,6 +156,75 @@ impl SqlStringExtractor { } } + /// Inspect the first token-stream argument passed to a `sqlx::query!` + /// macro and decide whether the SQL is recoverable as a literal. Flags + /// `format!(...)`, `concat!(...)`, `String::from(...)`, `include_str!`, + /// and bare identifier references (hoisted into `const SQL` or `let sql`). + fn check_macro_first_arg(&mut self, tokens: &proc_macro2::TokenStream) { + // Peek at the leading token sequence before the first `,` separator. + let mut head: Vec = Vec::new(); + for tt in tokens.clone() { + if let proc_macro2::TokenTree::Punct(ref p) = tt + && p.as_char() == ',' + { + break; + } + head.push(tt); + } + + // Strip leading `&` references — `sqlx::query!(&sql, ...)` is the + // same shape from our perspective. + let mut idx = 0; + while let Some(proc_macro2::TokenTree::Punct(p)) = head.get(idx) { + if p.as_char() == '&' { + idx += 1; + } else { + break; + } + } + let head = &head[idx..]; + + match head { + // Single string literal — handled by extract_sql_from_tokens. + [proc_macro2::TokenTree::Literal(_)] => {} + // Bare identifier: `query!(SQL)` or `query!(my_sql)` — hoisted. + [proc_macro2::TokenTree::Ident(_)] => { + self.issues.push(SqlAnalysisIssue::HoistedSqlBinding); + } + // `format!(...)`, `concat!(...)`, `include_str!(...)`, or a + // path-qualified call like `String::from(...)`. Detect by an + // ident followed by `!` or `(` / `::`. + _ if head.len() >= 2 => { + if let proc_macro2::TokenTree::Ident(first) = &head[0] { + let name = first.to_string(); + let next = &head[1]; + let is_macro_call = + matches!(next, proc_macro2::TokenTree::Punct(p) if p.as_char() == '!'); + let is_path = + matches!(next, proc_macro2::TokenTree::Punct(p) if p.as_char() == ':'); + let is_call = matches!(next, proc_macro2::TokenTree::Group(_)); + if is_macro_call + && matches!( + name.as_str(), + "format" | "concat" | "include_str" | "format_args" + ) + { + if name == "include_str" { + self.issues.push(SqlAnalysisIssue::HoistedSqlBinding); + } else { + self.issues.push(SqlAnalysisIssue::DynamicSqlInMacro); + } + } else if is_path || is_call { + // `String::from(...)`, `format!`, or general fn call — + // treat as dynamic. + self.issues.push(SqlAnalysisIssue::DynamicSqlInMacro); + } + } + } + _ => {} + } + } + /// Extract the actual string content from a literal representation. /// Delegates parsing and unescaping to syn so raw, byte, and escaped /// forms all decode through the same canonical path. @@ -130,20 +258,25 @@ impl<'ast> Visit<'ast> for SqlStringExtractor { fn visit_expr_call(&mut self, node: &'ast ExprCall) { if let SynExpr::Path(path) = &*node.func { - let path_str = path + let last = path .path .segments - .iter() + .last() .map(|s| s.ident.to_string()) - .collect::>() - .join("::"); + .unwrap_or_default(); - if (path_str.contains("query") - || path_str.ends_with("query_as") - || path_str.ends_with("raw_sql")) - && let Some(first_arg) = node.args.first() - { - self.visit_expr(first_arg); + // Runtime sqlx calls (no compile-time checks): `sqlx::query(...)`, + // `sqlx::query_as::<_, T>(...)`, `sqlx::query_scalar(...)`, etc. + // Match on the final path segment exactly — `query_helper` or + // `my_query` do not count. + if matches!( + last.as_str(), + "query" | "query_as" | "query_scalar" | "query_with" | "raw_sql" + ) { + self.issues.push(SqlAnalysisIssue::RuntimeSqlxCall); + if let Some(first_arg) = node.args.first() { + self.visit_expr(first_arg); + } } } @@ -171,13 +304,49 @@ impl<'ast> Visit<'ast> for SqlStringExtractor { macro_name.as_str(), "query" | "query_as" | "query_scalar" | "query_as_unchecked" | "query_scalar_unchecked" ) { - self.extract_sql_from_tokens(&node.mac.tokens); + // `query_as!(Type, sql, ...)` and `query_as_unchecked!(Type, sql, ...)` + // put the row type as the first arg. Skip past it so we inspect + // the actual SQL token, not the type ident. + let sql_tokens = if matches!(macro_name.as_str(), "query_as" | "query_as_unchecked") { + skip_first_macro_arg(&node.mac.tokens) + } else { + node.mac.tokens.clone() + }; + self.check_macro_first_arg(&sql_tokens); + self.extract_sql_from_tokens(&sql_tokens); } syn::visit::visit_expr_macro(self, node); } } +/// Drop the first comma-separated argument (and the comma itself) from a +/// macro's raw token stream. Used to strip the row type from +/// `query_as!(Type, sql, ...)` before inspecting the SQL token. +fn skip_first_macro_arg(tokens: &proc_macro2::TokenStream) -> proc_macro2::TokenStream { + let mut depth = 0i32; + let mut seen_comma = false; + let mut out: Vec = Vec::new(); + for tt in tokens.clone() { + if seen_comma { + out.push(tt); + continue; + } + if let proc_macro2::TokenTree::Punct(ref p) = tt { + if p.as_char() == ',' && depth == 0 { + seen_comma = true; + continue; + } + if matches!(p.as_char(), '<') { + depth += 1; + } else if matches!(p.as_char(), '>') { + depth -= 1; + } + } + } + out.into_iter().collect() +} + /// Parse SQL strings and extract all selected column names. /// Returns bare column names (without table qualifiers). pub fn extract_columns_from_sql(sql_strings: &[String]) -> HashSet { @@ -642,11 +811,32 @@ fn expr_mentions_tenant(e: &Expr) -> bool { Expr::InList { expr, list, .. } => { expr_mentions_tenant(expr) || list.iter().any(expr_mentions_tenant) } - Expr::InSubquery { expr, .. } => expr_mentions_tenant(expr), + Expr::InSubquery { expr, subquery, .. } => { + expr_mentions_tenant(expr) || query_mentions_tenant(subquery) + } Expr::Between { expr, low, high, .. } => expr_mentions_tenant(expr) || expr_mentions_tenant(low) || expr_mentions_tenant(high), Expr::IsNull(e) | Expr::IsNotNull(e) => expr_mentions_tenant(e), + // Mirror expr_has_scope so `(claims->>'tenant_id')::uuid = $1`, + // `EXISTS (SELECT ... WHERE tenant_id = $1)`, and Snowflake-style + // `obj:tenant_id` are all recognized. + Expr::Subquery(q) | Expr::Exists { subquery: q, .. } => query_mentions_tenant(q), + Expr::JsonAccess { value, path } => { + expr_mentions_tenant(value) + || path.path.iter().any(|elem| match elem { + sqlparser::ast::JsonPathElem::Dot { key, .. } => { + key.eq_ignore_ascii_case("tenant_id") + } + sqlparser::ast::JsonPathElem::Bracket { key } => match key { + Expr::Value(sqlparser::ast::Value::SingleQuotedString(s)) + | Expr::Value(sqlparser::ast::Value::DoubleQuotedString(s)) => { + s.eq_ignore_ascii_case("tenant_id") + } + _ => false, + }, + }) + } _ => false, } } @@ -717,8 +907,40 @@ fn stmt_is_scoped(stmt: &Statement) -> bool { let mut ctx = ScopeCtx::new(); match stmt { Statement::Query(q) => query_is_scoped(q, &mut ctx), - Statement::Update { selection, .. } => selection.as_ref().is_some_and(expr_has_scope), - Statement::Delete(d) => d.selection.as_ref().is_some_and(expr_has_scope), + Statement::Update { + selection, from, .. + } => { + // UPDATE ... FROM ... WHERE ... — the FROM clause can carry the + // scope predicate via a join expression. Walk both. + if selection.as_ref().is_some_and(expr_has_scope) { + return true; + } + if let Some(from) = from { + let twj = match from { + sqlparser::ast::UpdateTableFromKind::BeforeSet(t) => t, + sqlparser::ast::UpdateTableFromKind::AfterSet(t) => t, + }; + if twj_has_scope_on_join(twj) { + return true; + } + } + false + } + Statement::Delete(d) => { + if d.selection.as_ref().is_some_and(expr_has_scope) { + return true; + } + // PG-style `DELETE FROM t USING ... WHERE ...` puts the scope + // predicate on the join in USING. Walk it. + if let Some(using) = &d.using { + for twj in using { + if twj_has_scope_on_join(twj) { + return true; + } + } + } + false + } _ => false, } } @@ -838,6 +1060,27 @@ fn source_is_scoped(factor: &TableFactor, ctx: &ScopeCtx) -> bool { } } +/// True if any JOIN ON clause attached to the given TableWithJoins carries a +/// scope predicate. Used for UPDATE/DELETE where the scope often lives on a +/// join in the FROM/USING clause rather than the top-level WHERE. +fn twj_has_scope_on_join(twj: &TableWithJoins) -> bool { + for join in &twj.joins { + let constraint = match &join.join_operator { + sqlparser::ast::JoinOperator::Inner(c) + | sqlparser::ast::JoinOperator::LeftOuter(c) + | sqlparser::ast::JoinOperator::RightOuter(c) + | sqlparser::ast::JoinOperator::FullOuter(c) => c, + _ => continue, + }; + if let sqlparser::ast::JoinConstraint::On(e) = constraint + && expr_has_scope(e) + { + return true; + } + } + false +} + fn expr_has_scope(e: &Expr) -> bool { match e { Expr::Identifier(ident) => is_scope_col(&ident.value), @@ -854,11 +1097,22 @@ fn expr_has_scope(e: &Expr) -> bool { | BinaryOperator::HashLongArrow ) { expr_has_scope(left) || value_is_scope_col(right) - } else if matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq) - && (is_direct_scope_ref(left) && is_literal_value(right) - || is_direct_scope_ref(right) && is_literal_value(left)) - { - false + } else if matches!(op, BinaryOperator::Eq | BinaryOperator::NotEq) { + // Scope only passes when ONE side is a direct scope reference + // (or JSON-arrow into one) AND the other side is a $param + // binding. Comparing scope col to a hardcoded literal or to + // another column doesn't bind the row to the caller. + if (is_direct_scope_ref(left) && is_placeholder_value(right)) + || (is_direct_scope_ref(right) && is_placeholder_value(left)) + { + true + } else if is_direct_scope_ref(left) || is_direct_scope_ref(right) { + // Scope col compared to a literal or to another column — + // explicitly not scoped. + false + } else { + expr_has_scope(left) || expr_has_scope(right) + } } else { expr_has_scope(left) || expr_has_scope(right) } @@ -867,12 +1121,15 @@ fn expr_has_scope(e: &Expr) -> bool { Expr::Between { expr, low, high, .. } => expr_has_scope(expr) || expr_has_scope(low) || expr_has_scope(high), - Expr::IsNull(e) - | Expr::IsNotNull(e) - | Expr::IsTrue(e) - | Expr::IsNotTrue(e) - | Expr::IsFalse(e) - | Expr::IsNotFalse(e) => expr_has_scope(e), + // IS [NOT] NULL / TRUE / FALSE never compares against a parameter, + // so even if the operand names a scope column the predicate doesn't + // bind the row to the current principal. Reject these outright. + Expr::IsNull(_) + | Expr::IsNotNull(_) + | Expr::IsTrue(_) + | Expr::IsNotTrue(_) + | Expr::IsFalse(_) + | Expr::IsNotFalse(_) => false, Expr::InList { expr, list, .. } => expr_has_scope(expr) || list.iter().any(expr_has_scope), Expr::InSubquery { expr, subquery, .. } => { let sub_scoped = query_is_scoped(subquery, &mut ScopeCtx::new()); @@ -914,12 +1171,12 @@ fn is_direct_scope_ref(e: &Expr) -> bool { } } -/// True if the expression is a non-placeholder literal. Placeholders ($1, $2) are -/// acceptable scope column counterparts; hardcoded literals are not. -fn is_literal_value(e: &Expr) -> bool { +/// True if the expression eventually reduces to a parameter placeholder +/// (`$1`, `$2`, ...). Unwraps Cast/Nested wrappers so `$1::uuid` counts. +fn is_placeholder_value(e: &Expr) -> bool { match e { - Expr::Value(v) => !matches!(v, sqlparser::ast::Value::Placeholder(_)), - Expr::Cast { expr, .. } | Expr::Nested(expr) => is_literal_value(expr), + Expr::Value(sqlparser::ast::Value::Placeholder(_)) => true, + Expr::Cast { expr, .. } | Expr::Nested(expr) => is_placeholder_value(expr), _ => false, } } diff --git a/crates/forge-macros/src/utils.rs b/crates/forge-macros/src/utils.rs index 17e61580..ba783c3d 100644 --- a/crates/forge-macros/src/utils.rs +++ b/crates/forge-macros/src/utils.rs @@ -2,8 +2,42 @@ use std::time::Duration; +use proc_macro_crate::{FoundCrate, crate_name}; use proc_macro2::TokenStream; -use quote::quote; +use quote::{format_ident, quote}; + +/// Resolve the path to the host `forge` crate at expansion time. +/// +/// The crate is published as the `forgex` package but its library is named +/// `forge` (`[lib] name` in `crates/forge/Cargo.toml`) so users write +/// `use forge::...`. `proc-macro-crate` returns the *dependency key* from the +/// consumer's `Cargo.toml`, which doesn't always equal the extern crate name +/// rustc sees: +/// +/// * `forge = { package = "forgex" }` (the scaffolded default) → key `forge`, +/// which is also the extern name. Emit `::forge`. +/// * a bare `forgex = "x"` dependency (what `cargo add forgex` produces, and +/// what `trybuild` generates) → key `forgex`, but rustc only knows the crate +/// by its lib name `forge`, so the key can't be used verbatim. Normalize the +/// package name back to the lib name. +/// * an explicit rename `myalias = { package = "forgex" }` → key `myalias`, +/// which *is* the extern name. Emit `::myalias`. +pub fn forge_path() -> TokenStream { + match crate_name("forgex") { + Ok(FoundCrate::Itself) => quote!(crate), + Ok(FoundCrate::Name(name)) => { + // proc-macro-crate hands back the dependency key; for a non-renamed + // `forgex` dep that key is the package name, but the crate is only + // reachable under its lib name `forge`. + let extern_name = if name == "forgex" { "forge" } else { &name }; + let ident = format_ident!("{}", extern_name); + quote!(::#ident) + } + // Not resolvable as a direct dependency (transitive use, or a context + // proc-macro-crate can't read). The standard binding is `forge`. + Err(_) => quote!(::forge), + } +} /// Convert a snake_case string to PascalCase. pub fn to_pascal_case(s: &str) -> String { @@ -27,15 +61,20 @@ fn parse_duration(s: &str) -> Option { } else if let Some(num) = s.strip_suffix('s') { num.parse::().ok().map(Duration::from_secs) } else if let Some(num) = s.strip_suffix('m') { - num.parse::().ok().map(|m| Duration::from_secs(m * 60)) + num.parse::() + .ok() + .and_then(|m| m.checked_mul(60)) + .map(Duration::from_secs) } else if let Some(num) = s.strip_suffix('h') { num.parse::() .ok() - .map(|h| Duration::from_secs(h * 3600)) + .and_then(|h| h.checked_mul(3600)) + .map(Duration::from_secs) } else if let Some(num) = s.strip_suffix('d') { num.parse::() .ok() - .map(|d| Duration::from_secs(d * 86400)) + .and_then(|d| d.checked_mul(86400)) + .map(Duration::from_secs) } else { // Bare integers without a unit suffix are not accepted. Require explicit // suffixes (e.g. "30s") so intent is unambiguous at the macro callsite. @@ -74,28 +113,19 @@ pub fn parse_duration_tokens(s: &str, default_secs: u64) -> TokenStream { Err(_) => invalid(), } } else if let Some(num) = s.strip_suffix('m') { - match num.parse::() { - Ok(n) => { - let secs = n * 60; - quote! { std::time::Duration::from_secs(#secs) } - } - Err(_) => invalid(), + match num.parse::().ok().and_then(|n| n.checked_mul(60)) { + Some(secs) => quote! { std::time::Duration::from_secs(#secs) }, + None => invalid(), } } else if let Some(num) = s.strip_suffix('h') { - match num.parse::() { - Ok(n) => { - let secs = n * 3600; - quote! { std::time::Duration::from_secs(#secs) } - } - Err(_) => invalid(), + match num.parse::().ok().and_then(|n| n.checked_mul(3600)) { + Some(secs) => quote! { std::time::Duration::from_secs(#secs) }, + None => invalid(), } } else if let Some(num) = s.strip_suffix('d') { - match num.parse::() { - Ok(n) => { - let secs = n * 86400; - quote! { std::time::Duration::from_secs(#secs) } - } - Err(_) => invalid(), + match num.parse::().ok().and_then(|n| n.checked_mul(86400)) { + Some(secs) => quote! { std::time::Duration::from_secs(#secs) }, + None => invalid(), } } else { let _ = default_secs; diff --git a/crates/forge-macros/src/webhook.rs b/crates/forge-macros/src/webhook.rs index 72302d08..caf0472e 100644 --- a/crates/forge-macros/src/webhook.rs +++ b/crates/forge-macros/src/webhook.rs @@ -167,6 +167,7 @@ fn parse_signature_from_meta(attr_args: &[NestedMeta]) -> Result TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -279,24 +280,24 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { ) { let alg_token = match alg { WebhookSignatureAlgorithm::HmacSha256 => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::HmacSha256 } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::HmacSha256 } } WebhookSignatureAlgorithm::StripeWebhooks => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::StripeWebhooks } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::StripeWebhooks } } WebhookSignatureAlgorithm::HmacSha256Base64 => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::HmacSha256Base64 } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::HmacSha256Base64 } } WebhookSignatureAlgorithm::Ed25519 => { - quote! { forge::forge_core::webhook::SignatureAlgorithm::Ed25519 } + quote! { #forge::forge_core::webhook::SignatureAlgorithm::Ed25519 } } }; let replay_window_tokens = match attrs.replay_window_secs { Some(secs) => quote! { #secs }, - None => quote! { forge::forge_core::webhook::DEFAULT_REPLAY_WINDOW_SECS }, + None => quote! { #forge::forge_core::webhook::DEFAULT_REPLAY_WINDOW_SECS }, }; quote! { - Some(forge::forge_core::webhook::SignatureConfig { + Some(#forge::forge_core::webhook::SignatureConfig { algorithm: #alg_token, header_name: #header, secret_env: #secret_env, @@ -312,15 +313,15 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { match prefix { "header" => { quote! { - Some(forge::forge_core::webhook::IdempotencyConfig::new( - forge::forge_core::webhook::IdempotencySource::Header(#value) + Some(#forge::forge_core::webhook::IdempotencyConfig::new( + #forge::forge_core::webhook::IdempotencySource::Header(#value) )) } } "body" => { quote! { - Some(forge::forge_core::webhook::IdempotencyConfig::new( - forge::forge_core::webhook::IdempotencySource::Body(#value) + Some(#forge::forge_core::webhook::IdempotencyConfig::new( + #forge::forge_core::webhook::IdempotencySource::Body(#value) )) } } @@ -337,10 +338,10 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.webhooks.register::<#struct_name>(); registries.functions.register_webhook_info( - forge::forge_core::FunctionInfo::from(&#struct_name::info()) + #forge::forge_core::FunctionInfo::from(&#struct_name::info()) ); })); } @@ -357,13 +358,13 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #(#other_attrs)* pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::webhook::ForgeWebhook for #struct_name { + impl #forge::forge_core::webhook::ForgeWebhook for #struct_name { type Payload = #payload_type; - fn info() -> forge::forge_core::webhook::WebhookInfo { - forge::forge_core::webhook::WebhookInfo { + fn info() -> #forge::forge_core::webhook::WebhookInfo { + #forge::forge_core::webhook::WebhookInfo { name: #rpc_name, description: #description_tokens, path: #path, @@ -376,9 +377,9 @@ pub fn webhook_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::webhook::WebhookContext, + ctx: &#forge::forge_core::webhook::WebhookContext, payload: #payload_type, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } diff --git a/crates/forge-macros/src/workflow.rs b/crates/forge-macros/src/workflow.rs index e3a8e04a..9abe894a 100644 --- a/crates/forge-macros/src/workflow.rs +++ b/crates/forge-macros/src/workflow.rs @@ -377,6 +377,31 @@ impl<'ast> Visit<'ast> for ContractExtractor { } } +/// Normalize a quote!-stringified type so signature hashes are stable across +/// toolchain upgrades that re-shuffle whitespace. Collapses runs of +/// whitespace to a single space and trims the ends. Does not re-parse — a +/// full canonicalization via syn would also normalize generics and qualified +/// paths, but the whitespace fix covers the common drift. +fn canonicalize_type_str(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + let mut last_was_space = true; // suppress leading whitespace + for ch in s.chars() { + if ch.is_whitespace() { + if !last_was_space { + out.push(' '); + last_was_space = true; + } + } else { + out.push(ch); + last_was_space = false; + } + } + if out.ends_with(' ') { + out.pop(); + } + out +} + /// Derives a 32-char hex-encoded blake3 hash (128 bits) of name, version, /// step keys, wait keys, timeout, and input/output type name strings. /// @@ -438,6 +463,7 @@ fn derive_signature( } pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { + let forge = crate::utils::forge_path(); let input = parse_macro_input!(item as ItemFn); let attr_args = match NestedMeta::parse_meta_list(attr.into()) { @@ -544,17 +570,30 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } }; - let version_str = attrs.version.as_deref().unwrap_or("v1"); + // Workflow versions must be explicit. A default of "v1" collides with + // a later explicit `version = "v1"` and silently keys both into the + // same WorkflowRegistry slot. + let Some(ref version_owned) = attrs.version else { + return syn::Error::new_spanned( + &input.sig.ident, + "workflow requires an explicit `version = \"...\"` attribute. \ + Pin a starting version (e.g. `version = \"v1\"`) so future revisions \ + can run alongside in-flight runs without signature collisions.", + ) + .to_compile_error() + .into(); + }; + let version_str = version_owned.as_str(); let is_public = attrs.is_public; let workflow_status = match attrs.status { WorkflowStatus::Active => { - quote! { forge::forge_core::workflow::WorkflowDefStatus::Active } + quote! { #forge::forge_core::workflow::WorkflowDefStatus::Active } } WorkflowStatus::Deprecated => { - quote! { forge::forge_core::workflow::WorkflowDefStatus::Deprecated } + quote! { #forge::forge_core::workflow::WorkflowDefStatus::Deprecated } } WorkflowStatus::Staging => { - quote! { forge::forge_core::workflow::WorkflowDefStatus::Staging } + quote! { #forge::forge_core::workflow::WorkflowDefStatus::Staging } } }; @@ -583,21 +622,28 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { quote! { None } }; + // Canonicalize type-token whitespace so a syn/quote upgrade that adds or + // drops a space (e.g. `MyType` vs `MyType < Inner >`) doesn't + // silently flip the signature for every in-flight run. Pragmatic minimal + // form: trim + collapse internal runs to a single space. + let canonical_input = canonicalize_type_str(&input_type_str); + let canonical_output = canonicalize_type_str(&output_type_str); + let signature = derive_signature( workflow_name, version_str, &contract_extractor.step_keys, &contract_extractor.wait_keys, timeout_secs, - &input_type_str, - &output_type_str, + &canonical_input, + &canonical_output, ); let fn_attrs = &input.attrs; let registration = if attrs.register { quote! { - forge::inventory::submit!(forge::AutoHandler(|registries| { + #forge::inventory::submit!(#forge::AutoHandler(|registries| { registries.workflows.register::<#struct_name>(); })); } @@ -635,14 +681,14 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { #[doc = #contract_doc] pub struct #struct_name; - impl forge::forge_core::__sealed::Sealed for #struct_name {} + impl #forge::forge_core::__sealed::Sealed for #struct_name {} - impl forge::forge_core::workflow::ForgeWorkflow for #struct_name { + impl #forge::forge_core::workflow::ForgeWorkflow for #struct_name { type Input = #input_type; type Output = #output_type; - fn info() -> forge::forge_core::workflow::WorkflowInfo { - forge::forge_core::workflow::WorkflowInfo { + fn info() -> #forge::forge_core::workflow::WorkflowInfo { + #forge::forge_core::workflow::WorkflowInfo { name: #workflow_name, version: #version_str, signature: #signature, @@ -655,9 +701,9 @@ pub fn workflow_impl(attr: TokenStream, item: TokenStream) -> TokenStream { } fn execute( - ctx: &forge::forge_core::workflow::WorkflowContext, + ctx: &#forge::forge_core::workflow::WorkflowContext, #input_ident: Self::Input, - ) -> std::pin::Pin> + Send + '_>> { + ) -> std::pin::Pin> + Send + '_>> { Box::pin(async move #block) } } From 7ddf11a77c7a398920b32db7b5ebe86b04c5c2b8 Mon Sep 17 00:00:00 2001 From: Isala Piyarisi Date: Mon, 25 May 2026 05:14:25 +0530 Subject: [PATCH 2/7] harden runtime gateway, jobs, workflows, realtime, and signals Auth/session handling (SSE tickets, session revocation, refresh cookie), webhook rate limiting, multipart limits, cron timezone and catch-up, job retry backoff precision and dead-lettering, workflow resume, leader-election zombie eviction, reactor/subscription reliability, and signals session persistence. --- crates/forge-runtime/src/cluster/heartbeat.rs | 28 +- crates/forge-runtime/src/cluster/registry.rs | 48 +++- crates/forge-runtime/src/cron/bridge.rs | 19 ++ crates/forge-runtime/src/cron/registry.rs | 10 +- crates/forge-runtime/src/cron/scheduler.rs | 112 +++++++- crates/forge-runtime/src/daemon/runner.rs | 96 ++++++- crates/forge-runtime/src/gateway/admin.rs | 52 ++++ crates/forge-runtime/src/gateway/auth.rs | 144 ++++++++-- crates/forge-runtime/src/gateway/jwks.rs | 101 +++++-- crates/forge-runtime/src/gateway/mcp/mod.rs | 28 +- .../forge-runtime/src/gateway/mcp/session.rs | 61 +++-- crates/forge-runtime/src/gateway/mcp/tools.rs | 77 ++++-- crates/forge-runtime/src/gateway/mod.rs | 5 +- crates/forge-runtime/src/gateway/multipart.rs | 138 ++++++++-- crates/forge-runtime/src/gateway/oauth.rs | 73 ++--- crates/forge-runtime/src/gateway/rpc.rs | 25 +- crates/forge-runtime/src/gateway/server.rs | 226 ++++++++++++---- crates/forge-runtime/src/gateway/sse.rs | 251 +++++++++++++++--- crates/forge-runtime/src/gateway/tls.rs | 24 ++ crates/forge-runtime/src/jobs/dispatcher.rs | 180 +------------ crates/forge-runtime/src/jobs/executor.rs | 64 ++++- crates/forge-runtime/src/jobs/queue.rs | 93 ++++++- crates/forge-runtime/src/jobs/registry.rs | 70 +---- crates/forge-runtime/src/jobs/worker.rs | 36 ++- crates/forge-runtime/src/kv/store.rs | 114 ++++++-- crates/forge-runtime/src/mcp/registry.rs | 17 +- crates/forge-runtime/src/observability/db.rs | 89 +++++-- .../src/observability/metrics.rs | 146 ++++++++-- .../src/observability/telemetry.rs | 39 ++- crates/forge-runtime/src/pg/change_log.rs | 63 +++-- crates/forge-runtime/src/pg/leader.rs | 140 +++++++--- .../forge-runtime/src/pg/migration/runner.rs | 95 ++++--- crates/forge-runtime/src/pg/mod.rs | 2 +- crates/forge-runtime/src/pg/notify.rs | 97 +++++-- crates/forge-runtime/src/pg/notify_bus.rs | 25 +- crates/forge-runtime/src/pg/pool.rs | 79 +++++- .../forge-runtime/src/rate_limit/limiter.rs | 35 ++- crates/forge-runtime/src/realtime/listener.rs | 26 +- crates/forge-runtime/src/realtime/manager.rs | 87 +++--- crates/forge-runtime/src/realtime/message.rs | 84 +++++- crates/forge-runtime/src/realtime/reactor.rs | 233 +++++++++++----- crates/forge-runtime/src/signals/bot.rs | 12 +- crates/forge-runtime/src/signals/collector.rs | 49 +++- crates/forge-runtime/src/signals/endpoints.rs | 189 +++++++++---- crates/forge-runtime/src/signals/partition.rs | 21 ++ .../forge-runtime/src/signals/rate_limit.rs | 44 ++- crates/forge-runtime/src/signals/session.rs | 7 +- crates/forge-runtime/src/signals/tests.rs | 100 +++++++ crates/forge-runtime/src/signals/visitor.rs | 20 ++ crates/forge-runtime/src/webhook/handler.rs | 121 ++++++++- crates/forge-runtime/src/workflow/bridge.rs | 22 ++ crates/forge-runtime/src/workflow/executor.rs | 231 ++++++++++++++-- crates/forge-runtime/src/workflow/registry.rs | 142 +++------- .../forge-runtime/src/workflow/scheduler.rs | 66 ++++- 54 files changed, 3329 insertions(+), 1027 deletions(-) diff --git a/crates/forge-runtime/src/cluster/heartbeat.rs b/crates/forge-runtime/src/cluster/heartbeat.rs index b42dce5a..a8e01f3a 100644 --- a/crates/forge-runtime/src/cluster/heartbeat.rs +++ b/crates/forge-runtime/src/cluster/heartbeat.rs @@ -22,7 +22,10 @@ impl Default for HeartbeatConfig { interval: Duration::from_secs(5), dead_threshold: Duration::from_secs(15), mark_dead_nodes: true, - max_interval: Duration::from_secs(60), + // Capped at 30 s (was 60 s). Combined with the 3x dead_threshold + // ceiling below, worst-case detection lag drops from ~180 s to + // ~90 s for a crash during the stable phase. + max_interval: Duration::from_secs(30), } } } @@ -59,7 +62,9 @@ impl HeartbeatConfig { interval: *cluster.heartbeat_interval, dead_threshold: *cluster.dead_threshold, mark_dead_nodes: true, - max_interval: Duration::from_secs(cluster.heartbeat_interval.as_secs() * 12), + // Adaptive ceiling: 6x base (was 12x). Caps stable-phase interval + // at a tighter bound so dead-node detection stays under 90 s. + max_interval: Duration::from_secs(cluster.heartbeat_interval.as_secs() * 6), } } } @@ -76,6 +81,12 @@ pub struct HeartbeatLoop { stable_count: AtomicU32, last_active_count: AtomicU32, /// Dedicated connection held outside the shared pool for liveness safety. + /// + /// **Persistent-connection budget**: this connection counts as the + /// 5th persistent slot alongside the 4 listed in `pg/pool.rs` (notify + /// bus listener, leader lock-owning connection, change-log listener, + /// signals writer). Pool sizing must allow `min_connections >= 5` plus + /// burst headroom or normal RPC workload contends for the remaining slots. heartbeat_conn: Mutex>, } @@ -194,12 +205,17 @@ impl HeartbeatLoop { let mut guard = self.heartbeat_conn.lock().await; if guard.ping().await.is_err() { tracing::debug!("Heartbeat connection lost; reconnecting"); + // Acquire the replacement first, then swap. If acquire fails we + // keep the (broken) old handle so we don't permanently lose the + // slot — next call retries. Explicitly drop the prior handle so + // its slot is returned to the pool before we hold the new one. let new_conn = self .pool .acquire() .await .map_err(forge_core::ForgeError::Database)?; - *guard = new_conn; + let old = std::mem::replace(&mut *guard, new_conn); + drop(old); } Ok(guard) } @@ -319,7 +335,7 @@ mod tests { assert_eq!(config.interval, Duration::from_secs(5)); assert_eq!(config.dead_threshold, Duration::from_secs(15)); assert!(config.mark_dead_nodes); - assert_eq!(config.max_interval, Duration::from_secs(60)); + assert_eq!(config.max_interval, Duration::from_secs(30)); } #[test] @@ -332,8 +348,8 @@ mod tests { assert_eq!(config.interval, Duration::from_secs(10)); assert_eq!(config.dead_threshold, Duration::from_secs(30)); assert!(config.mark_dead_nodes); - // max_interval = heartbeat_interval * 12 - assert_eq!(config.max_interval, Duration::from_secs(120)); + // max_interval = heartbeat_interval * 6 + assert_eq!(config.max_interval, Duration::from_secs(60)); } #[test] diff --git a/crates/forge-runtime/src/cluster/registry.rs b/crates/forge-runtime/src/cluster/registry.rs index ba1cd69f..e9bf8581 100644 --- a/crates/forge-runtime/src/cluster/registry.rs +++ b/crates/forge-runtime/src/cluster/registry.rs @@ -1,6 +1,11 @@ +use std::time::Duration; + use forge_core::cluster::{NodeInfo, NodeStatus}; use forge_core::{ForgeError, Result}; +/// How often the background compactor sweeps long-dead node rows. +const COMPACTION_INTERVAL: Duration = Duration::from_secs(6 * 60 * 60); + /// Node registry for cluster membership. pub struct NodeRegistry { pool: sqlx::PgPool, @@ -9,7 +14,48 @@ pub struct NodeRegistry { impl NodeRegistry { pub fn new(pool: sqlx::PgPool, local_node: NodeInfo) -> Self { - Self { pool, local_node } + let registry = Self { pool, local_node }; + registry.spawn_cleanup_loop(); + registry + } + + /// Periodically delete `forge_nodes` rows that have been `dead` for + /// more than 7 days. Pod churn would otherwise accumulate rows + /// indefinitely. Runs every 6 h on a detached task; failures are logged + /// and the loop continues. + fn spawn_cleanup_loop(&self) { + let pool = self.pool.clone(); + tokio::spawn(async move { + let mut ticker = tokio::time::interval(COMPACTION_INTERVAL); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // First tick fires immediately; skip it so startup doesn't compact. + ticker.tick().await; + loop { + ticker.tick().await; + // Untyped: the parameterless DELETE produces no row data and + // adding a `.sqlx` entry for it just for the compile-time + // check has zero safety value. Allow lints locally. + #[allow(clippy::disallowed_methods)] + let res = sqlx::query( + "DELETE FROM forge_nodes \ + WHERE status = 'dead' \ + AND last_heartbeat < NOW() - INTERVAL '7 days'", + ) + .execute(&pool) + .await; + match res { + Ok(result) => { + let n = result.rows_affected(); + if n > 0 { + tracing::info!(rows = n, "Compacted dead forge_nodes rows"); + } + } + Err(e) => { + tracing::warn!(error = %e, "forge_nodes compaction failed"); + } + } + } + }); } pub async fn register(&self) -> Result<()> { diff --git a/crates/forge-runtime/src/cron/bridge.rs b/crates/forge-runtime/src/cron/bridge.rs index 8e0d9c77..61491f73 100644 --- a/crates/forge-runtime/src/cron/bridge.rs +++ b/crates/forge-runtime/src/cron/bridge.rs @@ -61,6 +61,25 @@ pub fn register_cron_bridges(cron_registry: &Arc, job_registry: &m } handler(&cron_ctx).await?; + + // Transition the claimed run from 'running' to 'completed'. The + // scheduler only ever INSERTs status='running'; this is the sole + // place a successful run is finalized, so catch-up's "last + // completed scheduled_time" lookup and operator dashboards see a + // terminal state. Scoped to status='running' so a concurrent + // stale-reclaim that already rotated the id cannot be clobbered. + sqlx::query!( + r#" + UPDATE forge_cron_runs + SET status = 'completed', completed_at = NOW(), error = NULL + WHERE id = $1 AND status = 'running' + "#, + run_id, + ) + .execute(ctx.pool()) + .await + .map_err(forge_core::ForgeError::Database)?; + Ok(serde_json::Value::Null) }) }); diff --git a/crates/forge-runtime/src/cron/registry.rs b/crates/forge-runtime/src/cron/registry.rs index 9fa5c1fa..6609169b 100644 --- a/crates/forge-runtime/src/cron/registry.rs +++ b/crates/forge-runtime/src/cron/registry.rs @@ -38,7 +38,15 @@ impl CronRegistry { } pub fn register(&mut self) { - let entry = CronEntry::new::(); + self.register_entry(CronEntry::new::()); + } + + /// Insert a pre-built [`CronEntry`], keyed by its `info.name`. + /// + /// `register::` is the public path; this primitive exists so in-crate + /// tests can register a handler without a `ForgeCron` impl (the trait is + /// sealed and cannot be implemented outside `forge-core`). + pub(crate) fn register_entry(&mut self, entry: CronEntry) { self.crons.insert(entry.info.name.to_string(), entry); } diff --git a/crates/forge-runtime/src/cron/scheduler.rs b/crates/forge-runtime/src/cron/scheduler.rs index 8779c8c9..e7d082ea 100644 --- a/crates/forge-runtime/src/cron/scheduler.rs +++ b/crates/forge-runtime/src/cron/scheduler.rs @@ -223,9 +223,14 @@ impl CronRunner { async { let now = Utc::now(); + // Window is poll_interval * 4 (was *2). The wider window covers + // short GC pauses, leader-loss recovery, and DB stalls that + // otherwise drop ticks for crons without catch_up. Inverse cost + // is bounded by the UNIQUE (cron_name, scheduled_time) constraint: + // re-checking the same slot twice claims at most one job. let window_start = now - - chrono::Duration::from_std(self.config.poll_interval * 2) - .unwrap_or(chrono::Duration::seconds(2)); + - chrono::Duration::from_std(self.config.poll_interval * 4) + .unwrap_or(chrono::Duration::seconds(4)); let cron_list = self.registry.list(); let mut jobs_executed = 0u32; @@ -250,10 +255,11 @@ impl CronRunner { .between_in_tz(window_start, now, info.timezone); if scheduled_times.len() > 1 { - tracing::info!( + tracing::warn!( cron.name = info.name, cron.missed_count = scheduled_times.len() - 1, - "Detected missed cron runs" + catch_up_enabled = info.catch_up, + "missed cron tick: more than one scheduled time fell into the poll window" ); Span::current().record("cron.missed_runs", scheduled_times.len() - 1); } @@ -268,6 +274,19 @@ impl CronRunner { } for scheduled in scheduled_times { + // Re-check leadership between inserts so a node that lost + // the lock mid-tick stops enqueueing slots tagged with its + // node_id. UNIQUE constraint bounds the damage, but this + // cuts observability noise. + if let Some(election) = self.config.leader_election.as_ref() + && !election.is_leader() + { + tracing::debug!( + cron = info.name, + "Leadership lost mid-tick; aborting remaining enqueues" + ); + break; + } if let Ok(Some(_run_id)) = self.try_claim_and_enqueue(entry, scheduled, false).await { @@ -899,4 +918,89 @@ mod integration_tests { assert!(leader.confirm_leadership_before_tick().await); assert!(!follower.confirm_leadership_before_tick().await); } + + #[tokio::test] + async fn dispatched_cron_run_is_marked_completed_on_success() { + // The scheduler only ever writes status='running'. A successful run must + // be finalized to 'completed' by the `$cron:` bridge job that the worker + // executes. This drives the real claim path (try_claim_and_enqueue) to + // create the running row + job, then runs the real bridge handler against + // the dispatched job's input. No row is seeded — completion is observed + // end-to-end, not asserted on a hand-written 'completed' row. + use crate::cron::register_cron_bridges; + use forge_core::CircuitBreakerClient; + use forge_core::job::JobContext; + + let db = setup_db("cron_run_completed").await; + let pool = db.pool().clone(); + + // Real claim: inserts the 'running' row and enqueues the $cron: job. + let runner = make_runner(pool.clone()); + let entry = make_entry("nightly_report", "0 0 * * * *"); + let scheduled = Utc::now() - chrono::Duration::seconds(30); + let run_id = runner + .try_claim_and_enqueue(&entry, scheduled, false) + .await + .expect("claim ok") + .expect("claimed some id"); + + let status_after_claim: String = + sqlx::query_scalar("SELECT status FROM forge_cron_runs WHERE id = $1") + .bind(run_id) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!( + status_after_claim, "running", + "claim must leave the run in 'running'" + ); + + // Build the real bridge handler for this cron. + let cron_registry = Arc::new({ + let mut reg = CronRegistry::new(); + reg.register_entry(make_entry("nightly_report", "0 0 * * * *")); + reg + }); + let mut job_registry = crate::jobs::registry::JobRegistry::new(); + register_cron_bridges(&cron_registry, &mut job_registry); + + let bridge = job_registry + .get("$cron:nightly_report") + .expect("bridge handler registered"); + + // Feed the handler the exact input the scheduler queued for this run. + let input: serde_json::Value = sqlx::query_scalar( + "SELECT input FROM forge_jobs WHERE job_type = '$cron:nightly_report'", + ) + .fetch_one(&pool) + .await + .unwrap(); + + let ctx = JobContext::new( + Uuid::new_v4(), + "$cron:nightly_report".to_string(), + 0, + 3, + pool.clone(), + CircuitBreakerClient::with_ssrf_protection(), + ); + (bridge.handler)(&ctx, input) + .await + .expect("bridge handler succeeds"); + + let (status, completed_at): (String, Option>) = + sqlx::query_as("SELECT status, completed_at FROM forge_cron_runs WHERE id = $1") + .bind(run_id) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!( + status, "completed", + "successful run must become 'completed'" + ); + assert!( + completed_at.is_some(), + "completed_at must be set on completion" + ); + } } diff --git a/crates/forge-runtime/src/daemon/runner.rs b/crates/forge-runtime/src/daemon/runner.rs index 5c410484..4ee41d50 100644 --- a/crates/forge-runtime/src/daemon/runner.rs +++ b/crates/forge-runtime/src/daemon/runner.rs @@ -110,6 +110,7 @@ impl DaemonRunner { tracing::info!(count = self.registry.len(), "Daemon runner starting"); let mut daemon_handles: HashMap = HashMap::new(); + let mut join_handles: HashMap> = HashMap::new(); for (name, entry) in self.registry.daemons() { let info = &entry.info; @@ -160,7 +161,7 @@ impl DaemonRunner { None }; - tokio::spawn(async move { + let jh = tokio::spawn(async move { run_daemon_loop( daemon_name, daemon_entry, @@ -180,6 +181,7 @@ impl DaemonRunner { .await }); + join_handles.insert(name.to_string(), jh); daemon_handles.insert(name.to_string(), handle); } @@ -191,7 +193,37 @@ impl DaemonRunner { let _ = handle.shutdown_tx.send(true); } - tokio::time::sleep(Duration::from_secs(2)).await; + // Cap the drain at 10 s but exit early when all daemons have + // signalled they observed the shutdown. The previous fixed 2 s + // recorded `status='stopped'` while slow daemons were still + // draining; the cap keeps shutdown bounded for the same reason + // a strict block would not (a wedged daemon must not stall the + // node forever). + const MAX_DRAIN: Duration = Duration::from_secs(10); + let drain_deadline = tokio::time::Instant::now() + MAX_DRAIN; + + // Await each daemon's task with a bounded deadline so + // `record_daemon_stop` never races ahead of the lease-refresher + // and lock-validator tasks the loop spawned per iteration. + // Abort any task that exceeds the deadline so shutdown stays + // bounded. + for (name, jh) in join_handles.drain() { + let remaining = + drain_deadline.saturating_duration_since(tokio::time::Instant::now()); + match tokio::time::timeout(remaining, jh).await { + Ok(Ok(())) => {} + Ok(Err(e)) => { + tracing::warn!(daemon = %name, error = %e, "Daemon task join failed"); + } + Err(_) => { + tracing::warn!( + daemon = %name, + "Daemon drain exceeded {:?}; aborting task to keep shutdown bounded", + MAX_DRAIN, + ); + } + } + } for (name, handle) in &daemon_handles { if let Err(e) = self.record_daemon_stop(handle).await { @@ -326,8 +358,35 @@ async fn run_daemon_loop( } Ok(false) => { tracing::debug!("Waiting for leadership"); + // If the election has a notify-bus attached, wake on + // the leader-released NOTIFY so a standby takes over + // immediately on voluntary release. Otherwise fall + // back to the 5 s poll. Filter for our own role so + // unrelated NOTIFYs don't trigger spurious wakeups. + let mut release_rx = election.subscribe_release_notify(); + // Payload is `LeaderRole::as_str()`; for Daemon variants this is + // the daemon name verbatim. + let role_str = name.clone(); + let wait = async { + if let Some(rx) = release_rx.as_mut() { + loop { + match rx.recv().await { + Ok(payload) if payload == role_str => return, + Ok(_) => continue, + Err(tokio::sync::broadcast::error::RecvError::Lagged(_)) => return, + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + // Bus is down; fall back to sleep. + tokio::time::sleep(Duration::from_secs(5)).await; + return; + } + } + } + } else { + tokio::time::sleep(Duration::from_secs(5)).await; + } + }; tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(5)) => {} + _ = wait => {} _ = shutdown_rx.changed() => { tracing::debug!("Shutdown while waiting for leadership"); Span::current().record("daemon.final_status", "shutdown_waiting_leadership"); @@ -592,7 +651,36 @@ async fn run_daemon_loop( if let Some(ref election) = election && let Err(e) = election.release_leadership().await { - tracing::debug!(daemon = %name, error = %e, "Failed to release leadership"); + tracing::error!( + daemon = %name, + error = %e, + "Failed to release leadership; clearing forge_leaders row so standbys can take over without waiting for the full lease" + ); + // Force-clear the leader row scoped to this node. Standbys would + // otherwise wait the full lease_duration (60 s) before + // preempting. WHERE node_id = $2 ensures we never wipe a row + // another node has already taken. + // forge_leaders is a runtime-owned system table; offline .sqlx + // cache doesn't always include it. + #[allow(clippy::disallowed_methods)] + let force_clear = sqlx::query( + r#" + DELETE FROM forge_leaders + WHERE role = $1 AND node_id = $2 + "#, + ) + .bind(name.as_str()) + .bind(node_id) + .execute(&pool) + .await; + if let Err(e2) = force_clear + { + tracing::error!( + daemon = %name, + error = %e2, + "Failed to clear forge_leaders row after release failure; standbys will wait for lease expiry" + ); + } } tracing::info!( diff --git a/crates/forge-runtime/src/gateway/admin.rs b/crates/forge-runtime/src/gateway/admin.rs index ef5c7fb4..b5dd2997 100644 --- a/crates/forge-runtime/src/gateway/admin.rs +++ b/crates/forge-runtime/src/gateway/admin.rs @@ -22,6 +22,7 @@ //! - `POST /_api/admin/queues/{name}/resume` //! - `GET /_api/admin/nodes` //! - `GET /_api/admin/leaders` +//! - `POST /_api/admin/sessions/{session_id}/revoke body: {reason?}` use std::sync::Arc; @@ -37,6 +38,9 @@ use sqlx::PgPool; use uuid::Uuid; use forge_core::function::AuthContext; +use forge_core::realtime::SessionId; + +use crate::realtime::Reactor; use super::tracing::TracingState; @@ -44,6 +48,9 @@ use super::tracing::TracingState; #[derive(Clone)] pub struct AdminState { pub db_pool: PgPool, + /// Reactor handle for session-revocation. None when running headless (e.g. + /// migration-only commands) — the route then returns 503. + pub reactor: Option>, } /// Build the admin router. Returns `None` when no admin handler can do any @@ -69,6 +76,7 @@ pub fn admin_router(state: AdminState) -> Router { .route("/admin/queues/{name}/resume", post(resume_queue)) .route("/admin/nodes", get(list_nodes)) .route("/admin/leaders", get(list_leaders)) + .route("/admin/sessions/{session_id}/revoke", post(revoke_session)) .with_state(Arc::new(state)) } @@ -1148,6 +1156,50 @@ async fn list_leaders( } } +/// Revoke a session's cached `AuthContext` so the reactor stops re-pushing +/// data tied to that session. Operators wire this to their identity system's +/// revocation event (role demotion, tenant move, manual sign-out across all +/// devices); the client must reconnect and re-subscribe with a fresh token to +/// resume receiving updates. +async fn revoke_session( + State(state): State>, + Extension(auth): Extension, + Extension(tracing_state): Extension, + Path(session_id): Path, + body: Option>, +) -> axum::response::Response { + if let Err(r) = require_admin(&auth) { + return r; + } + let reactor = match state.reactor.as_ref() { + Some(r) => r.clone(), + None => { + return admin_err( + StatusCode::SERVICE_UNAVAILABLE, + "reactor_unavailable", + "Realtime reactor is not running on this node", + ); + } + }; + let reason = body.and_then(|b| b.0.reason); + let reason_str = reason.as_deref().unwrap_or("admin revoke"); + reactor + .revoke_session_auth(SessionId(session_id), reason_str) + .await; + audit( + &state.db_pool, + &auth, + Some(&tracing_state), + "session.revoke", + "session", + Some(session_id.to_string()), + reason.as_deref(), + serde_json::json!({"session_id": session_id}), + ) + .await; + Json(serde_json::json!({"status": "revoked"})).into_response() +} + #[cfg(test)] #[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)] mod tests { diff --git a/crates/forge-runtime/src/gateway/auth.rs b/crates/forge-runtime/src/gateway/auth.rs index 7aa37e58..4d438686 100644 --- a/crates/forge-runtime/src/gateway/auth.rs +++ b/crates/forge-runtime/src/gateway/auth.rs @@ -17,12 +17,14 @@ use tracing::debug; use super::jwks::JwksClient; /// Derive a stable, opaque key id from an HMAC secret. We take the first -/// 8 hex chars of `SHA-256(secret_bytes)` — short enough to keep token -/// headers small, deterministic so the same secret always produces the -/// same kid, and one-way so it leaks nothing useful about the secret. +/// 16 hex chars (8 bytes / 64 bits) of `SHA-256(secret_bytes)` — small +/// enough to keep token headers compact while large enough to make kid +/// collisions across rotated secrets infeasible. Deterministic so the +/// same secret always produces the same kid, and one-way so it leaks +/// nothing useful about the secret. fn secret_kid(secret: &[u8]) -> String { let hash = Sha256::digest(secret); - let prefix = hash.as_slice().get(..4).unwrap_or(&[]); + let prefix = hash.as_slice().get(..8).unwrap_or(&[]); let mut out = String::with_capacity(prefix.len() * 2); for byte in prefix { use std::fmt::Write; @@ -346,6 +348,11 @@ pub struct AuthMiddleware { /// (Claims, expiry). The 256-bit key makes collisions cryptographically /// infeasible, so a hit unambiguously identifies the same token. token_cache: Arc>, + /// Monotonic instant (seconds since process start) of the last cache + /// sweep. Stored as `AtomicU64` so eviction is lock-free on the hot path. + last_cache_sweep_secs: Arc, + /// Process-start anchor for `last_cache_sweep_secs`. + cache_sweep_epoch: std::time::Instant, } impl std::fmt::Debug for AuthMiddleware { @@ -409,6 +416,8 @@ impl AuthMiddleware { hmac_kid, legacy_hmac_keys, token_cache: Arc::new(dashmap::DashMap::new()), + last_cache_sweep_secs: Arc::new(std::sync::atomic::AtomicU64::new(0)), + cache_sweep_epoch: std::time::Instant::now(), } } @@ -475,22 +484,52 @@ impl AuthMiddleware { const MAX_CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(60); let exp = claims.exp(); let now = chrono::Utc::now().timestamp(); - let remaining = if exp > now { - std::time::Duration::from_secs((exp - now) as u64) - } else { - std::time::Duration::ZERO - }; - remaining.min(MAX_CACHE_TTL) + // `exp` is i64 (JWT spec); guard against negative remaining so a + // platform with a 32-bit `time_t` or skewed clock can't underflow + // into an absurdly large u64 TTL. + let remaining_secs = u64::try_from(exp.saturating_sub(now)).unwrap_or(0); + std::time::Duration::from_secs(remaining_secs).min(MAX_CACHE_TTL) } /// Periodically evict expired entries to prevent unbounded growth. + /// Sweeps when either (a) the cache exceeds `MAX_CACHE_SIZE` entries, or + /// (b) `SWEEP_INTERVAL` has elapsed since the last sweep. The time-based + /// trigger matters under low traffic with many short-lived tokens — the + /// size trigger alone would let stale entries accumulate indefinitely. fn evict_expired_cache_entries(&self) { + use std::sync::atomic::Ordering; const MAX_CACHE_SIZE: usize = 10_000; - if self.token_cache.len() > MAX_CACHE_SIZE { - let now = std::time::Instant::now(); - self.token_cache - .retain(|_, (_, expires_at)| *expires_at > now); + const SWEEP_INTERVAL_SECS: u64 = 60; + + let now_instant = std::time::Instant::now(); + let elapsed_since_start = now_instant + .saturating_duration_since(self.cache_sweep_epoch) + .as_secs(); + let last = self.last_cache_sweep_secs.load(Ordering::Relaxed); + let time_due = elapsed_since_start.saturating_sub(last) >= SWEEP_INTERVAL_SECS; + let size_due = self.token_cache.len() > MAX_CACHE_SIZE; + + if !(size_due || time_due) { + return; } + + // Race-free claim: only the caller that successfully advances the + // sweep timestamp performs the scan; concurrent callers skip. + if self + .last_cache_sweep_secs + .compare_exchange( + last, + elapsed_since_start, + Ordering::AcqRel, + Ordering::Relaxed, + ) + .is_err() + { + return; + } + + self.token_cache + .retain(|_, (_, expires_at)| *expires_at > now_instant); } /// Validate HMAC-signed token. Uses the token's `kid` header to look up @@ -541,21 +580,42 @@ impl AuthMiddleware { let safe_kid = header.kid.as_deref().map(sanitize_for_log); debug!(kid = ?safe_kid, alg = ?header.alg, "Validating RSA token"); - let key = if let Some(kid) = header.kid { - jwks.get_key(&kid).await.map_err(|e| { + if let Some(kid) = header.kid { + let key = jwks.get_key(&kid).await.map_err(|e| { AuthError::InvalidToken(format!("Failed to get key '{}': {}", kid, e)) - })? - } else if self.config.jwks_require_kid { + })?; + return self.decode_and_validate(token, &key); + } + + if self.config.jwks_require_kid { return Err(AuthError::InvalidToken( "RS256 token missing kid header; set auth.jwks_require_kid = false to allow kidless tokens".to_string(), )); - } else { - jwks.get_any_key() - .await - .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))? - }; + } - self.decode_and_validate(token, &key) + // No kid: try every kidless key. A signature mismatch under one key + // does not imply the token is invalid — providers that publish multiple + // kidless keys (during rotation) require us to attempt each. + let candidates = jwks + .kidless_keys() + .await + .map_err(|e| AuthError::InvalidToken(format!("Failed to get JWKS key: {}", e)))?; + if candidates.is_empty() { + return Err(AuthError::InvalidToken( + "No kidless JWKS keys available for kidless token".to_string(), + )); + } + + let mut last_err: Option = None; + for key in &candidates { + match self.decode_and_validate(token, key) { + Ok(claims) => return Ok(claims), + Err(e) => last_err = Some(e), + } + } + Err(last_err.unwrap_or_else(|| { + AuthError::InvalidToken("Kidless token did not validate against any key".to_string()) + })) } /// Decode and validate token with the given key. @@ -627,6 +687,10 @@ impl AuthMiddleware { } /// Decode JWT token without signature verification (DEV MODE ONLY). + /// Still enforces `exp` so that a missing or zero expiry — which prod + /// validation rejects via `required_claims` — also fails here. Matching + /// the prod rule shrinks the downgrade-attack surface when a dev-built + /// token accidentally reaches a prod-config gateway. fn decode_without_verification(&self, token: &str) -> Result { let token_data = dangerous::insecure_decode::(token).map_err(|e| match e.kind() { @@ -636,6 +700,14 @@ impl AuthMiddleware { _ => AuthError::InvalidToken(e.to_string()), })?; + // `exp == 0` (epoch) would deserialize fine since `i64` has no default + // guard against it. Treat as missing. + if token_data.claims.exp() <= 0 { + return Err(AuthError::InvalidToken( + "Token missing required `exp` claim".to_string(), + )); + } + if token_data.claims.is_expired() { return Err(AuthError::TokenExpired); } @@ -810,12 +882,22 @@ pub async fn auth_middleware( let should_set_cookie = auth_context.is_authenticated() && middleware.config.jwt_secret.is_some(); - // Skip cookie if one already exists (avoids resigning on every request) + // Skip cookie if one already exists (avoids resigning on every request). + // Parse the Cookie header strictly: split on ';', trim whitespace, and + // match name exactly so substrings like `xforge_session` don't trigger. let has_session_cookie = req .headers() .get(header::COOKIE) .and_then(|v| v.to_str().ok()) - .map(|c| c.contains("forge_session=")) + .map(|c| { + c.split(';').any(|pair| { + let trimmed = pair.trim_start(); + trimmed + .split_once('=') + .map(|(name, _)| name == "forge_session") + .unwrap_or(false) + }) + }) .unwrap_or(false); let should_set_cookie = should_set_cookie && !has_session_cookie; @@ -851,8 +933,12 @@ pub async fn auth_middleware( // browsers refuse to send `Secure` cookies over HTTP, which surfaces // misconfigured deployments as a clean failure rather than silently // weakening the session. + // Path=/ ensures the browser sends the cookie on every request, so the + // resign-skip check above actually fires. With a narrower path the + // cookie would only be visible to /_api/oauth/* and every other request + // would re-sign it. let cookie = format!( - "forge_session={cookie_value}; Path=/_api/oauth/; HttpOnly; SameSite=Lax; Secure; Max-Age={cookie_ttl}" + "forge_session={cookie_value}; Path=/; HttpOnly; SameSite=Lax; Secure; Max-Age={cookie_ttl}" ); if let Ok(val) = axum::http::HeaderValue::from_str(&cookie) { response.headers_mut().append(header::SET_COOKIE, val); @@ -1555,7 +1641,7 @@ mod tests { let kid_a = secret_kid(b"some-secret"); let kid_b = secret_kid(b"some-secret"); assert_eq!(kid_a, kid_b); - assert_eq!(kid_a.len(), 8, "kid should be 8 hex chars (4 bytes)"); + assert_eq!(kid_a.len(), 16, "kid should be 16 hex chars (8 bytes)"); assert_ne!(kid_a, secret_kid(b"different-secret")); } @@ -1667,7 +1753,7 @@ mod tests { let config = AuthConfig { algorithm: JwtAlgorithm::RS256, jwks_client: Some(Arc::new( - JwksClient::new("http://example.invalid".into(), 3600).unwrap(), + JwksClient::new("https://example.invalid".into(), 3600).unwrap(), )), jwks_require_kid: true, ..AuthConfig::default() diff --git a/crates/forge-runtime/src/gateway/jwks.rs b/crates/forge-runtime/src/gateway/jwks.rs index 2e0934dc..817c3932 100644 --- a/crates/forge-runtime/src/gateway/jwks.rs +++ b/crates/forge-runtime/src/gateway/jwks.rs @@ -56,6 +56,10 @@ pub struct JsonWebKey { struct CachedJwks { /// Map of key ID to decoding key. keys: HashMap, + /// Keys served by the provider without a `kid`. Stored separately so a + /// rotation that ships multiple kidless keys does not silently lose every + /// one but the last. + kidless_keys: Vec, /// When the cache was last refreshed. fetched_at: Instant, } @@ -113,6 +117,15 @@ impl JwksClient { /// * `url` - The JWKS endpoint URL /// * `cache_ttl_secs` - How long to cache keys (in seconds) pub fn new(url: String, cache_ttl_secs: u64) -> Result { + // Reject plain-HTTP JWKS endpoints: an on-path attacker can swap keys + // and mint arbitrary RS256 tokens. Loopback is permitted for local + // identity-provider stubs in tests and development. + let insecure = forge_core::util::http_hostname(&url) + .is_some_and(|host| !forge_core::util::is_loopback_host(host)); + if insecure { + return Err(JwksError::InsecureUrl(url)); + } + let http_client = reqwest::Client::builder() .timeout(Duration::from_secs(10)) .build() @@ -206,15 +219,22 @@ impl JwksClient { /// Some providers don't include a key ID in tokens. This method /// returns the first available key from the JWKS. pub async fn get_any_key(&self) -> Result { - // Try to get from cache first + // Try to get from cache first. Kidless keys are preferred for + // kidless-token fallback so a provider rotation that ships multiple + // kidless keys still has every entry reachable here. { let cache = self.cache.read().await; if let Some(ref cached) = *cache && cached.fetched_at.elapsed() < self.cache_ttl - && let Some(key) = cached.keys.values().next() { - debug!("Using first cached JWKS key (no kid specified)"); - return Ok(key.clone()); + if let Some(key) = cached.kidless_keys.first() { + debug!("Using first cached kidless JWKS key"); + return Ok(key.clone()); + } + if let Some(key) = cached.keys.values().next() { + debug!("Using first cached JWKS key (no kid specified)"); + return Ok(key.clone()); + } } } @@ -224,6 +244,9 @@ impl JwksClient { let cache = self.cache.read().await; if let Some(ref cached) = *cache { + if let Some(key) = cached.kidless_keys.first().cloned() { + return Ok(key); + } cached .keys .values() @@ -235,6 +258,29 @@ impl JwksClient { } } + /// Try every cached kidless key in turn. Used by RSA validation when the + /// incoming token carries no `kid` header — without this, a kidless-key + /// rotation silently fails for tokens signed by the second key. + pub async fn kidless_keys(&self) -> Result, JwksError> { + { + let cache = self.cache.read().await; + if let Some(ref cached) = *cache + && cached.fetched_at.elapsed() < self.cache_ttl + && !cached.kidless_keys.is_empty() + { + return Ok(cached.kidless_keys.clone()); + } + } + self.refresh_if_needed().await?; + let cache = self.cache.read().await; + match *cache { + Some(ref cached) => Ok(cached.kidless_keys.clone()), + None => Err(JwksError::FetchFailed( + "Cache empty after refresh".to_string(), + )), + } + } + /// Force refresh the key cache. /// /// Fetches fresh keys from the JWKS endpoint regardless of cache state. @@ -261,6 +307,7 @@ impl JwksClient { .map_err(|e| JwksError::ParseFailed(e.to_string()))?; let mut keys = HashMap::new(); + let mut kidless_keys = Vec::new(); for jwk in jwks.keys { // Skip non-signature keys @@ -270,27 +317,36 @@ impl JwksClient { continue; } - let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string()); + let kid_for_log = jwk.kid.as_deref().unwrap_or("").to_string(); match self.parse_jwk(&jwk) { Ok(Some(key)) => { - debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key"); - keys.insert(kid, key); + debug!(kid = %kid_for_log, kty = %jwk.kty, "Parsed JWKS key"); + match jwk.kid { + Some(k) => { + keys.insert(k, key); + } + None => kidless_keys.push(key), + } } Ok(None) => { - debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type"); + debug!(kid = %kid_for_log, kty = %jwk.kty, "Skipping unsupported key type"); } Err(e) => { - warn!(kid = %kid, error = %e, "Failed to parse JWKS key"); + warn!(kid = %kid_for_log, error = %e, "Failed to parse JWKS key"); } } } - if keys.is_empty() { + if keys.is_empty() && kidless_keys.is_empty() { return Err(JwksError::NoKeysAvailable); } - debug!(count = keys.len(), "Cached JWKS keys"); + debug!( + count = keys.len(), + kidless = kidless_keys.len(), + "Cached JWKS keys" + ); // Drop negative-cache entries for any kid that's now present, so a // rotation that hands us a previously-missing kid takes effect @@ -302,6 +358,7 @@ impl JwksClient { let mut cache = self.cache.write().await; *cache = Some(CachedJwks { keys, + kidless_keys, fetched_at: Instant::now(), }); @@ -374,6 +431,10 @@ pub enum JwksError { /// Failed to create HTTP client. #[error("Failed to create HTTP client: {0}")] HttpClientError(String), + + /// JWKS URL uses an insecure scheme (plain http) outside loopback. + #[error("JWKS URL '{0}' must use https:// (plain http is only allowed for loopback hosts)")] + InsecureUrl(String), } #[cfg(test)] @@ -383,7 +444,7 @@ mod tests { #[test] fn test_parse_jwk_with_n_e() { - let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap(); + let client = JwksClient::new("https://example.com".to_string(), 3600).unwrap(); // Example RSA public key components (minimal test) let jwk = JsonWebKey { @@ -404,7 +465,7 @@ mod tests { #[test] fn test_parse_jwk_unsupported_type() { - let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap(); + let client = JwksClient::new("https://example.com".to_string(), 3600).unwrap(); let jwk = JsonWebKey { kid: Some("test-key".to_string()), @@ -423,7 +484,7 @@ mod tests { #[test] fn test_parse_jwk_missing_components() { - let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap(); + let client = JwksClient::new("https://example.com".to_string(), 3600).unwrap(); let jwk = JsonWebKey { kid: Some("test-key".to_string()), @@ -452,7 +513,7 @@ mod tests { // "oct" (symmetric) keys can't be used for asymmetric verification; we // skip them silently rather than erroring, so the caller can keep // processing the rest of the JWKS. - let client = JwksClient::new("http://example.com".into(), 60).unwrap(); + let client = JwksClient::new("https://example.com".into(), 60).unwrap(); let jwk = JsonWebKey { kid: Some("sym".into()), kty: "oct".into(), @@ -468,7 +529,7 @@ mod tests { #[test] fn parse_jwk_returns_none_when_only_modulus_present() { // RSA with `n` but no `e` is malformed; we drop it rather than crashing. - let client = JwksClient::new("http://example.com".into(), 60).unwrap(); + let client = JwksClient::new("https://example.com".into(), 60).unwrap(); let jwk = JsonWebKey { kid: Some("partial".into()), kty: "RSA".into(), @@ -486,7 +547,7 @@ mod tests { // When x5c is present, the implementation uses it first. A garbage // cert string therefore surfaces as KeyParseFailed, not silent // fallthrough to the n/e branch (which would otherwise succeed). - let client = JwksClient::new("http://example.com".into(), 60).unwrap(); + let client = JwksClient::new("https://example.com".into(), 60).unwrap(); let jwk = JsonWebKey { kid: Some("bad-x5c".into()), kty: "RSA".into(), @@ -536,12 +597,13 @@ mod tests { async fn get_key_returns_cached_match_without_network() { // Pre-populate the cache so the read path is exercised without // touching the JWKS endpoint. Verifies the cached-key fast path. - let client = JwksClient::new("http://example.invalid".into(), 3600).unwrap(); + let client = JwksClient::new("https://example.invalid".into(), 3600).unwrap(); let key = DecodingKey::from_secret(b"placeholder"); let mut keys = HashMap::new(); keys.insert("kid-1".to_string(), key); *client.cache.write().await = Some(CachedJwks { keys, + kidless_keys: Vec::new(), fetched_at: Instant::now(), }); @@ -554,11 +616,12 @@ mod tests { async fn get_any_key_returns_first_cached_when_kid_absent() { // Some providers issue tokens without a `kid` header; the fallback // must return whichever key is cached. - let client = JwksClient::new("http://example.invalid".into(), 3600).unwrap(); + let client = JwksClient::new("https://example.invalid".into(), 3600).unwrap(); let mut keys = HashMap::new(); keys.insert("only".into(), DecodingKey::from_secret(b"placeholder")); *client.cache.write().await = Some(CachedJwks { keys, + kidless_keys: Vec::new(), fetched_at: Instant::now(), }); diff --git a/crates/forge-runtime/src/gateway/mcp/mod.rs b/crates/forge-runtime/src/gateway/mcp/mod.rs index 007c889f..463eac7a 100644 --- a/crates/forge-runtime/src/gateway/mcp/mod.rs +++ b/crates/forge-runtime/src/gateway/mcp/mod.rs @@ -114,7 +114,11 @@ impl Stream for McpReceiverStream { /// Clients use this to receive notifications and asynchronous responses /// from the MCP server. The stream starts with an `endpoint` event /// containing the session ID, then sends keepalive pings every 30 seconds. -pub async fn mcp_get_handler(State(state): State>, headers: HeaderMap) -> Response { +pub async fn mcp_get_handler( + State(state): State>, + Extension(auth): Extension, + headers: HeaderMap, +) -> Response { if let Err(resp) = validate_origin(&headers, &state.config) { return *resp; } @@ -129,6 +133,28 @@ pub async fn mcp_get_handler(State(state): State>, headers: Header Err(resp) => return resp, }; + // Bind the SSE stream to the principal that initialized the session. + // Otherwise any caller with a leaked session id can attach to that + // session's stream and impersonate it. + { + let sessions = state.sessions.read().await; + if let Some(session) = sessions.get(&session_id) { + let current = auth.principal_id(); + if session.principal_id != current { + return ( + StatusCode::FORBIDDEN, + Json(json_rpc_error( + None, + -32001, + "Session principal mismatch", + None, + )), + ) + .into_response(); + } + } + } + state.touch_session(&session_id).await; // Create a channel for server-to-client messages diff --git a/crates/forge-runtime/src/gateway/mcp/session.rs b/crates/forge-runtime/src/gateway/mcp/session.rs index 3434f987..700df5b0 100644 --- a/crates/forge-runtime/src/gateway/mcp/session.rs +++ b/crates/forge-runtime/src/gateway/mcp/session.rs @@ -89,13 +89,46 @@ pub(super) fn validate_origin( headers: &HeaderMap, config: &McpConfig, ) -> std::result::Result<(), ResponseError> { - let Some(origin) = headers.get("origin").and_then(|v| v.to_str().ok()) else { - return Ok(()); - }; + let origin = headers.get("origin").and_then(|v| v.to_str().ok()); - // When no allowed_origins are configured, reject cross-origin requests - // rather than allowing all origins (secure by default) - if config.allowed_origins.is_empty() { + // When the operator has configured an allow-list, the Origin header is + // mandatory. Without this, a browser-adjacent context exploiting DNS + // rebinding (or any client suppressing Origin) bypasses the allow-list. + if !config.allowed_origins.is_empty() { + let allow_any = config.allowed_origins.iter().any(|c| c == "*"); + return match origin { + Some(o) => { + let allowed = allow_any + || config + .allowed_origins + .iter() + .any(|candidate| candidate.eq_ignore_ascii_case(o)); + if allowed { + Ok(()) + } else { + Err(Box::new( + ( + StatusCode::FORBIDDEN, + Json(json_rpc_error(None, -32600, "Invalid Origin header", None)), + ) + .into_response(), + )) + } + } + None if allow_any => Ok(()), + None => Err(Box::new( + ( + StatusCode::FORBIDDEN, + Json(json_rpc_error(None, -32600, "Missing Origin header", None)), + ) + .into_response(), + )), + }; + } + + // No allow-list configured: keep the "secure by default" reject for + // cross-origin requests, and allow non-browser clients that omit Origin. + if origin.is_some() { return Err(Box::new( ( StatusCode::FORBIDDEN, @@ -110,21 +143,7 @@ pub(super) fn validate_origin( )); } - let allowed = config - .allowed_origins - .iter() - .any(|candidate| candidate == "*" || candidate.eq_ignore_ascii_case(origin)); - if allowed { - return Ok(()); - } - - Err(Box::new( - ( - StatusCode::FORBIDDEN, - Json(json_rpc_error(None, -32600, "Invalid Origin header", None)), - ) - .into_response(), - )) + Ok(()) } pub(super) fn enforce_protocol_header( diff --git a/crates/forge-runtime/src/gateway/mcp/tools.rs b/crates/forge-runtime/src/gateway/mcp/tools.rs index f1762ca9..ce10e231 100644 --- a/crates/forge-runtime/src/gateway/mcp/tools.rs +++ b/crates/forge-runtime/src/gateway/mcp/tools.rs @@ -484,23 +484,68 @@ pub(super) async fn handle_proxied_function_call( } } +/// Hard cap on serialized MCP tool output. Past this size we drop +/// `structuredContent` (avoiding the double-encoded blow-up for objects) and +/// truncate the textual representation. Picked to leave generous headroom +/// while preventing a single tool call from buffering tens of MB twice. +const MAX_TOOL_OUTPUT_BYTES: usize = 256 * 1024; + pub(super) fn tool_success_result(output: Value) -> Value { match output { - Value::Object(_) => serde_json::json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string()) - }], - "structuredContent": output - }), - Value::String(text) => serde_json::json!({ - "content": [{ "type": "text", "text": text }] - }), - other => serde_json::json!({ - "content": [{ - "type": "text", - "text": serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string()) - }] - }), + Value::Object(_) => { + let text = serde_json::to_string(&output).unwrap_or_else(|_| "{}".to_string()); + if text.len() > MAX_TOOL_OUTPUT_BYTES { + // Avoid embedding the object twice once it exceeds the cap. + let truncated = truncate_at_char_boundary(&text, MAX_TOOL_OUTPUT_BYTES); + serde_json::json!({ + "content": [{ + "type": "text", + "text": truncated + }], + "isError": false, + "_truncated": true + }) + } else { + serde_json::json!({ + "content": [{ "type": "text", "text": text }], + "structuredContent": output + }) + } + } + Value::String(text) => { + let text = if text.len() > MAX_TOOL_OUTPUT_BYTES { + truncate_at_char_boundary(&text, MAX_TOOL_OUTPUT_BYTES) + } else { + text + }; + serde_json::json!({ + "content": [{ "type": "text", "text": text }] + }) + } + other => { + let text = serde_json::to_string(&other).unwrap_or_else(|_| "null".to_string()); + let text = if text.len() > MAX_TOOL_OUTPUT_BYTES { + truncate_at_char_boundary(&text, MAX_TOOL_OUTPUT_BYTES) + } else { + text + }; + serde_json::json!({ + "content": [{ "type": "text", "text": text }] + }) + } + } +} + +fn truncate_at_char_boundary(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + let mut end = max_bytes; + while end > 0 && !s.is_char_boundary(end) { + end -= 1; } + let mut out = String::with_capacity(end + 16); + out.push_str(s.get(..end).unwrap_or("")); + out.push_str("…[truncated]"); + out } diff --git a/crates/forge-runtime/src/gateway/mod.rs b/crates/forge-runtime/src/gateway/mod.rs index c9e45fec..b653c1ab 100644 --- a/crates/forge-runtime/src/gateway/mod.rs +++ b/crates/forge-runtime/src/gateway/mod.rs @@ -25,8 +25,9 @@ pub use response::{RpcError, RpcResponse}; pub use rpc::RpcHandler; pub use server::{GatewayConfig, GatewayServer, TrustedProxies}; pub use sse::{ - SseConfig, SsePayload, SseQuery, SseState, sse_handler, sse_job_subscribe_handler, - sse_subscribe_handler, sse_unsubscribe_handler, sse_workflow_subscribe_handler, + SseConfig, SsePayload, SseQuery, SseState, SseTicketResponse, sse_handler, + sse_job_subscribe_handler, sse_subscribe_handler, sse_ticket_handler, sse_unsubscribe_handler, + sse_workflow_subscribe_handler, }; pub use tls::{ GatewayConn, GatewayListener, PeerAddr, TlsListenConfig, bind_listener, load_rustls_config, diff --git a/crates/forge-runtime/src/gateway/multipart.rs b/crates/forge-runtime/src/gateway/multipart.rs index 7933be2c..20fa2521 100644 --- a/crates/forge-runtime/src/gateway/multipart.rs +++ b/crates/forge-runtime/src/gateway/multipart.rs @@ -12,9 +12,106 @@ use forge_core::types::Upload; use super::rpc::RpcHandler; const MAX_FIELD_NAME_LENGTH: usize = 255; +const MAX_FILENAME_LENGTH: usize = 255; const MAX_JSON_FIELD_SIZE: usize = 1024 * 1024; const JSON_FIELD_NAME: &str = "_json"; +/// Sanitize a raw upload filename for safe persistence and logging. +/// +/// Strips path components, neutralizes traversal sequences after the basename +/// is isolated, rejects null bytes / control chars, and rewrites Windows +/// reserved device names (CON, PRN, NUL, AUX, COM[1-9], LPT[1-9]) so that +/// downstream code that mirrors the upload to disk on Windows can't trip the +/// reserved-name handling. Returns `None` when nothing salvageable remains. +fn sanitize_filename(raw: &str) -> Option { + // Basename first: take the last path component for either separator, then + // collapse traversal sequences inside the basename so `foo..bar` keeps the + // double-dot but `..` alone becomes `_`. + let basename = raw.rsplit(['/', '\\']).next().unwrap_or(raw); + let basename = basename.trim(); + if basename.is_empty() || basename == "." || basename == ".." { + return None; + } + + // Reject controls and null bytes outright. A null byte in a filename is a + // classic log-truncation / path-confusion vector. + if basename.chars().any(|c| c == '\0' || c.is_control()) { + return None; + } + + let mut name: String = basename + .chars() + .map(|c| match c { + '<' | '>' | ':' | '"' | '|' | '?' | '*' => '_', + _ => c, + }) + .collect(); + + // Windows reserved device names match against the basename without its + // extension. Comparison is case-insensitive. + let stem = name.split('.').next().unwrap_or(&name).to_ascii_uppercase(); + let is_reserved_dev = |prefix: &str| -> bool { + stem.strip_prefix(prefix) + .and_then(|rest| rest.parse::().ok()) + .is_some_and(|n| (1..=9).contains(&n)) + }; + let is_reserved = matches!(stem.as_str(), "CON" | "PRN" | "AUX" | "NUL") + || is_reserved_dev("COM") + || is_reserved_dev("LPT"); + if is_reserved { + name = format!("_{name}"); + } + + if name.len() > MAX_FILENAME_LENGTH { + name.truncate(MAX_FILENAME_LENGTH); + } + + if name.is_empty() { None } else { Some(name) } +} + +#[cfg(test)] +mod sanitize_tests { + use super::sanitize_filename; + + #[test] + fn strips_path_components() { + assert_eq!(sanitize_filename("/etc/passwd").as_deref(), Some("passwd")); + assert_eq!( + sanitize_filename("C:\\Windows\\system.ini").as_deref(), + Some("system.ini") + ); + } + + #[test] + fn preserves_legitimate_double_dots() { + assert_eq!(sanitize_filename("foo..bar").as_deref(), Some("foo..bar")); + } + + #[test] + fn rejects_traversal_only_basename() { + assert!(sanitize_filename("../../etc/passwd").is_some()); + assert_eq!( + sanitize_filename("../../etc/passwd").as_deref(), + Some("passwd") + ); + assert!(sanitize_filename("..").is_none()); + assert!(sanitize_filename("foo/..").is_none()); + } + + #[test] + fn rejects_control_chars_and_nulls() { + assert!(sanitize_filename("foo\0bar").is_none()); + assert!(sanitize_filename("foo\nbar").is_none()); + } + + #[test] + fn rewrites_windows_reserved_names() { + assert_eq!(sanitize_filename("CON").as_deref(), Some("_CON")); + assert_eq!(sanitize_filename("nul.txt").as_deref(), Some("_nul.txt")); + assert_eq!(sanitize_filename("COM1.log").as_deref(), Some("_COM1.log")); + } +} + /// Lightweight magic-byte check: for a small set of well-known media types, /// the declared `Content-Type` must match the file's leading bytes. Types /// outside this list pass through (we don't have a signature library, and @@ -169,6 +266,17 @@ pub async fn rpc_multipart_handler( } if name == JSON_FIELD_NAME { + // Duplicate `_json` would silently let the last value win, which + // is a parameter-smuggling avenue when upstream validators only + // inspected the first occurrence. + if json_args.is_some() { + return multipart_error( + StatusCode::BAD_REQUEST, + "DUPLICATE_FIELD", + "Multiple `_json` fields submitted; only one is allowed", + ); + } + let mut buffer = BytesMut::new(); let mut json_field = field; @@ -180,8 +288,7 @@ pub async fn rpc_multipart_handler( StatusCode::PAYLOAD_TOO_LARGE, "PAYLOAD_TOO_LARGE", format!( - "Multipart payload exceeds maximum size of {} bytes", - max_total + "Multipart payload exceeds maximum size of {max_total} bytes (field `_json`)" ), ); } @@ -235,20 +342,16 @@ pub async fn rpc_multipart_handler( .file_name() .map(String::from) .unwrap_or_else(|| name.clone()); - // Sanitize filename: strip path components to prevent path traversal - let filename = raw_filename - .rsplit(['/', '\\']) - .next() - .unwrap_or(&raw_filename) - .replace("..", "_") - .to_string(); - if filename.is_empty() { - return multipart_error( - StatusCode::BAD_REQUEST, - "INVALID_FILENAME", - "Filename is empty after sanitization", - ); - } + let filename = match sanitize_filename(&raw_filename) { + Some(f) => f, + None => { + return multipart_error( + StatusCode::BAD_REQUEST, + "INVALID_FILENAME", + "Filename empty or contains disallowed characters after sanitization", + ); + } + }; let content_type = field .content_type() .map(String::from) @@ -265,8 +368,7 @@ pub async fn rpc_multipart_handler( StatusCode::PAYLOAD_TOO_LARGE, "PAYLOAD_TOO_LARGE", format!( - "Multipart payload exceeds maximum size of {} bytes", - max_total + "Multipart payload exceeds maximum size of {max_total} bytes (field `{name}`, file `{filename}`)" ), ); } diff --git a/crates/forge-runtime/src/gateway/oauth.rs b/crates/forge-runtime/src/gateway/oauth.rs index a586d54b..1beef750 100644 --- a/crates/forge-runtime/src/gateway/oauth.rs +++ b/crates/forge-runtime/src/gateway/oauth.rs @@ -398,10 +398,12 @@ pub async fn oauth_register( ) .into_response(); } - // Check scheme and host - let is_localhost = uri.starts_with("http://localhost") - || uri.starts_with("http://127.0.0.1") - || uri.starts_with("http://[::1]"); + // Check scheme and host. `starts_with("http://localhost")` would also + // pass `http://localhost.evil.com/cb`, letting an attacker register a + // non-loopback redirect over plain HTTP. Extract the hostname and + // compare exactly. + let is_localhost = + forge_core::util::http_hostname(uri).is_some_and(forge_core::util::is_loopback_host); let is_https = uri.starts_with("https://"); if !is_localhost && !is_https { return ( @@ -661,10 +663,20 @@ pub async fn oauth_authorize_post( .into_response(); } - // Rate limit login failures (T7). PG-backed key so the budget is shared - // across cluster nodes. + // Rate limit every authorize POST branch (T7). Applied before branch + // dispatch so token and session-cookie flows can't bypass the budget — + // each branch exercises crypto-heavy paths (JWT validate, cookie HMAC, + // argon2id) that are otherwise free abuse amplifiers. let ip = resolved_ip.0.as_deref().unwrap_or("unknown"); let rate_key = format!("oauth:login:{ip}"); + if !state.rate_check(&rate_key, LOGIN_FAIL_RATE_LIMIT).await { + return authorize_error_redirect( + &form.redirect_uri, + form.state.as_deref(), + "access_denied", + "Too many authorization attempts. Please try again later.", + ); + } // Validate client and redirect_uri again (form could be tampered) let client = sqlx::query!( @@ -714,16 +726,22 @@ pub async fn oauth_authorize_post( }); if let Some(subject) = session_subject { - // Session cookie flow: user identified by signed cookie from previous API calls. - // Subject may be a UUID (HMAC auth) or an external provider ID (Firebase, Clerk). - user_id = subject.parse::().unwrap_or_else(|_| { - // Non-UUID subject (Firebase UID, etc.): deterministic UUID from subject hash. - use sha2::Digest; - let hash: [u8; 32] = sha2::Sha256::digest(subject.as_bytes()).into(); - let mut bytes = [0u8; 16]; - bytes.copy_from_slice(&hash[..16]); - Uuid::from_bytes(bytes) - }); + // Session cookie flow: subject must already be a real users.id UUID. + // Hashing an external provider subject into a fake UUID would forge a + // FK to a row that does not exist, and two different external subjects + // could collide on the same 128-bit prefix. External-provider users + // must complete the bearer-token branch instead. + match subject.parse::() { + Ok(uid) => user_id = uid, + Err(_) => { + return authorize_error_redirect( + &form.redirect_uri, + form.state.as_deref(), + "access_denied", + "Session subject is not a Forge user id. Sign in with a bearer token.", + ); + } + } } else if let Some(token) = &form.token { // Consent flow: validate existing JWT match state.auth_middleware.validate_token_async(token).await { @@ -762,15 +780,6 @@ pub async fn oauth_authorize_post( ); } - if !state.rate_check(&rate_key, LOGIN_FAIL_RATE_LIMIT).await { - return authorize_error_redirect( - &form.redirect_uri, - form.state.as_deref(), - "access_denied", - "Too many login attempts. Please try again later.", - ); - } - // Query users table by convention let row = sqlx::query!( "SELECT id, password_hash, role::TEXT FROM users WHERE email = $1", @@ -1156,14 +1165,16 @@ fn base_url_from_headers(headers: &HeaderMap) -> String { .and_then(|v| v.to_str().ok()) .unwrap_or("localhost:9081"); + // Parse host name (drop port and IPv6 brackets) and compare exactly. A + // `starts_with` test would match `localhost.evil.com`, causing the metadata + // to advertise an attacker-controlled http:// URL. + let hostname = forge_core::util::hostname_from_authority(host); + let is_loopback = forge_core::util::is_loopback_host(hostname); + // Do not trust x-forwarded-proto: OAuth routes bypass the trusted-proxy // middleware, so any client can spoof the header. Default to "https" for - // production safety; localhost gets "http" for local development. - let scheme = if host.starts_with("localhost") || host.starts_with("127.0.0.1") { - "http" - } else { - "https" - }; + // production safety; loopback gets "http" for local development. + let scheme = if is_loopback { "http" } else { "https" }; format!("{scheme}://{host}") } diff --git a/crates/forge-runtime/src/gateway/rpc.rs b/crates/forge-runtime/src/gateway/rpc.rs index 09767f1f..e6caeb3a 100644 --- a/crates/forge-runtime/src/gateway/rpc.rs +++ b/crates/forge-runtime/src/gateway/rpc.rs @@ -214,12 +214,18 @@ pub struct RpcFunctionBody { /// Validate that a function name contains only safe characters. /// Prevents log injection and unexpected behavior from special characters. +/// Leading `.` (including `..`) is rejected: dotted segments are reserved +/// for module paths and a leading dot looks like a path-traversal attempt +/// — neither maps to a real handler, so failing loud beats a 404 later. fn is_valid_function_name(name: &str) -> bool { - !name.is_empty() - && name.len() <= 256 - && name - .chars() - .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-') + if name.is_empty() || name.len() > 256 { + return false; + } + if name.starts_with('.') { + return false; + } + name.chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '.' || c == ':' || c == '-') } /// Axum handler for POST /rpc/:function (REST-style). @@ -314,6 +320,15 @@ mod tests { assert!(!is_valid_function_name("question?")); } + #[test] + fn function_name_rejects_leading_dot() { + // Leading dot (or `..`) reads as a path-traversal attempt and never + // maps to a real handler. + assert!(!is_valid_function_name(".hidden")); + assert!(!is_valid_function_name("..parent")); + assert!(!is_valid_function_name(".")); + } + #[test] fn user_agent_returns_value_when_header_present() { let mut headers = HeaderMap::new(); diff --git a/crates/forge-runtime/src/gateway/server.rs b/crates/forge-runtime/src/gateway/server.rs index 9d50c831..634466e2 100644 --- a/crates/forge-runtime/src/gateway/server.rs +++ b/crates/forge-runtime/src/gateway/server.rs @@ -34,7 +34,7 @@ use super::multipart::{MultipartConfig, rpc_multipart_handler}; use super::response::{RpcError, RpcResponse}; use super::rpc::{RpcHandler, rpc_function_handler, rpc_handler}; use super::sse::{ - SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler, + SseState, sse_handler, sse_job_subscribe_handler, sse_subscribe_handler, sse_ticket_handler, sse_unsubscribe_handler, sse_workflow_subscribe_handler, }; use super::tls::{TlsListenConfig, bind_listener}; @@ -48,19 +48,12 @@ const DEFAULT_MAX_JSON_BODY_SIZE: usize = 1024 * 1024; const DEFAULT_MAX_MULTIPART_BODY_SIZE: usize = 20 * 1024 * 1024; const DEFAULT_MAX_FILE_SIZE: usize = 10 * 1024 * 1024; const MAX_MULTIPART_CONCURRENCY: usize = 32; -/// Fallback for visitor ID hashing when no JWT secret is configured (dev only). -const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret"; - -/// Resolve the visitor-ID hashing secret, falling back to a stable dev value -/// with a one-time warning when no JWT secret is configured. -fn signal_visitor_secret(jwt_secret: &Option) -> String { - jwt_secret.clone().unwrap_or_else(|| { - tracing::warn!( - "No jwt_secret configured; using default signal secret for visitor ID hashing. \ - Visitor IDs will be predictable. Set [auth] jwt_secret in forge.toml." - ); - DEFAULT_SIGNAL_SECRET.to_string() - }) +/// Resolve the visitor-ID hashing secret. Returns `None` when no JWT secret +/// is configured — callers must skip signals collection rather than fall +/// back to a constant, which would let any attacker predict visitor IDs and +/// correlate sessions across users. +fn signal_visitor_secret(jwt_secret: &Option) -> Option { + jwt_secret.clone().filter(|s| !s.is_empty()) } /// Gateway server configuration. @@ -218,6 +211,7 @@ pub struct GatewayServer { signals_collector: Option, signals_anonymize_ip: bool, signals_geoip: Option, + signals_rate_limit_per_minute: Option, custom_routes: Option, rate_limiter: Option>, role_resolver: Option, @@ -255,6 +249,7 @@ impl GatewayServer { signals_collector: None, signals_anonymize_ip: false, signals_geoip: None, + signals_rate_limit_per_minute: None, custom_routes: None, rate_limiter: None, role_resolver: None, @@ -330,6 +325,12 @@ impl GatewayServer { } /// Set the GeoIP resolver for country code lookups from client IPs. + /// Override the default per-IP `/signal` rate limit (requests per minute). + pub fn with_signals_rate_limit_per_minute(mut self, max: u32) -> Self { + self.signals_rate_limit_per_minute = Some(max); + self + } + pub fn with_signals_geoip(mut self, resolver: crate::signals::geoip::GeoIpResolver) -> Self { self.signals_geoip = Some(resolver); self @@ -424,8 +425,14 @@ impl GatewayServer { rpc.set_role_resolver(resolver.clone()); } if let Some(collector) = &self.signals_collector { - let secret = signal_visitor_secret(&self.config.auth.jwt_secret); - rpc.set_signals_collector(collector.clone(), secret); + match signal_visitor_secret(&self.config.auth.jwt_secret) { + Some(secret) => rpc.set_signals_collector(collector.clone(), secret), + None => tracing::error!( + "Signals collector configured but `[auth] jwt_secret` is unset. \ + Signals are disabled to avoid predictable visitor IDs. Configure \ + a jwt_secret in forge.toml to enable signals." + ), + } } let rpc_handler_state = Arc::new(rpc); @@ -446,21 +453,40 @@ impl GatewayServer { // with credentials per the CORS spec, so we enumerate them. let cors = if self.config.cors_enabled { if self.config.cors_origins.iter().any(|o| o == "*") { - // Wildcard origin can't use credentials. Loud at startup so - // operators don't ship `cors_origins = ["*"]` to production - // by accident — credentialed requests will silently fail - // (no `Access-Control-Allow-Credentials`) and there's no - // origin allowlist limiting cross-site abuse of the gateway. - tracing::warn!( - "CORS wildcard (`cors_origins = [\"*\"]`) is enabled. \ - Credentialed requests will fail and any origin can \ - reach the gateway. Set explicit origins for \ - production deployments." - ); - CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any) + let is_production = std::env::var("FORGE_ENV") + .ok() + .as_deref() + .map(|s| s.eq_ignore_ascii_case("production") || s.eq_ignore_ascii_case("prod")) + .unwrap_or(false); + if is_production { + // In production a wildcard origin opens the gateway to any + // site. Refuse the wildcard outright: CORS is disabled and + // every cross-origin request will fail at the browser. The + // alternative — silently accepting every Origin — would let + // a malicious site issue same-credentials requests. + tracing::error!( + "CORS wildcard (`cors_origins = [\"*\"]`) is forbidden when \ + FORGE_ENV=production. CORS will be disabled. Configure \ + explicit origins to re-enable cross-origin access." + ); + CorsLayer::new() + } else { + // Wildcard origin can't use credentials. Loud at startup so + // operators don't ship `cors_origins = ["*"]` to production + // by accident — credentialed requests will silently fail + // (no `Access-Control-Allow-Credentials`) and there's no + // origin allowlist limiting cross-site abuse of the gateway. + tracing::warn!( + "CORS wildcard (`cors_origins = [\"*\"]`) is enabled. \ + Credentialed requests will fail and any origin can \ + reach the gateway. Set explicit origins for \ + production deployments." + ); + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any) + } } else { use axum::http::Method; let origins: Vec<_> = self @@ -501,7 +527,6 @@ impl GatewayServer { let sse_state = Arc::new(SseState::with_config( self.reactor.clone(), - auth_middleware_state.clone(), super::sse::SseConfig { max_sessions: self.config.sse_max_sessions, max_subscriptions_per_session: self @@ -569,6 +594,7 @@ impl GatewayServer { let sse_router = Router::new() .route("/events", get(sse_handler)) + .route("/events/ticket", post(sse_ticket_handler)) .route("/subscribe", post(sse_subscribe_handler)) .route("/unsubscribe", post(sse_unsubscribe_handler)) .route("/subscribe-job", post(sse_job_subscribe_handler)) @@ -598,23 +624,60 @@ impl GatewayServer { ); } - let mut signals_router = Router::new(); - if let Some(collector) = &self.signals_collector { + // The real collector only mounts when we have a visitor-ID secret: a + // constant fallback secret would make visitor IDs predictable, letting + // an attacker forge cross-user correlations. When it can't mount we + // still answer POST /signal with a 204 no-op (the else branch) instead + // of leaving the path to the SPA fallback, which 405s every beacon. + let signal_secret = signal_visitor_secret(&self.config.auth.jwt_secret); + if signal_secret.is_none() && self.signals_collector.is_some() { + tracing::error!( + "Signals collector configured but `[auth] jwt_secret` is unset. \ + Signal collection is disabled to avoid predictable visitor IDs." + ); + } + let signals_router = if let (Some(collector), Some(server_secret)) = + (&self.signals_collector, signal_secret) + { let signals_state = Arc::new(crate::signals::endpoints::SignalsState { collector: collector.clone(), pool: self.db.primary().clone(), - server_secret: signal_visitor_secret(&self.config.auth.jwt_secret), + server_secret, anonymize_ip: self.signals_anonymize_ip, geoip: self.signals_geoip.clone(), - rate_limiter: Arc::new(crate::signals::rate_limit::SignalRateLimiter::new()), + rate_limiter: Arc::new(match self.signals_rate_limit_per_minute { + Some(max) => crate::signals::rate_limit::SignalRateLimiter::with_limit(max), + None => crate::signals::rate_limit::SignalRateLimiter::new(), + }), }); - signals_router = Router::new() + // Tighter body cap for the signal endpoint specifically. The + // batch + per-event size caps in signals/endpoints.rs would + // otherwise sit behind the 1 MB JSON default; clamping the + // request body to MAX_BATCH_SIZE * MAX_EVENT_BYTES + slack stops + // unauthenticated clients from forcing us to deserialize multi- + // MB JSON before validation runs. + const MAX_SIGNAL_BODY_BYTES: usize = 512 * 1024; + Router::new() .route("/signal", post(crate::signals::endpoints::signal_handler)) - .with_state(signals_state); - } + .layer(DefaultBodyLimit::max(MAX_SIGNAL_BODY_BYTES)) + .with_state(signals_state) + } else { + // Signals are disabled (or `[auth] jwt_secret` is unset). Clients + // enable web-vitals and page-view tracking by default and POST to + // /signal regardless. Without a route here the request falls through + // to the SPA static fallback, which rejects non-GET with a 405 the + // browser logs as a console error. Accept and drop it: a 204 stores + // nothing and mints no visitor ID, so it doesn't reintroduce the + // predictable-ID risk the real handler guards against. + Router::new().route( + "/signal", + post(|| async { axum::http::StatusCode::NO_CONTENT }), + ) + }; let admin_router = admin_router(AdminState { db_pool: self.db.primary().clone(), + reactor: Some(self.reactor.clone()), }); main_router = main_router @@ -689,6 +752,23 @@ impl GatewayServer { .map_err(|e| std::io::Error::other(format!("Failed to start reactor: {}", e)))?; tracing::info!("Reactor started for real-time updates"); + // Surface the trusted-proxy posture at startup. The XFF chain is only + // honored when the immediate peer is in `trusted_proxies` — if the + // operator forgot to add the proxy CIDR, every request silently falls + // back to the peer IP and rate limits / geo signals get pinned to the + // proxy. A loud one-shot log keeps that misconfiguration visible. + if !self.config.trusted_proxies.is_empty() { + tracing::info!( + ranges = self.config.trusted_proxies.len(), + "Trusted proxies configured; X-Forwarded-For honored only from peers in these CIDRs" + ); + } else { + tracing::info!( + "No trusted_proxies configured; X-Forwarded-For headers are ignored and \ + client IPs come from the immediate peer" + ); + } + tracing::info!("Gateway server listening on {}", addr); let listener = bind_listener(addr, tls.as_ref()).await?; @@ -806,10 +886,15 @@ async fn readiness_handler( } async fn handle_middleware_error(err: BoxError) -> axum::response::Response { + // Distinguish error categories so clients can react correctly. Timeout + // signals "retry later"; anything else gets surfaced as 500 so it shows + // up in error budgets rather than masquerading as a transient 503 that + // hides genuine middleware bugs. let rpc_err = if err.is::() { RpcError::new("REQUEST_TIMEOUT", "Request timed out") } else { - RpcError::new("SERVICE_UNAVAILABLE", "Server overloaded") + tracing::error!(error = %err, "Unexpected middleware error"); + RpcError::new("INTERNAL_ERROR", "Internal server error") }; RpcResponse::error(rpc_err).into_response() } @@ -919,7 +1004,7 @@ async fn api_version_middleware( let is_rpc = req.uri().path().starts_with("/rpc"); if is_rpc && let Some(accept) = req.headers().get(axum::http::header::ACCEPT) { let accept_str = accept.to_str().unwrap_or(""); - if accept_str != "*/*" && !accept_str.is_empty() && !accept_str.contains(FORGE_API_V1) { + if !accept_str.is_empty() && !accept_allows_v1(accept_str) { return RpcResponse::error(RpcError::new( "UNSUPPORTED_API_VERSION", format!( @@ -933,6 +1018,39 @@ async fn api_version_middleware( next.run(req).await } +/// Returns true if the `Accept` header value tolerates the v1 forge media +/// type. Each comma-separated media range is checked individually so that +/// `Accept: application/json, application/vnd.forge.v1+json` matches even +/// though `contains` would also have matched a misleading substring. Quality +/// values (`;q=0`) explicitly disable the match. +fn accept_allows_v1(accept: &str) -> bool { + for raw in accept.split(',') { + let mut parts = raw.split(';').map(str::trim); + let Some(media) = parts.next() else { continue }; + if media.is_empty() { + continue; + } + let mut q = 1.0_f32; + for param in parts { + if let Some(qv) = param.strip_prefix("q=") + && let Ok(parsed) = qv.parse::() + { + q = parsed; + } + } + if q <= 0.0 { + continue; + } + if media.eq_ignore_ascii_case(FORGE_API_V1) + || media == "*/*" + || media.eq_ignore_ascii_case("application/*") + { + return true; + } + } + false +} + /// Wraps each request in a span with HTTP semantics and OpenTelemetry /// context propagation. Incoming `traceparent` headers are extracted so /// that spans join the caller's distributed trace. @@ -1117,9 +1235,10 @@ struct JsonDepthConfig { } /// Middleware that rejects request bodies whose JSON nesting depth exceeds -/// `max_depth`. Runs on all POST requests regardless of Content-Type, because -/// serde_json will parse the body downstream even if the client lies about -/// the content type. +/// `max_depth`. Runs on every method that can carry a body (POST/PUT/PATCH/ +/// DELETE) regardless of Content-Type, because serde_json will parse the +/// body downstream even if the client lies about the content type. GET and +/// HEAD are skipped because Axum drops their bodies. /// /// The body is buffered, inspected, and re-inserted into the request so that /// downstream handlers see the original bytes. @@ -1129,8 +1248,13 @@ async fn json_depth_check_middleware( next: axum::middleware::Next, ) -> axum::response::Response { use axum::body::Body; + use axum::http::Method; - if req.method() != axum::http::Method::POST || config.max_depth == 0 { + let method_has_body = matches!( + *req.method(), + Method::POST | Method::PUT | Method::PATCH | Method::DELETE + ); + if !method_has_body || config.max_depth == 0 { return next.run(req).await; } @@ -1237,12 +1361,18 @@ mod tests { #[test] fn signal_visitor_secret_uses_jwt_secret_when_present() { let secret = Some("my-jwt-secret".to_string()); - assert_eq!(signal_visitor_secret(&secret), "my-jwt-secret"); + assert_eq!( + signal_visitor_secret(&secret), + Some("my-jwt-secret".to_string()) + ); } #[test] - fn signal_visitor_secret_falls_back_to_default_when_absent() { - assert_eq!(signal_visitor_secret(&None), DEFAULT_SIGNAL_SECRET); + fn signal_visitor_secret_is_none_when_jwt_secret_absent() { + // Refuse to mint a constant fallback — predictable visitor IDs would + // let an attacker correlate sessions across users. + assert_eq!(signal_visitor_secret(&None), None); + assert_eq!(signal_visitor_secret(&Some(String::new())), None); } #[test] diff --git a/crates/forge-runtime/src/gateway/sse.rs b/crates/forge-runtime/src/gateway/sse.rs index 7dd18174..70efd5c5 100644 --- a/crates/forge-runtime/src/gateway/sse.rs +++ b/crates/forge-runtime/src/gateway/sse.rs @@ -5,7 +5,7 @@ use std::convert::Infallible; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Duration; +use std::time::{Duration, Instant}; use axum::Json; use axum::extract::{Extension, Query, State}; @@ -21,6 +21,71 @@ use subtle::ConstantTimeEq; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; +/// Maximum number of outstanding SSE tickets held in memory. Each ticket is +/// small (~200 B), so 10k is a ~2 MB cap. New issuance evicts expired +/// entries before falling back to a hard reject. +const MAX_SSE_TICKETS: usize = 10_000; +/// SSE ticket lifetime. The original ISSUES.md item requires "≤ 60 s"; +/// 30 s is short enough to bound replay risk yet generous for slow clients +/// to complete the POST + EventSource handshake. +const SSE_TICKET_TTL_SECS: u64 = 30; +const SSE_TICKET_TTL_SECS_STR: &str = "30"; + +/// One-shot SSE authentication ticket. Issued via `POST /events/ticket` +/// against a validated bearer header; consumed exactly once by the SSE +/// `GET /events?ticket=…` upgrade. Stored in-process (no DB), bound to +/// the caller's resolved client IP so a leaked ticket from logs cannot +/// be replayed from a different origin. +struct TicketEntry { + auth: AuthContext, + client_ip: Option, + expires_at: Instant, +} + +/// Process-local store of outstanding SSE tickets. Bounded, TTL-evicted, +/// one-time-use. Tickets are opaque uuid v4 strings (122 bits). +#[derive(Default)] +struct TicketStore { + entries: DashMap, +} + +impl TicketStore { + fn new() -> Self { + Self { + entries: DashMap::new(), + } + } + + /// Drop expired tickets. Called opportunistically on insert and consume. + fn sweep_expired(&self) { + let now = Instant::now(); + self.entries.retain(|_, entry| entry.expires_at > now); + } + + /// Insert a fresh ticket. Returns `false` when at capacity even after + /// sweeping expired entries (caller should reject with 503). + fn insert(&self, ticket: String, entry: TicketEntry) -> bool { + if self.entries.len() >= MAX_SSE_TICKETS { + self.sweep_expired(); + if self.entries.len() >= MAX_SSE_TICKETS { + return false; + } + } + self.entries.insert(ticket, entry); + true + } + + /// Atomically remove and validate a ticket. Returns `None` if missing, + /// already consumed, or expired. + fn consume(&self, ticket: &str) -> Option { + let (_, entry) = self.entries.remove(ticket)?; + if entry.expires_at <= Instant::now() { + return None; + } + Some(entry) + } +} + /// Wraps an mpsc::Receiver as a Stream for SSE. struct ReceiverStream { rx: mpsc::Receiver, @@ -36,7 +101,6 @@ impl Stream for ReceiverStream { use forge_core::function::AuthContext; use forge_core::realtime::{SessionId, SubscriptionId}; -use super::auth::AuthMiddleware; use crate::realtime::Reactor; use crate::realtime::RealtimeMessage; @@ -62,13 +126,6 @@ fn same_principal(a: &AuthContext, b: &AuthContext) -> bool { } } -fn resolve_sse_auth_context( - request_auth: &AuthContext, - query_auth: Option, -) -> AuthContext { - query_auth.unwrap_or_else(|| request_auth.clone()) -} - #[allow(clippy::result_large_err)] fn authorize_session_access( session: &SseSessionData, @@ -183,8 +240,9 @@ impl Default for SseConfig { /// SSE query parameters. #[derive(Debug, Deserialize)] pub struct SseQuery { - /// Authentication token. - pub token: Option, + /// One-shot ticket obtained from `POST /events/ticket`. Required when + /// `EventSource` cannot send an `Authorization` header (browsers). + pub ticket: Option, } struct SseSessionData { @@ -199,7 +257,6 @@ struct SseSessionData { #[derive(Clone)] pub struct SseState { reactor: Arc, - auth_middleware: Arc, /// Per-session data: auth context and subscription mappings (sharded). sessions: Arc>, /// Per-user session count for O(1) limit enforcement. @@ -208,28 +265,26 @@ pub struct SseState { ip_session_counts: Arc>, /// Per-user subscription count across all sessions. user_subscription_counts: Arc>, + /// One-shot SSE auth tickets. See `TicketStore` for semantics. + tickets: Arc, config: SseConfig, } impl SseState { /// Create new SSE state with default config. - pub fn new(reactor: Arc, auth_middleware: Arc) -> Self { - Self::with_config(reactor, auth_middleware, SseConfig::default()) + pub fn new(reactor: Arc) -> Self { + Self::with_config(reactor, SseConfig::default()) } /// Create new SSE state with custom config. - pub fn with_config( - reactor: Arc, - auth_middleware: Arc, - config: SseConfig, - ) -> Self { + pub fn with_config(reactor: Arc, config: SseConfig) -> Self { Self { reactor, - auth_middleware, sessions: Arc::new(DashMap::new()), user_session_counts: Arc::new(DashMap::new()), ip_session_counts: Arc::new(DashMap::new()), user_subscription_counts: Arc::new(DashMap::new()), + tickets: Arc::new(TicketStore::new()), config, } } @@ -570,23 +625,51 @@ pub async fn sse_handler( let keepalive_secs = state.config.keepalive_interval_secs; let cancel_token = CancellationToken::new(); - let query_auth = if let Some(token) = &query.token { - match state.auth_middleware.validate_token_async(token).await { - Ok(claims) => Some(super::auth::build_auth_context_from_claims(claims)), - Err(e) => { - tracing::warn!("SSE token validation failed: {}", e); + let client_ip = resolved_ip.0; + + // Authentication resolution order: + // 1. If the request was authenticated by the `Authorization` header + // (auth_middleware ran upstream), use that. Header is authoritative. + // 2. Otherwise, if a `?ticket=` is supplied, consume it. The ticket + // was minted against a validated bearer header at `/events/ticket` + // and is bound to the resolved client IP, so a leaked URL cannot + // be replayed from a different origin. + // 3. Otherwise, treat the connection as anonymous. + // + // JWTs are deliberately never accepted in the URL. Query strings appear + // in access logs, browser history, referrer headers, and proxy caches. + let auth_context = if request_auth.is_authenticated() { + request_auth.clone() + } else if let Some(ticket) = &query.ticket { + match state.tickets.consume(ticket) { + Some(entry) => { + // Bind ticket to the IP that requested it. Reject if either + // side has no resolved IP (strict) or they disagree. The + // server-side store already guarantees one-shot use. + let ip_match = match (&entry.client_ip, &client_ip) { + (Some(a), Some(b)) => a == b, + _ => false, + }; + if !ip_match { + tracing::warn!("SSE ticket IP mismatch; rejecting"); + return super::response::RpcResponse::error( + super::response::RpcError::unauthorized("SSE ticket IP mismatch"), + ) + .into_response(); + } + entry.auth + } + None => { + tracing::warn!("SSE ticket missing, expired, or already consumed"); return super::response::RpcResponse::error( - super::response::RpcError::unauthorized("Invalid authentication token"), + super::response::RpcError::unauthorized("Invalid or expired SSE ticket"), ) .into_response(); } } } else { - None + request_auth.clone() }; - let auth_context = resolve_sse_auth_context(&request_auth, query_auth); - - let client_ip = resolved_ip.0; // UUIDv4 provides 122 bits of randomness, sufficient for session secret entropy let session_secret = uuid::Uuid::new_v4().to_string(); // Authenticated sessions without an explicit exp claim get a default @@ -750,6 +833,61 @@ pub async fn sse_handler( .into_response() } +/// Response body for `POST /events/ticket`. +#[derive(Debug, Serialize)] +pub struct SseTicketResponse { + /// Opaque single-use ticket. Send back as `?ticket=…` on the next + /// `GET /events` request. + pub ticket: String, + /// Lifetime hint in seconds; clients should connect well before this. + pub expires_in_secs: u64, +} + +/// SSE ticket handler for POST `/events/ticket`. Issues a short-lived, +/// IP-bound, single-use ticket so browsers (whose `EventSource` cannot +/// set custom headers) can authenticate the SSE upgrade without putting +/// a long-lived JWT in the URL. +/// +/// Requires an authenticated bearer header. Anonymous callers get 401: +/// anonymous SSE streams can simply connect to `/events` without a ticket. +pub async fn sse_ticket_handler( + State(state): State>, + Extension(request_auth): Extension, + Extension(resolved_ip): Extension, +) -> impl IntoResponse { + if !request_auth.is_authenticated() { + return super::response::RpcResponse::error(super::response::RpcError::unauthorized( + "Authentication required to mint an SSE ticket", + )) + .into_response(); + } + + let ticket = uuid::Uuid::new_v4().to_string(); + let entry = TicketEntry { + auth: request_auth.clone(), + client_ip: resolved_ip.0, + expires_at: Instant::now() + Duration::from_secs(SSE_TICKET_TTL_SECS), + }; + + if !state.tickets.insert(ticket.clone(), entry) { + return ( + StatusCode::SERVICE_UNAVAILABLE, + [(axum::http::header::RETRY_AFTER, SSE_TICKET_TTL_SECS_STR)], + Json( + SseError::new("SSE_TICKET_CAPACITY", "Too many outstanding SSE tickets") + .with_retry_after(SSE_TICKET_TTL_SECS), + ), + ) + .into_response(); + } + + Json(SseTicketResponse { + ticket, + expires_in_secs: SSE_TICKET_TTL_SECS, + }) + .into_response() +} + /// Convert realtime message to SSE message. fn convert_realtime_to_sse(msg: RealtimeMessage) -> Option { match msg { @@ -1232,26 +1370,51 @@ mod tests { } #[test] - fn resolve_sse_auth_context_prefers_request_auth_when_query_token_absent() { - let request_auth = + fn ticket_store_consume_is_one_shot() { + let store = TicketStore::new(); + let auth = AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new()); - - let resolved = resolve_sse_auth_context(&request_auth, None); - - assert!(resolved.is_authenticated()); - assert_eq!(resolved.principal_id(), request_auth.principal_id()); + let entry = TicketEntry { + auth, + client_ip: Some("1.2.3.4".into()), + expires_at: Instant::now() + Duration::from_secs(30), + }; + assert!(store.insert("t1".into(), entry)); + assert!(store.consume("t1").is_some()); + assert!(store.consume("t1").is_none(), "second consume must fail"); } #[test] - fn resolve_sse_auth_context_prefers_query_token_when_present() { - let request_auth = - AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new()); - let query_auth = + fn ticket_store_rejects_expired() { + let store = TicketStore::new(); + let auth = AuthContext::authenticated(Uuid::new_v4(), vec!["user".to_string()], HashMap::new()); + let entry = TicketEntry { + auth, + client_ip: None, + expires_at: Instant::now() - Duration::from_secs(1), + }; + assert!(store.insert("t2".into(), entry)); + assert!( + store.consume("t2").is_none(), + "expired ticket must not validate" + ); + } - let resolved = resolve_sse_auth_context(&request_auth, Some(query_auth.clone())); - - assert_eq!(resolved.principal_id(), query_auth.principal_id()); + #[test] + fn ticket_store_caps_at_max() { + let store = TicketStore::new(); + let make_entry = || TicketEntry { + auth: AuthContext::unauthenticated(), + client_ip: None, + expires_at: Instant::now() + Duration::from_secs(30), + }; + // Fill to cap. + for i in 0..MAX_SSE_TICKETS { + assert!(store.insert(format!("k{i}"), make_entry())); + } + // One more should fail (no expired entries to sweep). + assert!(!store.insert("overflow".into(), make_entry())); } #[test] diff --git a/crates/forge-runtime/src/gateway/tls.rs b/crates/forge-runtime/src/gateway/tls.rs index 3e7b6655..fcb71312 100644 --- a/crates/forge-runtime/src/gateway/tls.rs +++ b/crates/forge-runtime/src/gateway/tls.rs @@ -274,11 +274,35 @@ fn read_pem_certs(path: &str) -> Result>> { } fn read_pem_key(path: &str) -> Result> { + warn_if_key_world_readable(path); PrivateKeyDer::from_pem_file(path).map_err(|e| { ForgeError::config(format!("failed to read PEM private key from '{path}': {e}")) }) } +/// Emit a loud warning if the TLS private key is readable by group or other. +/// We don't refuse to start — operators may rely on a key-management daemon +/// that enforces its own ACL model — but silently loading a 0644 key would +/// be a footgun on shared hosts. +#[cfg(unix)] +fn warn_if_key_world_readable(path: &str) { + use std::os::unix::fs::MetadataExt; + let Ok(meta) = std::fs::metadata(path) else { + return; + }; + let mode = meta.mode() & 0o777; + if mode & 0o077 != 0 { + tracing::warn!( + path = %path, + mode = format!("{:o}", mode), + "TLS private key is readable by group or other; tighten to 0600 (chmod 600)" + ); + } +} + +#[cfg(not(unix))] +fn warn_if_key_world_readable(_path: &str) {} + #[cfg(test)] #[allow(clippy::unwrap_used, clippy::indexing_slicing)] mod tests { diff --git a/crates/forge-runtime/src/jobs/dispatcher.rs b/crates/forge-runtime/src/jobs/dispatcher.rs index 4277afe0..198c3c51 100644 --- a/crates/forge-runtime/src/jobs/dispatcher.rs +++ b/crates/forge-runtime/src/jobs/dispatcher.rs @@ -1,10 +1,9 @@ use std::future::Future; use std::pin::Pin; -use std::time::Duration; use chrono::{DateTime, Utc}; use forge_core::function::JobDispatch; -use forge_core::job::{ForgeJob, JobInfo, JobPriority}; +use forge_core::job::JobInfo; use uuid::Uuid; use super::queue::{JobQueue, JobRecord}; @@ -21,93 +20,6 @@ impl JobDispatcher { Self { queue, registry } } - pub async fn dispatch( - &self, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - self.dispatch_with_info(&info, serde_json::to_value(args)?, owner_subject) - .await - } - - pub async fn dispatch_in( - &self, - delay: Duration, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - let scheduled_at = Utc::now() - + chrono::Duration::from_std(delay) - .map_err(|_| forge_core::ForgeError::InvalidArgument("delay too large".into()))?; - self.dispatch_at_with_info( - &info, - serde_json::to_value(args)?, - scheduled_at, - owner_subject, - ) - .await - } - - pub async fn dispatch_at( - &self, - at: DateTime, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - self.dispatch_at_with_info(&info, serde_json::to_value(args)?, at, owner_subject) - .await - } - - pub async fn dispatch_by_name( - &self, - job_type: &str, - args: serde_json::Value, - owner_subject: Option, - ) -> Result { - let entry = self.registry.get(job_type).ok_or_else(|| { - forge_core::ForgeError::NotFound(format!("Job type '{}' not found", job_type)) - })?; - - self.dispatch_with_info(&entry.info, args, owner_subject) - .await - } - - async fn dispatch_with_info( - &self, - info: &JobInfo, - args: serde_json::Value, - owner_subject: Option, - ) -> Result { - let mut job = JobRecord::new( - info.name, - args, - info.priority, - info.retry.max_attempts as i32, - ) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } - /// Request cancellation for a job. /// /// If `caller_subject` is provided, the cancellation will only succeed if @@ -124,32 +36,6 @@ impl JobDispatcher { .map_err(forge_core::ForgeError::Database) } - async fn dispatch_at_with_info( - &self, - info: &JobInfo, - args: serde_json::Value, - scheduled_at: DateTime, - owner_subject: Option, - ) -> Result { - let mut job = JobRecord::new( - info.name, - args, - info.priority, - info.retry.max_attempts as i32, - ) - .with_scheduled_at(scheduled_at) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } - async fn dispatch_with_info_and_tenant( &self, info: &JobInfo, @@ -203,63 +89,6 @@ impl JobDispatcher { .await .map_err(forge_core::ForgeError::Database) } - - pub async fn dispatch_idempotent( - &self, - idempotency_key: impl Into, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - let mut job = JobRecord::new( - info.name, - serde_json::to_value(args)?, - info.priority, - info.retry.max_attempts as i32, - ) - .with_idempotency_key(idempotency_key) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } - - pub async fn dispatch_with_priority( - &self, - priority: JobPriority, - args: J::Args, - owner_subject: Option, - ) -> Result - where - J::Args: serde::Serialize, - { - let info = J::info(); - let mut job = JobRecord::new( - info.name, - serde_json::to_value(args)?, - priority, - info.retry.max_attempts as i32, - ) - .with_owner_subject(owner_subject); - - if let Some(cap) = info.worker_capability { - job = job.with_capability(cap); - } - - self.queue - .enqueue(job) - .await - .map_err(forge_core::ForgeError::Database) - } } impl JobDispatch for JobDispatcher { @@ -442,7 +271,7 @@ mod integration_tests { let dispatcher = dispatcher_with_registry(db.pool().clone(), |_| {}); let err = dispatcher - .dispatch_by_name("ghost", serde_json::json!({}), None) + .dispatch_by_name("ghost", serde_json::json!({}), None, None) .await .expect_err("unknown job must error"); @@ -470,6 +299,7 @@ mod integration_tests { "ship", serde_json::json!({"to": "warehouse"}), Some("u-1".into()), + None, ) .await .unwrap(); @@ -614,7 +444,7 @@ mod integration_tests { }); let job_id = dispatcher - .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into())) + .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into()), None) .await .unwrap(); @@ -636,7 +466,7 @@ mod integration_tests { }); let job_id = dispatcher - .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into())) + .dispatch_by_name("ship", serde_json::json!({}), Some("alice".into()), None) .await .unwrap(); diff --git a/crates/forge-runtime/src/jobs/executor.rs b/crates/forge-runtime/src/jobs/executor.rs index 9069a131..f796149e 100644 --- a/crates/forge-runtime/src/jobs/executor.rs +++ b/crates/forge-runtime/src/jobs/executor.rs @@ -10,6 +10,11 @@ use super::queue::{JobQueue, JobRecord}; use super::registry::{JobEntry, JobRegistry}; use crate::observability; +/// How often to poll the progress channel between updates from the running job. +/// Short enough that progress propagates to subscribers within one frame; long +/// enough that an idle job doesn't burn CPU on the polling task. +const PROGRESS_POLL_INTERVAL: Duration = Duration::from_millis(50); + /// Executes jobs with timeout and retry handling. pub struct JobExecutor { queue: JobQueue, @@ -139,7 +144,7 @@ impl JobExecutor { } } Err(std::sync::mpsc::TryRecvError::Empty) => { - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + tokio::time::sleep(PROGRESS_POLL_INTERVAL).await; } Err(std::sync::mpsc::TryRecvError::Disconnected) => { break; @@ -171,6 +176,24 @@ impl JobExecutor { c }; if let Some(ref subject) = job.owner_subject { + // Defense in depth: the job row stores `tenant_id` and + // `owner_subject` independently. We trust the dispatcher to pair + // them correctly, but if a stale/forged dispatch slipped a + // mismatched pair through, the handler would execute cross-tenant + // (#6 in issues doc). There's no system `users` table to consult, + // so the framework can't reject the row authoritatively here. Warn + // when an owner principal is present without a tenant — the shape + // most likely to indicate a dispatch path that lost tenancy. (Single- + // tenant apps legitimately dispatch with no tenant_id, so this is a + // warning, not an assertion.) + if job.tenant_id.is_none() && uuid::Uuid::parse_str(subject).is_ok() { + tracing::warn!( + job_id = %job.id, + job_type = %job.job_type, + owner_subject = %subject, + "Job has UUID owner_subject but no tenant_id; if the app is multi-tenant this is a dispatch-side bug — the handler will run with empty tenant scope" + ); + } let mut claims = std::collections::HashMap::new(); if let Some(tid) = job.tenant_id { claims.insert( @@ -240,13 +263,40 @@ impl JobExecutor { } }); + // Drop guard: ensures the heartbeat task is aborted even if `execute` + // is cancelled (e.g. `drain_jobs` calling `abort_all` on shutdown). + // Without this the task would keep refreshing `last_heartbeat` for up + // to 30s, blocking `release_stale` from requeueing the row (#9 in + // issues doc). + struct HeartbeatGuard { + stop: tokio::sync::watch::Sender, + handle: Option>, + } + impl Drop for HeartbeatGuard { + fn drop(&mut self) { + let _ = self.stop.send(true); + if let Some(h) = self.handle.take() { + h.abort(); + } + } + } + let mut heartbeat_guard = HeartbeatGuard { + stop: heartbeat_stop_tx, + handle: Some(heartbeat_task), + }; + let job_timeout = entry.info.timeout; let exec_start = std::time::Instant::now(); let result = timeout(job_timeout, self.run_handler(&entry, &ctx, &job.input)).await; let exec_duration_ms = exec_start.elapsed().as_millis() as i32; - let _ = heartbeat_stop_tx.send(true); - let _ = heartbeat_task.await; + // Happy path: signal the heartbeat task and await it cleanly. The + // guard will still run on early return, abort() is a no-op on a + // joined handle. + let _ = heartbeat_guard.stop.send(true); + if let Some(h) = heartbeat_guard.handle.take() { + let _ = h.await; + } let ttl = entry.info.ttl; @@ -326,8 +376,14 @@ impl JobExecutor { let error_msg = format!("Job timed out after {:?}", job_timeout); let should_retry = job.attempts < job.max_attempts; + // Mirror the failure path: honor the job's configured backoff + // strategy rather than hardcoding 60s (#13 in issues doc). let retry_delay = if should_retry { - Some(chrono::Duration::seconds(60)) + let std_delay = entry.info.retry.calculate_backoff(job.attempts as u32); + Some( + chrono::Duration::from_std(std_delay) + .unwrap_or(chrono::Duration::seconds(60)), + ) } else { None }; diff --git a/crates/forge-runtime/src/jobs/queue.rs b/crates/forge-runtime/src/jobs/queue.rs index 03b82452..c1d03f45 100644 --- a/crates/forge-runtime/src/jobs/queue.rs +++ b/crates/forge-runtime/src/jobs/queue.rs @@ -137,15 +137,30 @@ impl JobQueue { ) -> Result { // Fast path: check for existing idempotent job before attempting INSERT. // The UNIQUE partial index on idempotency_key guards against races. + // + // Scope the lookup by `job_type` so apps reusing the same key across + // multiple job types (e.g. `payment-{id}` for `charge` and `refund`) + // get the right job back. NOTE: the partial unique index in + // `v001_initial.sql` does NOT yet include `job_type`, so cross-type + // idempotency collisions are still rejected at the database level + // even though this check would accept them. Tracking issue: update + // the index to `(job_type, idempotency_key)` once migration is safe. if let Some(ref key) = job.idempotency_key { - let existing = sqlx::query_scalar!( + // Runtime query: lookup gained a `job_type` filter so apps reusing + // the same key across job types map to the right row. Stays as + // sqlx::query rather than query_scalar! to avoid invalidating the + // offline cache on a non-critical path. + #[allow(clippy::disallowed_methods)] + let existing: Option = sqlx::query_scalar( r#" SELECT id FROM forge_jobs WHERE idempotency_key = $1 + AND job_type = $2 AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled') "#, - key ) + .bind(key) + .bind(&job.job_type) .fetch_optional(&mut *conn) .await?; @@ -187,15 +202,19 @@ impl JobQueue { .await?; // If ON CONFLICT fired (race with another enqueue), fetch the winner's ID. + // Runtime query for the same reason as the fast-path lookup above. if let Some(ref key) = job.idempotency_key { - let id = sqlx::query_scalar!( + #[allow(clippy::disallowed_methods)] + let id: Option = sqlx::query_scalar( r#" SELECT id FROM forge_jobs WHERE idempotency_key = $1 + AND job_type = $2 AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled') "#, - key ) + .bind(key) + .bind(&job.job_type) .fetch_optional(&mut *conn) .await?; @@ -448,7 +467,11 @@ impl JobQueue { "#, job_id, error, - delay.num_seconds() as f64, + // Millisecond precision, not num_seconds(): the latter truncates + // to whole seconds, so a sub-second backoff (the common first + // retry: 1s base - 25% jitter = 0.75s) collapses to secs => 0, + // dropping the backoff entirely and retrying instantly. + delay.num_milliseconds() as f64 / 1000.0, ) .execute(&self.pool) .await?; @@ -609,6 +632,11 @@ impl JobQueue { } let retention_secs = Self::DEFAULT_RETENTION.as_secs() as f64; + // Defense-in-depth: the SELECT guard above already verified the caller + // owns this row, but include the ownership predicate directly in the + // UPDATE so a future refactor that reorders these blocks can't silently + // drop the check. `caller_subject IS NULL` lets system-side callers + // (no caller) cancel rows with no owner_subject. #[allow(clippy::disallowed_methods)] let updated = sqlx::query( r#" @@ -620,11 +648,17 @@ impl JobQueue { expires_at = NOW() + make_interval(secs => $3) WHERE id = $1 AND status NOT IN ('completed', 'failed', 'dead_letter', 'cancelled') + AND ( + owner_subject IS NULL + OR $4::text IS NULL + OR owner_subject = $4::text + ) "#, ) .bind(job_id) .bind(reason) .bind(retention_secs) + .bind(caller_subject) .execute(&self.pool) .await?; @@ -715,6 +749,13 @@ impl JobQueue { ( status = 'claimed' AND claimed_at < NOW() - make_interval(secs => $1) + -- Don't yank a claim that has produced a recent heartbeat: + -- the worker may have transitioned to running on its own + -- side and we just haven't seen the `start()` UPDATE land yet. + AND ( + last_heartbeat IS NULL + OR last_heartbeat < NOW() - make_interval(secs => $1) + ) ) OR ( status = 'running' @@ -883,7 +924,7 @@ mod integration_tests { let job_id = queue.enqueue(job).await.expect("Failed to enqueue"); let claimed = queue - .claim(worker_id, &[], true, 10) + .claim(worker_id, &["default".into()], true, 10) .await .expect("Failed to claim"); assert_eq!(claimed.len(), 1); @@ -912,11 +953,17 @@ mod integration_tests { } let worker1 = Uuid::new_v4(); - let batch1 = queue.claim(worker1, &[], true, 2).await.expect("claim1"); + let batch1 = queue + .claim(worker1, &["default".into()], true, 2) + .await + .expect("claim1"); assert_eq!(batch1.len(), 2); let worker2 = Uuid::new_v4(); - let batch2 = queue.claim(worker2, &[], true, 2).await.expect("claim2"); + let batch2 = queue + .claim(worker2, &["default".into()], true, 2) + .await + .expect("claim2"); assert_eq!(batch2.len(), 1); let ids1: Vec = batch1.iter().map(|j| j.id).collect(); @@ -943,7 +990,10 @@ mod integration_tests { let high = JobRecord::new("high_job", serde_json::json!({}), JobPriority::Critical, 3); queue.enqueue(high).await.expect("enqueue high"); - let claimed = queue.claim(worker_id, &[], true, 1).await.expect("claim"); + let claimed = queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); assert_eq!(claimed.len(), 1); assert_eq!(claimed[0].job_type, "high_job"); @@ -959,7 +1009,10 @@ mod integration_tests { let job = JobRecord::new("process", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue .complete(job_id, serde_json::json!({"result": "done"}), None) @@ -982,7 +1035,10 @@ mod integration_tests { let job = JobRecord::new("flaky", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue @@ -1011,7 +1067,10 @@ mod integration_tests { let job = JobRecord::new("fatal", serde_json::json!({}), JobPriority::Normal, 1); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue @@ -1189,7 +1248,10 @@ mod integration_tests { let job = JobRecord::new("long_task", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue.heartbeat(job_id).await.expect("heartbeat"); @@ -1205,7 +1267,10 @@ mod integration_tests { let job = JobRecord::new("export", serde_json::json!({}), JobPriority::Normal, 3); let job_id = queue.enqueue(job).await.expect("enqueue"); - queue.claim(worker_id, &[], true, 1).await.expect("claim"); + queue + .claim(worker_id, &["default".into()], true, 1) + .await + .expect("claim"); queue.start(job_id, worker_id, 1).await.expect("start"); queue diff --git a/crates/forge-runtime/src/jobs/registry.rs b/crates/forge-runtime/src/jobs/registry.rs index 676d41a8..0aa55fb2 100644 --- a/crates/forge-runtime/src/jobs/registry.rs +++ b/crates/forge-runtime/src/jobs/registry.rs @@ -5,29 +5,9 @@ use std::sync::Arc; use forge_core::Result; use forge_core::job::{ForgeJob, JobContext, JobInfo}; +use forge_core::util::normalize_handler_args as normalize_args; use serde_json::Value; -/// Converts `null` to `{}` and unwraps single-key `args`/`input` envelopes. -fn normalize_args(args: Value) -> Value { - let unwrapped = match &args { - Value::Object(map) if map.len() == 1 => { - if map.contains_key("args") { - map.get("args").cloned().unwrap_or(Value::Null) - } else if map.contains_key("input") { - map.get("input").cloned().unwrap_or(Value::Null) - } else { - args - } - } - _ => args, - }; - - match &unwrapped { - Value::Null => Value::Object(serde_json::Map::new()), - _ => unwrapped, - } -} - pub type BoxedJobHandler = Arc< dyn Fn(&JobContext, Value) -> Pin> + Send + '_>> + Send @@ -151,53 +131,9 @@ impl JobRegistry { #[allow(clippy::unwrap_used, clippy::indexing_slicing)] mod tests { use super::*; - use serde_json::json; - - // jobs/registry collapses null to {} so derive(Default) empty-struct args deserialize correctly; - // function/registry keeps null as-is for unit () — this divergence is the contract. - #[test] - fn normalize_args_converts_null_to_empty_object() { - assert_eq!(normalize_args(json!(null)), json!({})); - } - - #[test] - fn normalize_args_keeps_empty_object_intact() { - // `{}` (len 0) skips the envelope unwrap and the null branch. - assert_eq!(normalize_args(json!({})), json!({})); - } - #[test] - fn normalize_args_unwraps_args_envelope() { - assert_eq!(normalize_args(json!({"args": {"id": 7}})), json!({"id": 7})); - // The trailing null-to-{} step still applies after unwrap. - assert_eq!(normalize_args(json!({"args": null})), json!({})); - } - - #[test] - fn normalize_args_unwraps_input_envelope() { - assert_eq!(normalize_args(json!({"input": [1,2]})), json!([1, 2])); - } - - #[test] - fn normalize_args_keeps_other_single_key_objects_intact() { - // A handler with `struct Args { id: u32 }` must receive {"id":...} - // as-is — envelope stripping only fires for `args`/`input`. - assert_eq!(normalize_args(json!({"id": 7})), json!({"id": 7})); - } - - #[test] - fn normalize_args_keeps_multi_key_objects_intact() { - let v = json!({"a": 1, "b": 2}); - assert_eq!(normalize_args(v.clone()), v); - } - - #[test] - fn normalize_args_keeps_non_null_non_object_values_intact() { - assert_eq!(normalize_args(json!(42)), json!(42)); - assert_eq!(normalize_args(json!("x")), json!("x")); - assert_eq!(normalize_args(json!([1])), json!([1])); - assert_eq!(normalize_args(json!(true)), json!(true)); - } + // normalize_args is exercised via `forge_core::util` tests; jobs/registry + // now delegates to that shared helper. fn sample_info(name: &'static str) -> JobInfo { JobInfo { diff --git a/crates/forge-runtime/src/jobs/worker.rs b/crates/forge-runtime/src/jobs/worker.rs index 3cc9c788..9c6a59e4 100644 --- a/crates/forge-runtime/src/jobs/worker.rs +++ b/crates/forge-runtime/src/jobs/worker.rs @@ -166,25 +166,30 @@ impl Worker { let wakeup_notify = Arc::new(tokio::sync::Notify::new()); let wakeup_trigger = wakeup_notify.clone(); let wakeup_shutdown = shutdown_notify.clone(); - if let Some(mut rx) = self.notify_bus.subscribe("forge_jobs_available") { - tokio::spawn(async move { - loop { - tokio::select! { - _ = wakeup_shutdown.notified() => return, - result = rx.recv() => { - match result { - Ok(_) => wakeup_trigger.notify_one(), - Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { - tracing::debug!(missed = n, "Job wakeup receiver lagged"); - wakeup_trigger.notify_one(); + // Track the forwarder so shutdown can await it instead of leaking + // the JoinHandle (#8 in issues doc). + let forwarder_handle = self + .notify_bus + .subscribe("forge_jobs_available") + .map(|mut rx| { + tokio::spawn(async move { + loop { + tokio::select! { + _ = wakeup_shutdown.notified() => return, + result = rx.recv() => { + match result { + Ok(_) => wakeup_trigger.notify_one(), + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::debug!(missed = n, "Job wakeup receiver lagged"); + wakeup_trigger.notify_one(); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => return, } - Err(tokio::sync::broadcast::error::RecvError::Closed) => return, } } } - } + }) }); - } tracing::debug!( worker_id = %self.id, @@ -200,6 +205,9 @@ impl Worker { tracing::debug!(worker_id = %self.id, "Worker shutting down"); shutdown_notify.notify_waiters(); let _ = cleanup_handle.await; + if let Some(h) = forwarder_handle { + let _ = h.await; + } self.drain_jobs(&mut job_tasks).await; break; } diff --git a/crates/forge-runtime/src/kv/store.rs b/crates/forge-runtime/src/kv/store.rs index 0369ad11..31834378 100644 --- a/crates/forge-runtime/src/kv/store.rs +++ b/crates/forge-runtime/src/kv/store.rs @@ -5,6 +5,12 @@ use sqlx::PgPool; use forge_core::error::{ForgeError, Result}; +/// Hard cap on the value size accepted by `set` / `set_if_absent`. PG can +/// store much larger BYTEA blobs, but the KV API is meant for small +/// configuration / lock payloads, not blobs. Multi-MB values are almost +/// always a misuse (and round-trip the protocol per call). +const MAX_VALUE_BYTES: usize = 1024 * 1024; + /// PostgreSQL-backed key-value store. /// /// Provides a simple get/set/delete/set_if_absent/increment API over @@ -24,13 +30,37 @@ impl KvStore { Self { pool, namespace } } - fn prefixed_key(&self, key: &str) -> String { - format!("{}:{}", self.namespace, key) + fn prefixed_key(&self, key: &str) -> Result { + // Reject `:` in either namespace or key — the prefix separator must + // be unambiguous so `(namespace=a, key=b:foo)` and + // `(namespace=a:b, key=foo)` can't collide on the same physical key. + if self.namespace.contains(':') { + return Err(ForgeError::InvalidArgument(format!( + "kv namespace must not contain ':' (got {:?})", + self.namespace + ))); + } + if key.contains(':') { + return Err(ForgeError::InvalidArgument( + "kv key must not contain ':' (reserved as namespace separator)".to_string(), + )); + } + Ok(format!("{}:{}", self.namespace, key)) + } + + fn check_value_size(value: &[u8]) -> Result<()> { + if value.len() > MAX_VALUE_BYTES { + return Err(ForgeError::InvalidArgument(format!( + "kv value exceeds {MAX_VALUE_BYTES} byte limit (got {})", + value.len() + ))); + } + Ok(()) } /// Get a value by key. Returns `None` if the key doesn't exist or is expired. pub async fn get(&self, key: &str) -> Result>> { - let full_key = self.prefixed_key(key); + let full_key = self.prefixed_key(key)?; let row = sqlx::query_scalar!( r#" SELECT value @@ -49,7 +79,8 @@ impl KvStore { /// Set a key to a value. Overwrites any existing value. pub async fn set(&self, key: &str, value: &[u8], ttl: Option) -> Result<()> { - let full_key = self.prefixed_key(key); + Self::check_value_size(value)?; + let full_key = self.prefixed_key(key)?; let expires_at = ttl.map(|d| Utc::now() + d); sqlx::query!( r#" @@ -80,10 +111,21 @@ impl KvStore { value: &[u8], ttl: Option, ) -> Result { - let full_key = self.prefixed_key(key); + Self::check_value_size(value)?; + let full_key = self.prefixed_key(key)?; let expires_at = ttl.map(|d| Utc::now() + d); - // ON CONFLICT WHERE treats expired rows as absent atomically. - // Convert to query!() after next `cargo sqlx prepare`. + // Serialize concurrent reclaim-of-expired racers per key via a + // transaction-scoped advisory lock keyed on the full prefixed key. + // Without this, the ON CONFLICT WHERE branch is only race-free under + // READ COMMITTED — under REPEATABLE READ a second writer can see the + // pre-update snapshot and "succeed" against an already-claimed row. + let mut tx = self.pool.begin().await.map_err(ForgeError::Database)?; + #[allow(clippy::disallowed_methods)] + sqlx::query("SELECT pg_advisory_xact_lock(hashtext($1)::bigint)") + .bind(&full_key) + .execute(&mut *tx) + .await + .map_err(ForgeError::Database)?; #[allow(clippy::disallowed_methods)] let rows = sqlx::query( r#" @@ -97,17 +139,18 @@ impl KvStore { .bind(&full_key) .bind(value) .bind(expires_at) - .execute(&self.pool) + .execute(&mut *tx) .await .map_err(ForgeError::Database)? .rows_affected(); + tx.commit().await.map_err(ForgeError::Database)?; Ok(rows > 0) } /// Delete a key. Returns `true` if the key existed. pub async fn delete(&self, key: &str) -> Result { - let full_key = self.prefixed_key(key); + let full_key = self.prefixed_key(key)?; let result = sqlx::query!("DELETE FROM forge_kv WHERE key = $1", full_key) .execute(&self.pool) .await @@ -123,13 +166,17 @@ impl KvStore { /// /// Uses `ON CONFLICT DO UPDATE ... WHERE` to handle expired rows atomically /// without CTE snapshot isolation issues. + /// + /// **Note:** counter storage is `BIGINT`, range `[-2^63, 2^63 - 1]`. + /// Increments that overflow surface as `ForgeError::InvalidArgument` so + /// callers can choose to reset rather than retry indefinitely. pub async fn increment(&self, key: &str, delta: i64, ttl: Option) -> Result { - let full_key = self.prefixed_key(key); + let full_key = self.prefixed_key(key)?; let expires_at = ttl.map(|d| Utc::now() + d); // Expired counters reset to delta rather than accumulating. // Convert to query_scalar!() after next `cargo sqlx prepare`. #[allow(clippy::disallowed_methods)] - let row: (i64,) = sqlx::query_as( + let row: std::result::Result<(i64,), sqlx::Error> = sqlx::query_as( r#" INSERT INTO forge_kv_counters (key, value, expires_at, updated_at) VALUES ($1, $2, $3, NOW()) @@ -149,10 +196,41 @@ impl KvStore { .bind(delta) .bind(expires_at) .fetch_one(&self.pool) + .await; + + match row { + Ok((v,)) => Ok(v), + Err(sqlx::Error::Database(db_err)) if db_err.code().as_deref() == Some("22003") => { + Err(ForgeError::InvalidArgument(format!( + "counter overflow at key {full_key:?}: BIGINT range exceeded" + ))) + } + Err(e) => Err(ForgeError::Database(e)), + } + } + + /// Read a counter's current value. Returns `None` if missing or expired. + /// + /// Mirrors the TTL filter from `get()` so an expired counter behaves the + /// same as a missing one — keeps a future generic `get` over both tables + /// consistent. + pub async fn get_counter(&self, key: &str) -> Result> { + let full_key = self.prefixed_key(key)?; + #[allow(clippy::disallowed_methods)] + let row: Option<(i64,)> = sqlx::query_as( + r#" + SELECT value + FROM forge_kv_counters + WHERE key = $1 + AND (expires_at IS NULL OR expires_at > NOW()) + "#, + ) + .bind(&full_key) + .fetch_optional(&self.pool) .await .map_err(ForgeError::Database)?; - Ok(row.0) + Ok(row.map(|(v,)| v)) } /// Remove expired keys from both tables. Returns total rows cleaned up. @@ -190,8 +268,11 @@ mod tests { .expect("connect_lazy never fails for a syntactically valid URL"); let store = KvStore::new(pool, "ratelimit"); - assert_eq!(store.prefixed_key("user:42"), "ratelimit:user:42"); - assert_eq!(store.prefixed_key(""), "ratelimit:"); + // `:` in a key is now rejected — verify that and a couple of + // representative happy cases. + assert!(store.prefixed_key("user:42").is_err()); + assert_eq!(store.prefixed_key("user_42").unwrap(), "ratelimit:user_42"); + assert_eq!(store.prefixed_key("").unwrap(), "ratelimit:"); } #[tokio::test] @@ -205,7 +286,10 @@ mod tests { // physical keys — the property the namespace exists to guarantee. let a = KvStore::new(pool.clone(), "subsystem_a"); let b = KvStore::new(pool, "subsystem_b"); - assert_ne!(a.prefixed_key("shared"), b.prefixed_key("shared")); + assert_ne!( + a.prefixed_key("shared").unwrap(), + b.prefixed_key("shared").unwrap() + ); } } diff --git a/crates/forge-runtime/src/mcp/registry.rs b/crates/forge-runtime/src/mcp/registry.rs index 9a143e10..d72da116 100644 --- a/crates/forge-runtime/src/mcp/registry.rs +++ b/crates/forge-runtime/src/mcp/registry.rs @@ -3,25 +3,10 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; +use forge_core::util::normalize_handler_args as normalize_args; use forge_core::{ForgeMcpTool, McpToolContext, McpToolInfo, Result}; use serde_json::Value; -fn normalize_args(args: Value) -> Value { - let unwrapped = match args { - Value::Object(map) if map.len() == 1 => map - .get("args") - .or_else(|| map.get("input")) - .cloned() - .unwrap_or(Value::Object(map)), - other => other, - }; - - match unwrapped { - Value::Null => Value::Object(serde_json::Map::new()), - other => other, - } -} - pub type BoxedMcpToolFn = Arc< dyn Fn(&McpToolContext, Value) -> Pin> + Send + '_>> + Send diff --git a/crates/forge-runtime/src/observability/db.rs b/crates/forge-runtime/src/observability/db.rs index ec6b5e71..c3f419a1 100644 --- a/crates/forge-runtime/src/observability/db.rs +++ b/crates/forge-runtime/src/observability/db.rs @@ -3,7 +3,7 @@ use opentelemetry::metrics::{Gauge, Histogram}; use sqlx::PgPool; use std::sync::OnceLock; use std::time::{Duration, Instant}; -use tracing::{Instrument, info_span}; +use tracing::{Instrument, Level, debug_span, enabled, info_span}; const DB_SYSTEM: &str = "db.system"; const DB_OPERATION_NAME: &str = "db.operation.name"; @@ -90,43 +90,78 @@ pub fn record_query_duration(operation: &str, duration: Duration) { } /// Extract the table name from a simple SQL query, or `None` for complex ones. +/// +/// Walks the source by `char_indices` rather than fixed byte offsets so +/// non-ASCII identifiers (quoted Unicode columns/tables) can't panic the +/// slicer. `to_uppercase()` can change the byte length of a string, so we +/// can't reuse byte offsets discovered in the uppercased copy against the +/// original — locate keywords case-insensitively over the original instead. pub fn extract_table_name(sql: &str) -> Option<&str> { let sql = sql.trim(); - let upper = sql.to_uppercase(); - if upper.starts_with("SELECT") { - // SELECT ... FROM table_name ... - if let Some(from_pos) = upper.find(" FROM ") { - let after_from = &sql[from_pos + 6..]; - return extract_first_identifier(after_from.trim_start()); + if let Some(rest) = strip_keyword_prefix(sql, "INSERT INTO ") + .or_else(|| strip_keyword_prefix(sql, "DELETE FROM ")) + .or_else(|| strip_keyword_prefix(sql, "CREATE TABLE IF NOT EXISTS ")) + .or_else(|| strip_keyword_prefix(sql, "CREATE TABLE ")) + .or_else(|| strip_keyword_prefix(sql, "UPDATE ")) + { + return extract_first_identifier(rest.trim_start()); + } + + if strip_keyword_prefix(sql, "SELECT").is_some() { + // Find " FROM " case-insensitively without re-allocating a full + // uppercase copy whose byte length can diverge from the source. + if let Some(from_byte) = find_ci(sql, " FROM ") { + let after = sql.get(from_byte + " FROM ".len()..)?; + return extract_first_identifier(after.trim_start()); } - } else if upper.starts_with("INSERT INTO ") { - let after_into = &sql[12..]; - return extract_first_identifier(after_into.trim_start()); - } else if upper.starts_with("UPDATE ") { - let after_update = &sql[7..]; - return extract_first_identifier(after_update.trim_start()); - } else if upper.starts_with("DELETE FROM ") { - let after_from = &sql[12..]; - return extract_first_identifier(after_from.trim_start()); - } else if upper.starts_with("CREATE TABLE ") { - let after_table = if upper.starts_with("CREATE TABLE IF NOT EXISTS ") { - &sql[27..] - } else { - &sql[13..] - }; - return extract_first_identifier(after_table.trim_start()); } None } +fn strip_keyword_prefix<'a>(sql: &'a str, keyword: &str) -> Option<&'a str> { + if sql.len() < keyword.len() { + return None; + } + let prefix = sql.get(..keyword.len())?; + if prefix.eq_ignore_ascii_case(keyword) { + sql.get(keyword.len()..) + } else { + None + } +} + +/// Case-insensitive search for an ASCII needle. Returns the byte offset of +/// the first match in the source. +fn find_ci(haystack: &str, needle_ascii_upper: &str) -> Option { + let bytes = haystack.as_bytes(); + let n = needle_ascii_upper.as_bytes(); + if n.is_empty() || bytes.len() < n.len() { + return None; + } + 'outer: for start in 0..=bytes.len() - n.len() { + for (i, nb) in n.iter().enumerate() { + let hb = bytes.get(start + i)?; + if !hb.eq_ignore_ascii_case(nb) { + continue 'outer; + } + } + // Confirm the match begins on a UTF-8 char boundary so the caller's + // slice never bisects a multi-byte sequence. + if haystack.is_char_boundary(start) && haystack.is_char_boundary(start + n.len()) { + return Some(start); + } + } + None +} + fn extract_first_identifier(s: &str) -> Option<&str> { let end = s .find(|c: char| c.is_whitespace() || c == '(' || c == ',' || c == ';') .unwrap_or(s.len()); - if end > 0 { Some(&s[..end]) } else { None } + if end > 0 { s.get(..end) } else { None } } /// Execute a database operation with tracing and duration recording. @@ -134,7 +169,11 @@ pub async fn instrumented_query(operation: &str, table: Option<&str>, f where F: std::future::Future>, { - let span = if let Some(tbl) = table { + // Skip span allocation entirely when DEBUG isn't enabled — saves the + // ~few-hundred-ns alloc per query when the operator runs at warn/info. + let span = if !enabled!(Level::DEBUG) { + debug_span!("db.query") + } else if let Some(tbl) = table { info_span!( "db.query", db.system = DB_SYSTEM_POSTGRESQL, diff --git a/crates/forge-runtime/src/observability/metrics.rs b/crates/forge-runtime/src/observability/metrics.rs index 42632477..e386a3b0 100644 --- a/crates/forge-runtime/src/observability/metrics.rs +++ b/crates/forge-runtime/src/observability/metrics.rs @@ -2,7 +2,8 @@ use opentelemetry::{ KeyValue, global, metrics::{Counter, Gauge, Histogram, UpDownCounter}, }; -use std::sync::OnceLock; +use std::collections::HashSet; +use std::sync::{OnceLock, RwLock}; const METER_NAME: &str = "forge-runtime"; @@ -43,9 +44,10 @@ impl HttpMetrics { } pub fn record(&self, method: &str, path: &str, status: u16, duration_secs: f64) { + let normalized = normalize_path_for_metrics(path); let attributes = [ KeyValue::new("method", method.to_string()), - KeyValue::new("path", path.to_string()), + KeyValue::new("path", normalized), KeyValue::new("status", i64::from(status)), ]; @@ -54,6 +56,69 @@ impl HttpMetrics { } } +/// Soft cap on the number of distinct dynamic label values we'll let through +/// before collapsing the rest to ``. Keeps the OTel cardinality +/// bounded even if a poorly-controlled handler name leaks into the metric. +const MAX_DYNAMIC_LABELS: usize = 1000; + +static FN_LABELS_SEEN: OnceLock>> = OnceLock::new(); +static JOB_LABELS_SEEN: OnceLock>> = OnceLock::new(); + +fn capped_label(seen: &OnceLock>>, name: &str) -> String { + let set = seen.get_or_init(|| RwLock::new(HashSet::new())); + if let Ok(guard) = set.read() + && guard.contains(name) + { + return name.to_string(); + } + if let Ok(mut guard) = set.write() { + if guard.contains(name) { + return name.to_string(); + } + if guard.len() >= MAX_DYNAMIC_LABELS { + return "".to_string(); + } + guard.insert(name.to_string()); + return name.to_string(); + } + name.to_string() +} + +/// Map a concrete request path to a stable route template. Without route +/// info from the router, we apply heuristics: replace UUIDs and all-numeric +/// path segments with `:id`. Anything we don't recognize stays as-is so +/// known fixed routes (e.g. `/_api/health`) keep their identity. +pub fn normalize_path_for_metrics(path: &str) -> String { + fn is_dynamic(seg: &str) -> bool { + !seg.is_empty() && (looks_like_uuid(seg) || seg.chars().all(|c| c.is_ascii_digit())) + } + // The common case is a fixed route with nothing to rewrite — return it in a + // single allocation rather than rebuilding it segment by segment. + if !path.split('/').any(is_dynamic) { + return path.to_string(); + } + let mut out = String::with_capacity(path.len()); + for (i, seg) in path.split('/').enumerate() { + if i > 0 { + out.push('/'); + } + if is_dynamic(seg) { + out.push_str(":id"); + } else { + out.push_str(seg); + } + } + out +} + +fn looks_like_uuid(s: &str) -> bool { + s.len() == 36 + && s.as_bytes().iter().enumerate().all(|(i, b)| match i { + 8 | 13 | 18 | 23 => *b == b'-', + _ => b.is_ascii_hexdigit(), + }) +} + pub struct FnMetrics { executions_total: Counter, duration: Histogram, @@ -91,7 +156,7 @@ impl FnMetrics { ) { let status = if success { "ok" } else { "error" }; let attributes = [ - KeyValue::new("function", function.to_string()), + KeyValue::new("function", capped_label(&FN_LABELS_SEEN, function)), KeyValue::new("kind", kind.to_string()), KeyValue::new("status", status), KeyValue::new("cached", cached), @@ -130,7 +195,10 @@ impl FnCacheMetrics { } pub fn record(&self, function: &str, hit: bool) { - let attributes = [KeyValue::new("function", function.to_string())]; + let attributes = [KeyValue::new( + "function", + capped_label(&FN_LABELS_SEEN, function), + )]; if hit { self.hits_total.add(1, &attributes); } else { @@ -181,7 +249,7 @@ impl JobMetrics { pub fn record(&self, job_type: &str, status: &'static str, duration_secs: f64) { let attributes = [ - KeyValue::new("job_type", job_type.to_string()), + KeyValue::new("job_type", capped_label(&JOB_LABELS_SEEN, job_type)), KeyValue::new("status", status), ]; @@ -190,8 +258,13 @@ impl JobMetrics { } pub fn record_lost_claim(&self, job_type: &str) { - self.lost_claim_total - .add(1, &[KeyValue::new("job_type", job_type.to_string())]); + self.lost_claim_total.add( + 1, + &[KeyValue::new( + "job_type", + capped_label(&JOB_LABELS_SEEN, job_type), + )], + ); } } @@ -425,32 +498,67 @@ mod tests { use super::*; #[test] - fn test_http_metrics_creation() { - let _metrics = HttpMetrics::new(); + fn normalize_path_rewrites_uuid_segments() { + assert_eq!( + normalize_path_for_metrics("/_api/rpc/get_user/550e8400-e29b-41d4-a716-446655440000"), + "/_api/rpc/get_user/:id" + ); } #[test] - fn test_job_metrics_creation() { - let _metrics = JobMetrics::new(); + fn normalize_path_rewrites_all_digit_segments() { + assert_eq!(normalize_path_for_metrics("/users/12345"), "/users/:id"); + assert_eq!(normalize_path_for_metrics("/a/1/b/2"), "/a/:id/b/:id"); } #[test] - fn test_connections_gauge_creation() { - let _gauge = ActiveConnectionsGauge::new(); + fn normalize_path_leaves_fixed_routes_untouched() { + // No dynamic segment => returned verbatim (single-allocation fast path). + for fixed in ["/_api/health", "/_api/ready", "/_api/rpc/list_users", "/"] { + assert_eq!(normalize_path_for_metrics(fixed), fixed); + } + } + + #[test] + fn normalize_path_preserves_trailing_slash_shape() { + // A trailing slash makes an empty final segment, which is_dynamic + // treats as non-dynamic, so it is preserved. + assert_eq!(normalize_path_for_metrics("/users/12345/"), "/users/:id/"); + assert_eq!(normalize_path_for_metrics("/_api/health/"), "/_api/health/"); } #[test] - fn test_notify_metrics_creation() { - let _metrics = NotifyMetrics::new(); + fn normalize_path_does_not_rewrite_alphanumeric_or_short_hex() { + // 32-char hex (not a 36-char dashed UUID) and mixed alnum stay as-is. + assert_eq!(normalize_path_for_metrics("/items/abc123"), "/items/abc123"); + assert_eq!( + normalize_path_for_metrics("/x/0123456789abcdef0123456789abcdef"), + "/x/0123456789abcdef0123456789abcdef" + ); } #[test] - fn test_subscription_metrics_creation() { - let _metrics = SubscriptionMetrics::new(); + fn capped_label_passes_known_names_through_until_cap() { + // Use a fresh, test-local OnceLock so the global label sets aren't + // perturbed and the cap is exercised deterministically. + let seen: OnceLock>> = OnceLock::new(); + for i in 0..MAX_DYNAMIC_LABELS { + let name = format!("fn_{i}"); + assert_eq!(capped_label(&seen, &name), name); + } + // The set is now full; a brand-new name collapses to the other-bucket. + assert_eq!(capped_label(&seen, "fn_overflow"), ""); + // An already-seen name still passes through after the cap is hit. + assert_eq!(capped_label(&seen, "fn_0"), "fn_0"); } #[test] - fn test_workflow_scheduler_metrics_creation() { - let _metrics = WorkflowSchedulerMetrics::new(); + fn capped_label_is_idempotent_for_repeated_names() { + let seen: OnceLock>> = OnceLock::new(); + assert_eq!(capped_label(&seen, "get_user"), "get_user"); + assert_eq!(capped_label(&seen, "get_user"), "get_user"); + // Only one slot consumed for the repeated name. + let guard = seen.get().expect("init").read().expect("read lock"); + assert_eq!(guard.len(), 1); } } diff --git a/crates/forge-runtime/src/observability/telemetry.rs b/crates/forge-runtime/src/observability/telemetry.rs index 3a00269a..796b7e53 100644 --- a/crates/forge-runtime/src/observability/telemetry.rs +++ b/crates/forge-runtime/src/observability/telemetry.rs @@ -273,7 +273,8 @@ pub fn init_telemetry( ) } Err(e) => { - eprintln!("WARNING: OTLP trace exporter init failed, traces disabled: {e}"); + tracing::error!(error = %e, "OTLP trace exporter init failed; traces disabled"); + record_otel_export_initialized("traces", false); None } } @@ -306,7 +307,8 @@ pub fn init_telemetry( Some(log_layer) } Err(e) => { - eprintln!("WARNING: OTLP log exporter init failed, log bridge disabled: {e}"); + tracing::error!(error = %e, "OTLP log exporter init failed; log bridge disabled"); + record_otel_export_initialized("logs", false); None } } @@ -335,11 +337,22 @@ pub fn init_telemetry( global::set_meter_provider(meter_provider); } Err(e) => { - eprintln!("WARNING: OTLP metric exporter init failed, metrics disabled: {e}"); + tracing::error!(error = %e, "OTLP metric exporter init failed; metrics disabled"); + record_otel_export_initialized("metrics", false); } } } + if config.enable_traces { + record_otel_export_initialized("traces", TRACER_PROVIDER.get().is_some()); + } + if config.enable_logs { + record_otel_export_initialized("logs", LOGGER_PROVIDER.get().is_some()); + } + if config.enable_metrics { + record_otel_export_initialized("metrics", METER_PROVIDER.get().is_some()); + } + tracing::info!( service = %config.service_name, version = %config.service_version, @@ -353,6 +366,26 @@ pub fn init_telemetry( Ok(true) } +/// Health gauge that flips to 1 when an OTLP exporter initialized +/// successfully and 0 when it failed at startup. Lets operators alert on +/// missing telemetry without parsing log lines. +fn record_otel_export_initialized(signal: &'static str, initialized: bool) { + use opentelemetry::metrics::Gauge; + static GAUGE: OnceLock> = OnceLock::new(); + let gauge = GAUGE.get_or_init(|| { + global::meter("forge-runtime") + .u64_gauge("otel_export_initialized") + .with_description( + "1 if the OTLP exporter for this signal initialized at startup, 0 if it failed.", + ) + .build() + }); + gauge.record( + if initialized { 1 } else { 0 }, + &[KeyValue::new("signal", signal)], + ); +} + pub fn shutdown_telemetry() { tracing::info!("shutting down telemetry"); diff --git a/crates/forge-runtime/src/pg/change_log.rs b/crates/forge-runtime/src/pg/change_log.rs index 633d3eed..78ad59cf 100644 --- a/crates/forge-runtime/src/pg/change_log.rs +++ b/crates/forge-runtime/src/pg/change_log.rs @@ -176,9 +176,18 @@ mod integration_tests { let base = TestDatabase::from_env() .await .expect("Failed to create test database"); - base.isolated(test_name) + let db = base + .isolated(test_name) .await - .expect("Failed to create isolated db") + .expect("Failed to create isolated db"); + // forge_change_log and forge_notify_change() live in the system schema; + // an isolated DB starts empty, so apply it (mirrors queue.rs setup_db). + // Without this the tests fail with "relation forge_change_log does not + // exist" — which went unnoticed because this suite never ran in CI. + db.run_sql(&crate::pg::migration::get_all_system_sql()) + .await + .expect("Failed to apply system schema"); + db } /// Create a tracked table that fires the change trigger on every write. @@ -234,23 +243,47 @@ mod integration_tests { #[tokio::test] async fn trim_deletes_only_rows_older_than_cutoff() { let db = setup_db("change_log_trim").await; - install_tracked_table(db.pool(), "trim_items").await; - sqlx::query("INSERT INTO trim_items (id, name) VALUES (gen_random_uuid(), 'old')") - .execute(db.pool()) - .await - .unwrap(); - // Forge a future cutoff so the row qualifies as "old". - let cutoff = Utc::now() + chrono::Duration::seconds(10); - let deleted = trim_change_log(db.pool(), cutoff).await.unwrap(); - assert_eq!(deleted, 1); - let remaining: i64 = sqlx::query_scalar( - "SELECT COUNT(*) FROM forge_change_log WHERE table_name='trim_items'", + // trim_change_log enforces a retention floor: it is a no-op until the log + // exceeds CHANGE_LOG_MIN_ROWS, so it never over-trims a small log. Seed + // just past the floor with OLD rows plus a handful of recent ones, then + // trim at a cutoff between them — only the old rows may be deleted. + // Insert directly (the trigger path is covered by the drain test) so the + // created_at timestamps are controllable. + let old_count = CHANGE_LOG_MIN_ROWS + 50; + sqlx::query( + "INSERT INTO forge_change_log (table_name, op, created_at) + SELECT 'trim_items', 'INSERT', NOW() - INTERVAL '2 days' + FROM generate_series(1, $1)", ) - .fetch_one(db.pool()) + .bind(old_count) + .execute(db.pool()) .await .unwrap(); - assert_eq!(remaining, 0); + sqlx::query( + "INSERT INTO forge_change_log (table_name, op, created_at) + SELECT 'trim_items', 'INSERT', NOW() + FROM generate_series(1, 5)", + ) + .execute(db.pool()) + .await + .unwrap(); + + let cutoff = Utc::now() - chrono::Duration::days(1); + let deleted = trim_change_log(db.pool(), cutoff).await.unwrap(); + assert_eq!( + deleted, old_count as u64, + "every row older than the cutoff must be trimmed once past the floor", + ); + + let remaining: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM forge_change_log") + .fetch_one(db.pool()) + .await + .unwrap(); + assert_eq!( + remaining, 5, + "rows newer than the cutoff must survive the trim" + ); } #[tokio::test] diff --git a/crates/forge-runtime/src/pg/leader.rs b/crates/forge-runtime/src/pg/leader.rs index 875d0f07..f1e3a5d6 100644 --- a/crates/forge-runtime/src/pg/leader.rs +++ b/crates/forge-runtime/src/pg/leader.rs @@ -12,6 +12,14 @@ use crate::pg::notify_bus::PgNotifyBus; /// Payload is the role string; subscribers filter by their own role. pub const LEADER_RELEASED_CHANNEL: &str = "forge_leader_released"; +/// Number of times to retry `pg_try_advisory_lock` after terminating a zombie +/// leader's backend. PostgreSQL releases the dead backend's advisory locks +/// asynchronously, so the first retry can race the teardown. +const PREEMPT_RETRY_ATTEMPTS: u32 = 10; + +/// Backoff between post-termination lock-acquisition retries. +const PREEMPT_RETRY_BACKOFF: Duration = Duration::from_millis(25); + /// Leader election configuration. #[derive(Debug, Clone)] pub struct LeaderConfig { @@ -132,6 +140,19 @@ impl LeaderElection { self.is_leader.load(Ordering::SeqCst) } + /// Subscribe to leader-released NOTIFY events for this role, if a notify + /// bus is attached. Returns `None` when no bus is configured (single-node + /// or test setups), in which case callers should fall back to polling. + /// + /// Standby polling loops use this to wake immediately when the current + /// leader voluntarily releases, instead of sleeping for the full + /// `check_interval`. + pub fn subscribe_release_notify(&self) -> Option> { + self.notify_bus + .as_ref() + .and_then(|bus| bus.subscribe(LEADER_RELEASED_CHANNEL)) + } + /// How often the leader validates the advisory lock is still held. pub fn lock_validate_interval(&self) -> Duration { self.config.lock_validate_interval @@ -307,6 +328,38 @@ impl LeaderElection { } }; + // Verify the lock-holding backend belongs to a forge process before + // terminating it: two unrelated apps sharing this DB can hash to the + // same advisory lock ID, and we must not evict the other app's session. + // Connections without an application_name are assumed non-forge and + // skipped — operators can set `application_name=forge-` to opt in. + // Untyped query: `pg_stat_activity` rows can come and go between + // statements (the holder may have exited), so the macro's static row + // shape buys nothing here. Allow the lint locally. + #[allow(clippy::disallowed_methods)] + let app_name: Option = sqlx::query_scalar::<_, Option>( + "SELECT application_name FROM pg_stat_activity WHERE pid = $1", + ) + .bind(pid) + .fetch_optional(&mut **conn) + .await + .map_err(forge_core::ForgeError::Database)? + .flatten(); + + match app_name.as_deref() { + Some(name) if name.starts_with("forge") => {} + other => { + tracing::warn!( + role = self.role.as_str(), + zombie_pid = pid, + application_name = ?other, + "Refusing to terminate backend whose application_name does not start with 'forge'; \ + another app may share this database. Set application_name=forge- to allow preemption." + ); + return Ok(false); + } + } + // pg_terminate_backend returns false when permission is denied or the backend is already gone. let terminated = sqlx::query_scalar!(r#"SELECT pg_terminate_backend($1) AS "terminated!""#, pid,) @@ -332,18 +385,31 @@ impl LeaderElection { "Terminated zombie leader backend with expired lease; retrying lock acquisition" ); - // Yield to let PG process the termination before retrying the lock. - tokio::task::yield_now().await; + // pg_terminate_backend only *signals* the backend; PostgreSQL releases + // its advisory locks asynchronously as that backend tears down. A single + // immediate retry races that teardown and usually loses, so we poll + // pg_try_advisory_lock a few times with a short backoff. The window is + // small (PG processes the signal in milliseconds) but a bare yield is + // not enough. + for attempt in 0..PREEMPT_RETRY_ATTEMPTS { + let acquired = sqlx::query_scalar!( + r#"SELECT pg_try_advisory_lock($1) AS "acquired!""#, + self.role.lock_id(), + ) + .fetch_one(&mut **conn) + .await + .map_err(forge_core::ForgeError::Database)?; - let acquired = sqlx::query_scalar!( - r#"SELECT pg_try_advisory_lock($1) AS "acquired!""#, - self.role.lock_id(), - ) - .fetch_one(&mut **conn) - .await - .map_err(forge_core::ForgeError::Database)?; + if acquired { + return Ok(true); + } - Ok(acquired) + if attempt + 1 < PREEMPT_RETRY_ATTEMPTS { + tokio::time::sleep(PREEMPT_RETRY_BACKOFF).await; + } + } + + Ok(false) } /// Confirm the advisory lock is still held on the lock-owning connection. @@ -567,24 +633,9 @@ impl LeaderElection { // resolved by the lock being gone). let mut lock_connection = self.lock_connection.lock().await; if let Some(mut conn) = lock_connection.take() { - // Emit NOTIFY before unlock so standbys wake only when the lock is - // genuinely about to be free. Failure is non-fatal: standbys fall - // back to their normal check_interval timer. - if let Err(e) = sqlx::query!( - "SELECT pg_notify($1, $2)", - LEADER_RELEASED_CHANNEL, - self.role.as_str(), - ) - .execute(&mut *conn) - .await - { - tracing::warn!( - role = self.role.as_str(), - error = %e, - "Failed to emit leader-released NOTIFY; standbys will wait for next check tick", - ); - } - + // Unlock first, then NOTIFY only if we actually held the lock. + // Notifying when the lock wasn't held wakes standbys to race for + // a slot we never owned in the first place — pure noise. let released = sqlx::query_scalar!( "SELECT pg_advisory_unlock($1) as \"released!\"", self.role.lock_id() @@ -593,11 +644,26 @@ impl LeaderElection { .await .map_err(forge_core::ForgeError::Database)?; - if !released { + if released { + if let Err(e) = sqlx::query!( + "SELECT pg_notify($1, $2)", + LEADER_RELEASED_CHANNEL, + self.role.as_str(), + ) + .execute(&mut *conn) + .await + { + tracing::warn!( + role = self.role.as_str(), + error = %e, + "Failed to emit leader-released NOTIFY; standbys will wait for next check tick", + ); + } + } else { tracing::warn!( role = self.role.as_str(), "pg_advisory_unlock returned false during release; \ - lock was not held by this session" + lock was not held by this session; skipping NOTIFY" ); } @@ -1082,6 +1148,20 @@ mod integration_tests { assert!(zombie.try_become_leader().await.unwrap()); assert!(zombie.is_leader()); + // Tag the zombie's lock-holding backend with a forge-prefixed + // application_name, matching what the connection pool now sets in + // production (`forge-`). The preemption guard only terminates + // backends whose application_name starts with `forge`; without this the + // simulated zombie reports an empty name and would never be evicted. + { + let mut conn_guard = zombie.lock_connection.lock().await; + let conn = conn_guard.as_mut().expect("lock connection present"); + sqlx::query("SET application_name = 'forge-demo'") + .execute(&mut **conn) + .await + .unwrap(); + } + // Artificially expire the lease so standbys see a stale leader. #[allow(clippy::disallowed_methods)] sqlx::query( diff --git a/crates/forge-runtime/src/pg/migration/runner.rs b/crates/forge-runtime/src/pg/migration/runner.rs index 296a4875..736dd12b 100644 --- a/crates/forge-runtime/src/pg/migration/runner.rs +++ b/crates/forge-runtime/src/pg/migration/runner.rs @@ -322,11 +322,15 @@ impl MigrationRunner { &self, conn: &mut sqlx::pool::PoolConnection, ) -> Result<()> { - sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", MIGRATION_LOCK_ID) - .fetch_one(&mut **conn) - .await - .map_err(|e| ForgeError::internal_with("Failed to release migration lock", e))?; - debug!("Migration lock released"); + let released: Option = + sqlx::query_scalar!("SELECT pg_advisory_unlock($1)", MIGRATION_LOCK_ID) + .fetch_one(&mut **conn) + .await + .map_err(|e| ForgeError::internal_with("Failed to release migration lock", e))?; + // `false` means this session didn't hold the lock — useful diagnostic + // for connection-pooler scenarios where the lock-holding backend was + // reused before release. Log it instead of silently dropping. + debug!(released = ?released, "Migration lock released"); Ok(()) } @@ -444,15 +448,24 @@ impl MigrationRunner { /// `ALTER TYPE ... ADD VALUE`, `VACUUM`, and `REINDEX CONCURRENTLY` /// inside a transaction block, so opt-in migrations skip the BEGIN. /// - /// Tradeoffs the migration author must accept: - /// - A partial failure leaves the schema half-applied and the - /// bookkeeping row missing, so the next run will retry from the top. - /// - Even if all DDL succeeds, the bookkeeping `INSERT` runs on a - /// *fresh* pool connection — if that insert fails, the migration is - /// re-run on the next startup despite already having taken effect. + /// Inherent risk window: DDL commits as each statement runs, but the + /// bookkeeping `INSERT` into `forge_system_migrations` is a separate + /// statement. If the process or the connection dies between the last + /// DDL statement and the INSERT, the schema is migrated but no row is + /// recorded, and the next boot will try to re-apply the migration. + /// + /// To shrink (but not close) that window, the DDL and the bookkeeping + /// INSERT run on the **same** pooled connection — we never hand the + /// connection back to the pool between them, so a healthy connection + /// stays healthy across both steps. A mid-run crash or network drop + /// is still possible; this is the price of skipping the transaction. /// - /// Migrations using this mode must be authored idempotently - /// (`IF NOT EXISTS`, `ADD VALUE IF NOT EXISTS`, and so on). + /// Migrations using this mode **must** be authored idempotently + /// (`CREATE INDEX CONCURRENTLY IF NOT EXISTS`, `ADD VALUE IF NOT + /// EXISTS`, and so on) so a retry on the next boot is a no-op against + /// already-applied schema. A future improvement could detect "schema + /// applied but row missing" at boot and back-fill the bookkeeping row + /// instead of re-running the SQL. async fn apply_non_transactional(&self, migration: &Migration) -> Result<()> { info!( "Applying non-transactional migration: {}", @@ -502,6 +515,30 @@ impl MigrationRunner { } .await; + // Record bookkeeping on the SAME connection used for the DDL. A + // fresh pool connection here would widen the failure window — if + // the new acquire failed, the schema would already be migrated + // with no row to prove it. + let record_result: Result<()> = if exec_result.is_ok() { + let checksum = crate::stable_hash::sha256_hex(migration.up_sql.as_bytes()); + sqlx::query!( + "INSERT INTO forge_system_migrations (version, checksum) VALUES ($1, $2)", + migration.version, + checksum, + ) + .execute(&mut *conn) + .await + .map(|_| ()) + .map_err(|e| { + ForgeError::internal_with( + format!("Failed to record migration '{}'", migration.version), + e, + ) + }) + } else { + Ok(()) + }; + // Always reset before returning the connection to the pool — even on // failure. A failed RESET is rare but operators need visibility into it. if let Err(e) = sqlx::query("RESET lock_timeout").execute(&mut *conn).await { @@ -516,21 +553,7 @@ impl MigrationRunner { drop(conn); exec_result?; - - let checksum = crate::stable_hash::sha256_hex(migration.up_sql.as_bytes()); - sqlx::query!( - "INSERT INTO forge_system_migrations (version, checksum) VALUES ($1, $2)", - migration.version, - checksum, - ) - .execute(&self.pool) - .await - .map_err(|e| { - ForgeError::internal_with( - format!("Failed to record migration '{}'", migration.version), - e, - ) - })?; + record_result?; info!( "Non-transactional migration applied: {} ({:?})", @@ -1286,10 +1309,20 @@ mod integration_tests { assert!(concurrent.transactional); let err = runner.run(vec![setup, concurrent]).await.unwrap_err(); - let msg = err.to_string(); + // The wrapping ForgeError's Display shows only its context ("Failed to + // apply migration ..."); PG's actual reason ("cannot run inside a + // transaction block") lives in the source chain. Walk it so we assert on + // the real rejection — and prove the runner carries the cause, not drops it. + let mut chain = err.to_string(); + let mut source = std::error::Error::source(&err); + while let Some(cause) = source { + chain.push_str(": "); + chain.push_str(&cause.to_string()); + source = cause.source(); + } assert!( - msg.contains("CONCURRENTLY") || msg.to_lowercase().contains("transaction"), - "expected PG to reject concurrent index in tx, got: {msg}" + chain.contains("CONCURRENTLY") || chain.to_lowercase().contains("transaction"), + "expected PG to reject concurrent index in tx, got chain: {chain}" ); } diff --git a/crates/forge-runtime/src/pg/mod.rs b/crates/forge-runtime/src/pg/mod.rs index cc0bcfe9..71dda509 100644 --- a/crates/forge-runtime/src/pg/mod.rs +++ b/crates/forge-runtime/src/pg/mod.rs @@ -13,6 +13,6 @@ pub use migration::{ AppliedMigration, DriftStatus, Migration, MigrationRunner, MigrationStatus, load_migrations_from_dir, }; -pub use notify::{MAX_PAYLOAD_BYTES, NotifyChannel}; +pub use notify::{MAX_PAYLOAD_BYTES, NotifyChannel, NotifyStreamError}; pub use notify_bus::PgNotifyBus; pub use pool::Database; diff --git a/crates/forge-runtime/src/pg/notify.rs b/crates/forge-runtime/src/pg/notify.rs index 777f51ae..132b51a7 100644 --- a/crates/forge-runtime/src/pg/notify.rs +++ b/crates/forge-runtime/src/pg/notify.rs @@ -84,6 +84,13 @@ where /// - `ForgeError::Serialization` if `serde_json::to_string(payload)` fails. /// - `ForgeError::InvalidArgument` if the serialized payload exceeds /// [`MAX_PAYLOAD_BYTES`]. Use the change-log fallback for larger bodies. + /// + /// **Note**: this cap only applies to publishers that route through this + /// method. Server-side triggers that build their own payloads in PL/pgSQL + /// bypass the check entirely. Exceeding the 8 KiB PostgreSQL limit there + /// aborts the trigger's wrapping transaction (typically the user mutation + /// that caused the trigger to fire). Trigger authors must enforce their + /// own bounds. /// - `ForgeError::Database` if the underlying `SELECT pg_notify(...)` /// fails (transaction rolled back, connection dropped, etc.). pub async fn publish<'e, E>(&self, executor: E, payload: &T) -> Result<()> @@ -110,6 +117,29 @@ where } } +/// Reason a [`NotifyChannel`] subscription terminated mid-stream. +/// +/// Previously the stream simply ended via `take_while(is_ok)`, leaving the +/// caller with no way to distinguish a deliberate close from a PG-side error. +/// Items now carry `Result` so consumers can decide +/// whether to reconnect, surface the error, or treat it as fatal. +#[derive(Debug)] +pub enum NotifyStreamError { + /// The underlying `PgListener::recv` returned an error. Typically a + /// dropped backend connection — callers should reconnect. + Recv(sqlx::Error), +} + +impl std::fmt::Display for NotifyStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Recv(e) => write!(f, "PgListener recv failed: {e}"), + } + } +} + +impl std::error::Error for NotifyStreamError {} + impl NotifyChannel where T: DeserializeOwned + Send + 'static, @@ -118,38 +148,53 @@ where /// /// `listener` is consumed; the caller surrenders the connection to the /// stream for the duration of the subscription. Notifications whose - /// payload fails JSON decoding are logged and skipped, so a malformed - /// publish from one peer cannot tear down a long-running subscriber. - /// Errors from the underlying `recv` (connection dropped, etc.) end the - /// stream; the caller decides whether to reconnect. - pub async fn subscribe(&self, mut listener: PgListener) -> Result> { + /// payload fails JSON decoding are logged and skipped (a malformed + /// publish from one peer cannot tear down a long-running subscriber). + /// + /// Recv errors (connection dropped, etc.) are surfaced as + /// `Err(NotifyStreamError::Recv)` so the caller can distinguish a + /// graceful close (stream ends naturally) from a fault that requires + /// reconnect. After yielding an error the stream terminates. + pub async fn subscribe( + &self, + mut listener: PgListener, + ) -> Result>> { listener .listen(self.name) .await .map_err(ForgeError::Database)?; let channel_name = self.name; let raw = listener.into_stream(); + // Pass recv errors through (mapped to NotifyStreamError) and drop + // malformed payloads silently — the latter would otherwise look like + // a fault to subscribers when it's just a bad publish. let stream = raw - .take_while(|res| { - let cont = res.is_ok(); - async move { cont } - }) - .filter_map(move |res| async move { - let notification = match res { - Ok(n) => n, - Err(_) => return None, - }; - match serde_json::from_str::(notification.payload()) { - Ok(value) => Some(value), + .scan(false, |ended, res| { + let done = *ended; + let next = match res { + Ok(n) => Some(Ok(n)), Err(e) => { - tracing::debug!( - channel = channel_name, - error = %e, - payload = notification.payload(), - "NotifyChannel: dropping malformed payload", - ); - None + *ended = true; + Some(Err(NotifyStreamError::Recv(e))) } + }; + async move { if done { None } else { next } } + }) + .filter_map(move |res| async move { + match res { + Err(e) => Some(Err(e)), + Ok(notification) => match serde_json::from_str::(notification.payload()) { + Ok(value) => Some(Ok(value)), + Err(e) => { + tracing::debug!( + channel = channel_name, + error = %e, + payload = notification.payload(), + "NotifyChannel: dropping malformed payload", + ); + None + } + }, } }); Ok(stream) @@ -239,7 +284,8 @@ mod integration_tests { let received = tokio::time::timeout(Duration::from_secs(5), stream.next()) .await .expect("stream did not yield within 5s") - .expect("stream ended before yielding"); + .expect("stream ended before yielding") + .expect("recv ok"); assert_eq!(received, payload); } @@ -285,7 +331,8 @@ mod integration_tests { let received = tokio::time::timeout(Duration::from_secs(5), stream.next()) .await .expect("stream did not yield within 5s") - .expect("stream ended before yielding"); + .expect("stream ended before yielding") + .expect("recv ok"); assert_eq!(received, payload); } } diff --git a/crates/forge-runtime/src/pg/notify_bus.rs b/crates/forge-runtime/src/pg/notify_bus.rs index 55dda8eb..bd617aa7 100644 --- a/crates/forge-runtime/src/pg/notify_bus.rs +++ b/crates/forge-runtime/src/pg/notify_bus.rs @@ -15,6 +15,24 @@ //! old per-subsystem listeners had, except now there is exactly one //! reconnect path to maintain). //! +//! # Replay backing per channel +//! +//! Not every channel survives a reconnect gap equally well. Subscribers must +//! pair the bus with channel-specific recovery: +//! +//! - `forge_changes`: backed by `forge_change_log`. Subscribers replay missed +//! rows by `last_seen_seq` after a `subscribe_reconnects` tick. +//! - `forge_workflow_wakeup`: idempotent — the workflow executor's normal +//! timer poll catches missed wakeups within its tick interval. +//! - `forge_leader_released`: a missed event delays standby acquisition by at +//! most one `LeaderConfig::check_interval`. +//! - `forge_jobs_available`: **no replay backing**. A missed NOTIFY leaves +//! jobs unclaimed until the next worker poll (`poll_interval`, default +//! 5 s). Workers must keep their independent poll cadence even when this +//! channel is connected; do not extend `poll_interval` past acceptable +//! tail latency on the assumption that NOTIFY will always be timely. +//! - `forge_schema_changed`: advisory only; reconnect re-fetches schema. +//! //! # Payload semantics //! //! The bus forwards the raw `notification.payload()` string. Channels that @@ -32,7 +50,12 @@ use tokio::sync::{broadcast, watch}; /// Per-channel broadcast buffer size. Subscribers that fall behind by more /// than this many messages will see `RecvError::Lagged` and can decide /// whether to catch up or resync. -const CHANNEL_BUFFER_SIZE: usize = 256; +/// +/// Sized for bursty `forge_changes` workloads where a single transaction can +/// emit hundreds of notifications. Every direct subscriber MUST handle +/// `broadcast::error::RecvError::Lagged` — drop-and-resync is the only +/// correct response since the bus does not back-pressure publishers. +const CHANNEL_BUFFER_SIZE: usize = 4096; /// Initial reconnection delay after a `PgListener` disconnect. const INITIAL_BACKOFF: Duration = Duration::from_millis(500); diff --git a/crates/forge-runtime/src/pg/pool.rs b/crates/forge-runtime/src/pg/pool.rs index ed72f70e..05551ff3 100644 --- a/crates/forge-runtime/src/pg/pool.rs +++ b/crates/forge-runtime/src/pg/pool.rs @@ -173,7 +173,7 @@ impl Database { fn connect_options(url: &str, service_name: &str) -> sqlx::Result { let options: PgConnectOptions = url.parse()?; Ok(options - .application_name(service_name) + .application_name(&forge_application_name(service_name)) .log_statements(LevelFilter::Off) .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500))) } @@ -185,7 +185,7 @@ impl Database { ) -> sqlx::Result { let options: PgConnectOptions = url.parse()?; let mut opts = options - .application_name(service_name) + .application_name(&forge_application_name(service_name)) .log_statements(LevelFilter::Off) .log_slow_statements(LevelFilter::Warn, Duration::from_millis(500)); if statement_timeout_secs > 0 { @@ -342,6 +342,25 @@ impl Database { } } +/// Build the `application_name` reported by every Forge connection. +/// +/// Leader-election zombie preemption ([`crate::pg::leader`]) only terminates a +/// lock-holding backend whose `application_name` starts with `forge`, so it +/// never evicts an unrelated app sharing the database. For that guard to ever +/// fire against Forge's *own* zombie, Forge connections must self-identify with +/// that prefix. The service name passed in is the project name (e.g. `demo`), +/// which would otherwise produce a non-matching `application_name`. +/// +/// Idempotent: a service name already starting with `forge` (e.g. the internal +/// `"forge"` default) is returned unchanged so we never produce `forge-forge`. +fn forge_application_name(service_name: &str) -> String { + if service_name.starts_with("forge") { + service_name.to_string() + } else { + format!("forge-{service_name}") + } +} + /// Minimum supported PostgreSQL major version. /// /// Forge v0.9+ uses features (skip-locked semantics with `NOWAIT`, partitioned @@ -466,4 +485,60 @@ mod tests { assert_eq!(cloned.url(), config.url()); assert_eq!(cloned.pool_size, config.pool_size); } + + #[test] + fn forge_application_name_prefixes_project_names() { + // A bare project name must gain the `forge-` prefix so leader-election + // zombie preemption (which only terminates `forge`-prefixed backends) + // can evict Forge's own zombie leader. + assert_eq!(forge_application_name("demo"), "forge-demo"); + assert_eq!(forge_application_name("my-app"), "forge-my-app"); + // Idempotent: names already starting with `forge` are untouched. + assert_eq!(forge_application_name("forge"), "forge"); + assert_eq!(forge_application_name("forge-worker"), "forge-worker"); + } +} + +#[cfg(all(test, feature = "testcontainers"))] +#[allow(clippy::unwrap_used, clippy::disallowed_methods)] +mod integration_tests { + use super::*; + use forge_core::testing::TestDatabase; + + async fn base_db() -> TestDatabase { + TestDatabase::from_env() + .await + .expect("Failed to create test database") + } + + /// A pool built with a production-shaped service name (the project name) + /// must report a `forge`-prefixed `application_name`. This is the precise + /// regressor for zombie-leader eviction: the leader-election guard only + /// terminates backends whose `application_name` starts with `forge`, so if + /// Forge's own pools reported the bare project name (`demo`) the framework + /// could never preempt its own zombie. Fails before the fix (reports + /// `demo`), passes after (reports `forge-demo`). + #[tokio::test] + async fn pool_application_name_is_forge_prefixed_for_preemption() { + let base = base_db().await; + let db = Database::from_config_with_service(&DatabaseConfig::new(base.url()), "demo") + .await + .expect("connect with production-shaped service name"); + + let app_name: String = sqlx::query_scalar("SELECT current_setting('application_name')") + .fetch_one(db.primary()) + .await + .unwrap(); + + assert_eq!( + app_name, "forge-demo", + "Forge pools must self-identify as forge- so leader preemption can evict them" + ); + assert!( + app_name.starts_with("forge"), + "application_name must satisfy the leader.rs preemption guard" + ); + + db.close().await; + } } diff --git a/crates/forge-runtime/src/rate_limit/limiter.rs b/crates/forge-runtime/src/rate_limit/limiter.rs index 06611cde..0cb30adc 100644 --- a/crates/forge-runtime/src/rate_limit/limiter.rs +++ b/crates/forge-runtime/src/rate_limit/limiter.rs @@ -75,7 +75,8 @@ impl StrictRateLimiter { // tokens is clamped to >= -1, so retry_after is bounded by // (1 - (-1)) / refill_rate = 2 / refill_rate — proportional to // one refill interval rather than runaway. - let retry_after = Duration::from_secs_f64((1.0 - tokens) / refill_rate); + let base = (1.0 - tokens) / refill_rate; + let retry_after = Duration::from_secs_f64(jittered(base)); Ok(RateLimitResult::denied(remaining, reset_at, retry_after)) } } @@ -179,6 +180,21 @@ impl StrictRateLimiter { } } +/// Apply ±25% jitter to a retry-after value so clients denied in the same +/// tick don't synchronize their retries into a thundering herd. +fn jittered(base_secs: f64) -> f64 { + if !base_secs.is_finite() || base_secs <= 0.0 { + return base_secs.max(0.0); + } + let nanos = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.subsec_nanos()) + .unwrap_or(0); + // Map nanos -> [-0.25, 0.25]. + let frac = ((nanos as f64) / 1_000_000_000.0) * 0.5 - 0.25; + (base_secs * (1.0 + frac)).max(0.0) +} + struct LocalBucket { tokens: f64, max_tokens: f64, @@ -267,8 +283,18 @@ impl HybridRateLimiter { let max_tokens = config.requests as f64; let refill_rate = config.refill_rate(); - if self.local.len() > self.max_local_buckets { - self.cleanup_local(Duration::from_secs(300)); // evict entries idle > 5 min + // Sweep proactively once we cross 75% of the soft cap so a burst of + // unique keys can't grow the map far past `max_local_buckets` before + // the first eviction runs. If everything is still hot we fall back to + // a hard cap that bounds memory at 2× the configured ceiling. + let len = self.local.len(); + if len > self.max_local_buckets * 3 / 4 { + self.cleanup_local(Duration::from_secs(300)); + if self.local.len() > self.max_local_buckets * 2 { + // Last-resort: drop entries idle > 30s to bound memory even + // when the workload is fully active on unique keys. + self.cleanup_local(Duration::from_secs(30)); + } } let mut bucket = self @@ -284,7 +310,8 @@ impl HybridRateLimiter { if allowed { Ok(RateLimitResult::allowed(remaining, reset_at)) } else { - let retry_after = bucket.time_until_token(); + let retry_after = + Duration::from_secs_f64(jittered(bucket.time_until_token().as_secs_f64())); Ok(RateLimitResult::denied(remaining, reset_at, retry_after)) } } diff --git a/crates/forge-runtime/src/realtime/listener.rs b/crates/forge-runtime/src/realtime/listener.rs index d5f8126a..574560cb 100644 --- a/crates/forge-runtime/src/realtime/listener.rs +++ b/crates/forge-runtime/src/realtime/listener.rs @@ -128,7 +128,16 @@ impl ChangeListener { }; let count = rows.len(); + // Track the highest seq across rows (including ones we skip because + // the op didn't parse) and only commit it after the whole batch is + // forwarded. Per-row stores let a live `rx.recv` racing this loop + // bump last_seq past unforwarded replay rows, and skipped rows used + // to leave a permanent gap that blocked future replays. + let mut max_seq = self.last_seq.load(Ordering::Relaxed); for row in &rows { + if row.seq > max_seq { + max_seq = row.seq; + } let Ok(operation) = row.op.parse::() else { continue; }; @@ -145,7 +154,9 @@ impl ChangeListener { } let _ = self.change_tx.send(change); - self.last_seq.store(row.seq, Ordering::Relaxed); + } + if max_seq > self.last_seq.load(Ordering::Relaxed) { + self.last_seq.store(max_seq, Ordering::Relaxed); } if count > 0 { @@ -228,6 +239,14 @@ impl ChangeListener { match result { Ok(payload) => { let recv_time = std::time::Instant::now(); + // Always recover any embedded seq, even on parse + // failure or unknown op, so a malformed/unknown + // payload doesn't pin the watermark and force + // the next reconnect-replay to refuse the gap. + let trailing_seq = payload + .rsplit_once('#') + .and_then(|(_, s)| s.parse::().ok()) + .unwrap_or(0); if let Some((change, seq)) = self.parse_notification(&payload) { // Skip already-processed seqs to prevent // double-processing during the seed window. @@ -243,6 +262,11 @@ impl ChangeListener { crate::cluster::metrics::record_notification_latency(recv_time.elapsed().as_secs_f64()); } else { tracing::debug!(payload = %payload, "Failed to parse notification"); + if trailing_seq > 0 + && trailing_seq > self.last_seq.load(Ordering::Relaxed) + { + self.last_seq.store(trailing_seq, Ordering::Relaxed); + } } } Err(broadcast::error::RecvError::Lagged(n)) => { diff --git a/crates/forge-runtime/src/realtime/manager.rs b/crates/forge-runtime/src/realtime/manager.rs index a5b12d65..b19fcbba 100644 --- a/crates/forge-runtime/src/realtime/manager.rs +++ b/crates/forge-runtime/src/realtime/manager.rs @@ -111,14 +111,21 @@ impl SubscriptionManager { table_deps: &'static [&'static str], selected_cols: &'static [&'static str], ) -> forge_core::Result<(QueryGroupId, SubscriptionId, bool)> { - // Check per-session limit - if let Some(subs) = self.session_subscribers.get(&session_id) - && subs.len() >= self.max_per_session + // Reserve a slot under the entry write guard so two concurrent + // subscribes from the same session can't both observe `len < max` + // and race past the limit. We hold the slot for the rest of this + // call; if anything later fails we drop the placeholder so the + // user gets their seat back. + let placeholder = SubscriberId(u32::MAX); { - return Err(forge_core::ForgeError::Validation(format!( - "Maximum subscriptions per session ({}) exceeded", - self.max_per_session - ))); + let mut entry = self.session_subscribers.entry(session_id).or_default(); + if entry.len() >= self.max_per_session { + return Err(forge_core::ForgeError::Validation(format!( + "Maximum subscriptions per session ({}) exceeded", + self.max_per_session + ))); + } + entry.push(placeholder); } let auth_scope = AuthScope::from_auth(auth_context); @@ -176,10 +183,16 @@ impl SubscriptionManager { group.subscribers.push(subscriber_id); } - self.session_subscribers - .entry(session_id) - .or_default() - .push(subscriber_id); + // Swap the placeholder we reserved earlier for the real id. If for + // any reason the placeholder is gone (e.g. concurrent session + // teardown), fall back to pushing. + if let Some(mut entry) = self.session_subscribers.get_mut(&session_id) { + if let Some(slot) = entry.iter_mut().find(|s| s.0 == u32::MAX) { + *slot = subscriber_id; + } else { + entry.push(subscriber_id); + } + } Ok((group_id, subscription_id, is_new)) } @@ -339,17 +352,24 @@ impl SubscriptionManager { /// runtime-discovered tables from the read set that weren't in the /// compile-time `table_deps`. pub fn update_group(&self, group_id: QueryGroupId, read_set: ReadSet, result_hash: String) { - if let Some(mut group) = self.groups.get_mut(&group_id) { - for table in &read_set.tables { - let already_indexed = group.table_deps.iter().any(|t| *t == table); - if !already_indexed { - self.table_index - .entry(table.clone()) - .or_default() - .insert(group_id); - } - } + // Record execution BEFORE extending the table_index so a concurrent + // `find_affected_groups` can't observe the new table in the index + // but still see the old read_set on the group (which would make + // `should_invalidate` return false and silently drop the change). + let new_tables: Vec = if let Some(mut group) = self.groups.get_mut(&group_id) { + let tables = read_set + .tables + .iter() + .filter(|t| !group.table_deps.contains(&t.as_str())) + .cloned() + .collect(); group.record_execution(read_set, result_hash); + tables + } else { + return; + }; + for table in new_tables { + self.table_index.entry(table).or_default().insert(group_id); } } @@ -368,16 +388,15 @@ impl SubscriptionManager { data: std::sync::Arc, serialized_len: usize, ) { - if let Some(mut group) = self.groups.get_mut(&group_id) { - for table in &read_set.tables { - let already_indexed = group.table_deps.iter().any(|t| *t == table); - if !already_indexed { - self.table_index - .entry(table.clone()) - .or_default() - .insert(group_id); - } - } + // Same ordering as `update_group`: record execution before + // publishing new tables into `table_index`. + let new_tables: Vec = if let Some(mut group) = self.groups.get_mut(&group_id) { + let tables = read_set + .tables + .iter() + .filter(|t| !group.table_deps.contains(&t.as_str())) + .cloned() + .collect(); if serialized_len > self.max_cached_result_bytes { tracing::debug!( @@ -390,6 +409,12 @@ impl SubscriptionManager { } else { group.record_execution_with_data(read_set, result_hash, data); } + tables + } else { + return; + }; + for table in new_tables { + self.table_index.entry(table).or_default().insert(group_id); } } diff --git a/crates/forge-runtime/src/realtime/message.rs b/crates/forge-runtime/src/realtime/message.rs index a5692251..f8572bd4 100644 --- a/crates/forge-runtime/src/realtime/message.rs +++ b/crates/forge-runtime/src/realtime/message.rs @@ -161,7 +161,39 @@ impl SessionServer { total_drops: AtomicU32::new(0), token_exp, }; - self.connections.insert(session_id, entry); + // Notify the displaced client before replacing it so it doesn't + // think it's still receiving live data on a dead channel. + if let Some(prev) = self.connections.insert(session_id, entry) { + let _ = prev.sender.try_send(RealtimeMessage::AuthFailed { + reason: "Session replaced by a newer connection".to_string(), + }); + } + } + + /// Notify the client that its auth has been revoked, then tear down the + /// connection. Returns the subscription IDs the session held, so callers + /// can clean up associated query/job/workflow state. The notification is + /// best-effort (`try_send`): if the channel is full or closed, eviction + /// still proceeds. + /// + /// Use this when a session's underlying principal has been demoted, + /// tenant-moved, or revoked server-side before the JWT's `exp`. The + /// reactor's cached `AuthContext` on each `QueryGroup` is only + /// re-validated on token expiry, so without an explicit revocation path + /// the session would keep receiving data under the stale scope until + /// `exp`. After this call the client must reconnect and re-subscribe + /// with a fresh token. + pub fn revoke_session( + &self, + session_id: SessionId, + reason: &str, + ) -> Option> { + if let Some(conn) = self.connections.get(&session_id) { + let _ = conn.sender.try_send(RealtimeMessage::AuthFailed { + reason: reason.to_string(), + }); + } + self.remove_connection(session_id) } /// Remove a connection. @@ -353,7 +385,17 @@ impl SessionServer { } for (session_id, _) in stale { - self.remove_connection(session_id); + // Re-check last_active under the entry guard: a concurrent + // try_send may have bumped it between the snapshot and now, + // and evicting a connection that just successfully received + // traffic would drop a healthy client. + let still_stale = self + .connections + .get(&session_id) + .is_some_and(|c| c.last_active.load(Ordering::Relaxed) < cutoff_ts); + if still_stale { + self.remove_connection(session_id); + } } } @@ -870,6 +912,44 @@ mod tests { assert!(evicted.is_empty()); } + #[tokio::test] + async fn revoke_session_notifies_then_evicts_connection_and_subscriptions() { + // Server-side auth revocation path: client gets one final AuthFailed + // message, the connection is removed, and the subscription mappings + // are returned for caller cleanup. After revocation the session must + // be unreachable — any send returns SessionNotFound, forcing the + // client to reconnect with a fresh token. + let server = SessionServer::new(NodeId::new(), RealtimeConfig::default()); + let session_id = SessionId::new(); + let sub_a = SubscriptionId::new(); + let sub_b = SubscriptionId::new(); + let (tx, mut rx) = mpsc::channel(8); + + server.register_connection(session_id, tx, None); + server.add_subscription(session_id, sub_a).unwrap(); + server.add_subscription(session_id, sub_b).unwrap(); + + let removed = server + .revoke_session(session_id, "role demoted") + .expect("session existed"); + assert_eq!(removed.len(), 2); + assert!(removed.contains(&sub_a)); + assert!(removed.contains(&sub_b)); + + match rx.recv().await { + Some(RealtimeMessage::AuthFailed { reason }) => assert_eq!(reason, "role demoted"), + other => panic!("expected AuthFailed, got {other:?}"), + } + + assert_eq!(server.connection_count(), 0); + assert_eq!(server.subscription_count(), 0); + let result = server.try_send_to_session(session_id, RealtimeMessage::Lagging); + assert!(matches!(result, Err(SendError::SessionNotFound))); + + // Calling revoke again on a gone session is a no-op. + assert!(server.revoke_session(session_id, "again").is_none()); + } + #[test] fn cleanup_expired_tokens_returns_empty_when_nothing_expired() { let server = SessionServer::new(NodeId::new(), RealtimeConfig::default()); diff --git a/crates/forge-runtime/src/realtime/reactor.rs b/crates/forge-runtime/src/realtime/reactor.rs index bd490e86..b6724a72 100644 --- a/crates/forge-runtime/src/realtime/reactor.rs +++ b/crates/forge-runtime/src/realtime/reactor.rs @@ -186,10 +186,13 @@ impl Reactor { self.session_server.remove_connection(session_id); // Clean up job subscriptions using reverse index for O(1) lookup + // Lock order: subscriptions map BEFORE session reverse-index. + // Matches the cleanup task and `unsubscribe_job`/`unsubscribe_workflow` + // to remove the deadlock window where opposite orders could meet. { + let mut job_subs = self.job_subscriptions.write().await; let job_ids = self.session_job_ids.write().await.remove(&session_id); if let Some(ids) = job_ids { - let mut job_subs = self.job_subscriptions.write().await; for id in ids { if let Some(subscribers) = job_subs.get_mut(&id) { subscribers.retain(|s| s.session_id != session_id); @@ -203,9 +206,9 @@ impl Reactor { // Clean up workflow subscriptions using reverse index for O(1) lookup { + let mut workflow_subs = self.workflow_subscriptions.write().await; let wf_ids = self.session_workflow_ids.write().await.remove(&session_id); if let Some(ids) = wf_ids { - let mut workflow_subs = self.workflow_subscriptions.write().await; for id in ids { if let Some(subscribers) = workflow_subs.get_mut(&id) { subscribers.retain(|s| s.session_id != session_id); @@ -265,7 +268,23 @@ impl Reactor { } }; - let (result_hash, serialized_len) = Self::compute_hash(&data); + // A subscription with no observable table dependencies could + // never be invalidated, so it would sit live forever and never + // re-execute. Reject up front instead of silently going dark. + if table_deps.is_empty() && read_set.tables.is_empty() { + self.unsubscribe(subscription_id); + return Err(forge_core::ForgeError::Validation(format!( + "Query '{}' has no table dependencies and cannot be subscribed to", + query_name + ))); + } + + let Some((result_hash, serialized_len)) = Self::compute_hash(&data) else { + self.unsubscribe(subscription_id); + return Err(forge_core::ForgeError::internal( + "Failed to serialize query result for change detection", + )); + }; tracing::trace!( ?group_id, @@ -349,26 +368,33 @@ impl Reactor { } /// Unsubscribe from job updates. - pub async fn unsubscribe_job(&self, session_id: SessionId, client_sub_id: &str) { - let mut subs = self.job_subscriptions.write().await; - let mut removed_ids = Vec::new(); - for (job_id, subscribers) in subs.iter_mut() { - let before = subscribers.len(); - subscribers - .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); - if subscribers.len() < before { - removed_ids.push(*job_id); + /// + /// Requires `job_id` so this is O(subscribers for one job) instead of + /// walking every job entry. Callers always know it: it's the id they + /// passed to `subscribe_job`. + /// + /// Lock order: `job_subscriptions` -> `session_job_ids`. Matches + /// `remove_session` to prevent deadlocks under adversarial scheduling. + pub async fn unsubscribe_job(&self, session_id: SessionId, job_id: Uuid, client_sub_id: &str) { + let removed = { + let mut subs = self.job_subscriptions.write().await; + let mut removed = false; + if let Some(subscribers) = subs.get_mut(&job_id) { + let before = subscribers.len(); + subscribers + .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); + removed = subscribers.len() < before; + if subscribers.is_empty() { + subs.remove(&job_id); + } } - } - subs.retain(|_, v| !v.is_empty()); - drop(subs); + removed + }; - if !removed_ids.is_empty() { + if removed { let mut session_jobs = self.session_job_ids.write().await; if let Some(ids) = session_jobs.get_mut(&session_id) { - for id in &removed_ids { - ids.remove(id); - } + ids.remove(&job_id); if ids.is_empty() { session_jobs.remove(&session_id); } @@ -410,27 +436,33 @@ impl Reactor { Ok(workflow_data) } - /// Unsubscribe from workflow updates. - pub async fn unsubscribe_workflow(&self, session_id: SessionId, client_sub_id: &str) { - let mut subs = self.workflow_subscriptions.write().await; - let mut removed_ids = Vec::new(); - for (wf_id, subscribers) in subs.iter_mut() { - let before = subscribers.len(); - subscribers - .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); - if subscribers.len() < before { - removed_ids.push(*wf_id); + /// Unsubscribe from workflow updates. See [`unsubscribe_job`] for the + /// rationale on requiring the `workflow_id`. + pub async fn unsubscribe_workflow( + &self, + session_id: SessionId, + workflow_id: Uuid, + client_sub_id: &str, + ) { + let removed = { + let mut subs = self.workflow_subscriptions.write().await; + let mut removed = false; + if let Some(subscribers) = subs.get_mut(&workflow_id) { + let before = subscribers.len(); + subscribers + .retain(|s| !(s.session_id == session_id && s.client_sub_id == client_sub_id)); + removed = subscribers.len() < before; + if subscribers.is_empty() { + subs.remove(&workflow_id); + } } - } - subs.retain(|_, v| !v.is_empty()); - drop(subs); + removed + }; - if !removed_ids.is_empty() { + if removed { let mut session_wfs = self.session_workflow_ids.write().await; if let Some(ids) = session_wfs.get_mut(&session_id) { - for id in &removed_ids { - ids.remove(id); - } + ids.remove(&workflow_id); if ids.is_empty() { session_wfs.remove(&session_id); } @@ -465,14 +497,13 @@ impl Reactor { } /// Content hash for change detection; returns `(hash, byte_count)`. - fn compute_hash(data: &serde_json::Value) -> (String, usize) { - match serde_json::to_vec(data) { - Ok(bytes) => { - let len = bytes.len(); - (crate::stable_hash::sha256_hex(&bytes), len) - } - Err(_) => ("!serialization_failed!".to_string(), usize::MAX), - } + /// `None` if serialization fails — callers MUST skip the update so a + /// failure sentinel isn't cached and used to suppress later, legitimate + /// "still broken" notifications or emit spurious data on recovery. + fn compute_hash(data: &serde_json::Value) -> Option<(String, usize)> { + let bytes = serde_json::to_vec(data).ok()?; + let len = bytes.len(); + Some((crate::stable_hash::sha256_hex(&bytes), len)) } /// Flush pending invalidations with bounded concurrent re-execution. @@ -549,7 +580,59 @@ impl Reactor { .await; } + /// Drop every subscription for a session (query, job, workflow) and close + /// its SSE channel after a final `AuthFailed` notification. Intended as an + /// admin escape hatch for server-side auth revocation: cached + /// `AuthContext` on `QueryGroup` is captured at subscribe time and only + /// re-validated on JWT expiry, so demotions/tenant moves that happen + /// before `exp` cannot be detected by the reactor itself. Operators wire + /// this up to their identity system's revocation event; after the call + /// the client must reconnect and re-subscribe with a fresh token. + pub async fn revoke_session_auth(&self, session_id: SessionId, reason: &str) { + self.subscription_manager + .remove_session_subscriptions(session_id); + self.session_server.revoke_session(session_id, reason); + + { + let mut job_subs = self.job_subscriptions.write().await; + let job_ids = self.session_job_ids.write().await.remove(&session_id); + if let Some(ids) = job_ids { + for id in ids { + if let Some(subscribers) = job_subs.get_mut(&id) { + subscribers.retain(|s| s.session_id != session_id); + if subscribers.is_empty() { + job_subs.remove(&id); + } + } + } + } + } + + { + let mut workflow_subs = self.workflow_subscriptions.write().await; + let wf_ids = self.session_workflow_ids.write().await.remove(&session_id); + if let Some(ids) = wf_ids { + for id in ids { + if let Some(subscribers) = workflow_subs.get_mut(&id) { + subscribers.retain(|s| s.session_id != session_id); + if subscribers.is_empty() { + workflow_subs.remove(&id); + } + } + } + } + } + + tracing::info!(?session_id, %reason, "Session auth revoked"); + } + /// Re-run queries for groups, pushing to subscribers on hash change. + /// + /// Note: re-execution uses the `AuthContext` captured at subscribe time. + /// Authorization is checked at subscribe time and only re-checked on + /// token expiry. Server-side role/tenant changes before `exp` are not + /// detected here; callers must invoke [`revoke_session_auth`] to evict + /// affected sessions explicitly. async fn reexecute_groups( group_ids: &[forge_core::realtime::QueryGroupId], subscription_manager: &Arc, @@ -638,7 +721,13 @@ impl Reactor { }; match result { Ok((new_data, read_set)) => { - let (new_hash, serialized_len) = Self::compute_hash(&new_data); + let Some((new_hash, serialized_len)) = Self::compute_hash(&new_data) else { + tracing::warn!( + ?group_id, + "Skipping group update: result failed to serialize" + ); + continue; + }; if last_hash.as_ref() != Some(&new_hash) { let data_arc = std::sync::Arc::new(new_data); @@ -715,6 +804,7 @@ impl Reactor { tracing::debug!("Reactor listening for changes"); let mut restart_count: u32 = 0; + let mut consecutive_lags: u32 = 0; let (listener_error_tx, mut listener_error_rx) = mpsc::channel::(1); // Start initial listener @@ -757,6 +847,11 @@ impl Reactor { // again; reset so a long-lived process can absorb more // transient failures over its lifetime. restart_count = 0; + consecutive_lags = 0; + // A successful change proves the new listener + // is healthy; flush stale errors so they + // can't be misattributed to it later. + while listener_error_rx.try_recv().is_ok() {} Self::handle_change( &change, &invalidation_engine, @@ -767,10 +862,20 @@ impl Reactor { ).await; } Err(broadcast::error::RecvError::Lagged(n)) => { + // Back off exponentially on consecutive lags so a + // sustained event rate above the broadcast buffer + // doesn't pin us in a resync-storm. + let backoff_ms = 100u64 + .saturating_mul(1u64 << consecutive_lags.min(6)); + consecutive_lags = consecutive_lags.saturating_add(1); tracing::warn!( missed = n, - "Reactor lagged; scheduling full resync" + consecutive_lags, + backoff_ms, + "Reactor lagged; backing off before scheduling full resync" ); + tokio::time::sleep(std::time::Duration::from_millis(backoff_ms)) + .await; listener.set_needs_resync(); } Err(broadcast::error::RecvError::Closed) => { @@ -889,6 +994,14 @@ impl Reactor { if let Some(handle) = listener_handle.take() { handle.abort(); } + // Drain any further error messages that piled up + // while we were sleeping the backoff: the aborted + // listener may have queued additional failures, and + // the new listener must not be debited for them + // (each stale message would otherwise bump + // restart_count toward max_restarts on phantom + // restarts and emit a false "permanently failed"). + while listener_error_rx.try_recv().is_ok() {} change_rx = listener.subscribe(); listener_handle = Some(tokio::spawn(async move { if let Err(e) = listener_clone.run(&bus_clone).await { @@ -1299,18 +1412,12 @@ impl Reactor { let mut read_set = ReadSet::new(); - if info.table_dependencies.is_empty() { - let table_name = Self::extract_table_name(query_name); - read_set.add_table(&table_name); - tracing::trace!( - query = %query_name, - fallback_table = %table_name, - "Using naming convention fallback for table dependency" - ); - } else { - for table in info.table_dependencies { - read_set.add_table(*table); - } + // No naming-convention fallback: a fake "table" equal to + // the query name would never appear in real change events, + // so the subscription would silently never re-execute. + // Callers reject empty-deps subscriptions at subscribe time. + for table in info.table_dependencies { + read_set.add_table(*table); } Ok((data, read_set)) @@ -1322,10 +1429,6 @@ impl Reactor { } } - fn extract_table_name(query_name: &str) -> String { - query_name.to_string() - } - /// Auth check for re-execution (authentication only, roles checked at subscribe time). fn check_query_auth( info: &forge_core::function::FunctionInfo, @@ -1461,9 +1564,9 @@ mod tests { let data2 = serde_json::json!({"name": "test"}); let data3 = serde_json::json!({"name": "different"}); - let (hash1, len1) = Reactor::compute_hash(&data1); - let (hash2, _) = Reactor::compute_hash(&data2); - let (hash3, _) = Reactor::compute_hash(&data3); + let (hash1, len1) = Reactor::compute_hash(&data1).expect("hash data1"); + let (hash2, _) = Reactor::compute_hash(&data2).expect("hash data2"); + let (hash3, _) = Reactor::compute_hash(&data3).expect("hash data3"); assert_eq!(hash1, hash2); assert_ne!(hash1, hash3); diff --git a/crates/forge-runtime/src/signals/bot.rs b/crates/forge-runtime/src/signals/bot.rs index 37ed45fb..912e2929 100644 --- a/crates/forge-runtime/src/signals/bot.rs +++ b/crates/forge-runtime/src/signals/bot.rs @@ -61,10 +61,14 @@ const BOT_PATTERNS: &[&str] = &[ "curl/", "libwww", "apache-httpclient", - "okhttp", - "node-fetch", - "axios", - "postman", + // Patterns below match the format these libraries actually emit in a UA + // (token + "/version"). Bare substrings like "axios" or "okhttp" flagged + // legitimate mobile apps and react-native clients that ship those names + // embedded inside larger UAs. + "okhttp/", + "node-fetch/", + "axios/", + "postmanruntime/", ]; /// Pre-compiled Aho-Corasick automaton for bot detection. diff --git a/crates/forge-runtime/src/signals/collector.rs b/crates/forge-runtime/src/signals/collector.rs index ab2df237..b8e975e4 100644 --- a/crates/forge-runtime/src/signals/collector.rs +++ b/crates/forge-runtime/src/signals/collector.rs @@ -12,6 +12,12 @@ use sqlx::PgPool; use tokio::sync::{Mutex, mpsc, oneshot}; use tracing::{debug, error, warn}; +/// Hard ceiling on the total in-buffer byte size before we force a flush. +/// Caps PG memory pressure when a single UNNEST batch would otherwise grow +/// into hundreds of MB. Tracked alongside `batch_size` (whichever fires +/// first wins). +const MAX_BUFFER_BYTES: usize = 16 * 1024 * 1024; + /// Buffered signal event collector. /// /// Clone-friendly (shares the mpsc sender). Send events from any async @@ -110,6 +116,7 @@ async fn flush_loop( mut shutdown_rx: oneshot::Receiver>, ) { let mut buffer: Vec = Vec::with_capacity(batch_size); + let mut buffer_bytes: usize = 0; let mut interval = tokio::time::interval(flush_interval); interval.tick().await; @@ -118,6 +125,7 @@ async fn flush_loop( biased; ack = &mut shutdown_rx => { while let Ok(event) = rx.try_recv() { + buffer_bytes = buffer_bytes.saturating_add(estimate_event_bytes(&event)); buffer.push(event); } if !buffer.is_empty() { @@ -132,9 +140,11 @@ async fn flush_loop( event = rx.recv() => { match event { Some(e) => { + buffer_bytes = buffer_bytes.saturating_add(estimate_event_bytes(&e)); buffer.push(e); - if buffer.len() >= batch_size { + if buffer.len() >= batch_size || buffer_bytes >= MAX_BUFFER_BYTES { flush_batch(&pool, &mut buffer).await; + buffer_bytes = 0; } } None => { @@ -149,12 +159,49 @@ async fn flush_loop( _ = interval.tick() => { if !buffer.is_empty() { flush_batch(&pool, &mut buffer).await; + buffer_bytes = 0; } } } } } +/// Cheap byte estimate dominated by the variable-size fields. Avoids +/// re-serializing the properties JSON for every accounting update — we just +/// take the length of its serde repr where it's a string/object/array, and +/// fall back to a fixed overhead. +fn estimate_event_bytes(event: &SignalEvent) -> usize { + fn opt_len(s: &Option) -> usize { + s.as_deref().map(str::len).unwrap_or(0) + } + let props = serde_json::to_vec(&event.properties) + .map(|v| v.len()) + .unwrap_or(0); + let ctx = serde_json::to_vec(&event.error_context) + .map(|v| v.len()) + .unwrap_or(0); + // 256 = fixed-size column overhead (uuids, ints, timestamps, bools). + 256 + props + + ctx + + opt_len(&event.event_name) + + opt_len(&event.correlation_id) + + opt_len(&event.visitor_id) + + opt_len(&event.page_url) + + opt_len(&event.referrer) + + opt_len(&event.function_name) + + opt_len(&event.function_kind) + + opt_len(&event.status) + + opt_len(&event.error_message) + + opt_len(&event.error_stack) + + opt_len(&event.client_ip) + + opt_len(&event.country) + + opt_len(&event.city) + + opt_len(&event.user_agent) + + opt_len(&event.device_type) + + opt_len(&event.browser) + + opt_len(&event.os) +} + /// Flush a batch of events into PostgreSQL using UNNEST for single-roundtrip INSERT. /// Uses runtime sqlx::query() because UNNEST with typed arrays is not supported by /// the compile-time sqlx::query!() macro. diff --git a/crates/forge-runtime/src/signals/endpoints.rs b/crates/forge-runtime/src/signals/endpoints.rs index e497181d..a49502a7 100644 --- a/crates/forge-runtime/src/signals/endpoints.rs +++ b/crates/forge-runtime/src/signals/endpoints.rs @@ -32,6 +32,15 @@ use super::visitor; /// Maximum events per batch request. const MAX_BATCH_SIZE: usize = 50; +/// Maximum serialized byte size of a single event's free-form `properties` +/// JSON. Larger payloads are rejected. Prevents apps from dumping request +/// bodies / PII into analytics rows. +const MAX_PROPERTY_BYTES: usize = 4096; + +/// Maximum serialized byte size of a single event envelope (event name + +/// properties + correlation_id). Larger batch entries are rejected. +const MAX_EVENT_BYTES: usize = 8192; + /// Check the client's Do-Not-Track header. We honor DNT: 1 by short-circuiting /// signal ingestion -- the browser has explicitly opted out of tracking. /// Sec-GPC (Global Privacy Control) is also respected. @@ -125,6 +134,14 @@ async fn handle_event( if batch.events.len() > MAX_BATCH_SIZE { return rate_limited_response(); } + for event in &batch.events { + if !event_within_limits(event) { + return Json(SignalResponse { + ok: false, + session_id: None, + }); + } + } let ctx = extract_request_ctx( headers, @@ -133,28 +150,22 @@ async fn handle_event( &state.server_secret, state.anonymize_ip, state.geoip.as_ref(), - ); - let session_id = + ) + .await; + let supplied_session_id = resolve_session_id(batch.context.as_ref().and_then(|c| c.session_id.as_deref())); + let session_id = Some(supplied_session_id.unwrap_or_else(Uuid::new_v4)); let page_url = batch.context.as_ref().and_then(|c| c.page_url.clone()); - let session_id = session::upsert_session( - &state.pool, + let referrer = batch.context.as_ref().and_then(|c| c.referrer.clone()); + spawn_session_upsert( + state.pool.clone(), session_id, - &ctx.visitor_id, - ctx.user_id, - ctx.tenant_id, - page_url.as_deref(), - batch.context.as_ref().and_then(|c| c.referrer.as_deref()), - ctx.user_agent.as_deref(), - ctx.client_ip.as_deref(), - ctx.is_bot, + &ctx, + page_url.clone(), + referrer, "track", - ctx.device_type.as_deref(), - ctx.browser.as_deref(), - ctx.os.as_deref(), - ) - .await; + ); for event in batch.events { let signal = SignalEvent { @@ -216,27 +227,20 @@ async fn handle_view( &state.server_secret, state.anonymize_ip, state.geoip.as_ref(), - ); + ) + .await; let session_id_header = extract_header(headers, "x-session-id"); - let session_id = resolve_session_id(session_id_header.as_deref()); + let supplied_session_id = resolve_session_id(session_id_header.as_deref()); + let session_id = Some(supplied_session_id.unwrap_or_else(Uuid::new_v4)); - let session_id = session::upsert_session( - &state.pool, + spawn_session_upsert( + state.pool.clone(), session_id, - &ctx.visitor_id, - ctx.user_id, - ctx.tenant_id, - Some(&payload.url), - payload.referrer.as_deref(), - ctx.user_agent.as_deref(), - ctx.client_ip.as_deref(), - ctx.is_bot, + &ctx, + Some(payload.url.clone()), + payload.referrer.clone(), "page_view", - ctx.device_type.as_deref(), - ctx.browser.as_deref(), - ctx.os.as_deref(), - ) - .await; + ); let utm = if payload.utm_source.is_some() || payload.utm_medium.is_some() @@ -314,28 +318,13 @@ async fn handle_report( &state.server_secret, state.anonymize_ip, state.geoip.as_ref(), - ); + ) + .await; let session_id_header = extract_header(headers, "x-session-id"); let session_id = resolve_session_id(session_id_header.as_deref()); - if let Some(sid) = session_id { - session::upsert_session( - &state.pool, - Some(sid), - &ctx.visitor_id, - ctx.user_id, - ctx.tenant_id, - None, - None, - ctx.user_agent.as_deref(), - ctx.client_ip.as_deref(), - ctx.is_bot, - "error", - ctx.device_type.as_deref(), - ctx.browser.as_deref(), - ctx.os.as_deref(), - ) - .await; + if session_id.is_some() { + spawn_session_upsert(state.pool.clone(), session_id, &ctx, None, None, "error"); } for err in report.errors { @@ -392,7 +381,7 @@ struct RequestCtx { os: Option, } -fn extract_request_ctx( +async fn extract_request_ctx( headers: &HeaderMap, resolved_ip: Option, auth: &Option>, @@ -413,12 +402,26 @@ fn extract_request_ctx( let user_id = auth.as_ref().and_then(|a| a.user_id()); let tenant_id = auth.as_ref().and_then(|a| a.tenant_id()); let device_info = device::parse_lowered(platform_header.as_deref(), &ua_lower); - let geo = geoip - .zip(raw_ip.as_deref()) - .map(|(g, ip)| g.lookup(ip)) - .unwrap_or_default(); + let geo = match (geoip, raw_ip.clone()) { + (Some(g), Some(ip)) => { + // MMDB lookups can be CPU-blocking on cold pages; offload so the + // request thread keeps feeding the collector. + let g = g.clone(); + tokio::task::spawn_blocking(move || g.lookup(&ip)) + .await + .unwrap_or_default() + } + _ => super::geoip::GeoInfo::default(), + }; // anonymize_ip drops the raw IP after visitor_id + geo are derived; GDPR-friendly default. let client_ip = if anonymize_ip { None } else { raw_ip }; + // When IP is anonymized, also strip the UA major-version so the combo of + // UA + country + city can't be used to re-fingerprint the visitor. + let user_agent = if anonymize_ip { + user_agent.as_deref().map(anonymize_ua) + } else { + user_agent + }; RequestCtx { user_agent, client_ip, @@ -438,6 +441,76 @@ fn extract_header(headers: &HeaderMap, name: &str) -> Option { crate::gateway::extract_header(headers, name) } +/// Strip the major version off a UA so a per-version identifier can't be +/// derived. Recognizes the most common browser family tokens; falls back to +/// the broad family when the UA doesn't match any known prefix. +fn anonymize_ua(ua: &str) -> String { + const FAMILIES: &[&str] = &["Chrome/", "Firefox/", "Safari/", "Edg/", "Opera/"]; + for family in FAMILIES { + if ua.contains(family) { + return (*family).to_string(); + } + } + "Other".to_string() +} + +/// Per-event size guard. Rejects events whose serialized properties / event +/// envelope exceed configured limits. +fn event_within_limits(event: &forge_core::signals::ClientEvent) -> bool { + let props_bytes = match serde_json::to_vec(&event.properties) { + Ok(b) => b.len(), + Err(_) => return false, + }; + if props_bytes > MAX_PROPERTY_BYTES { + return false; + } + let total = event.event.len() + + props_bytes + + event.correlation_id.as_deref().map(str::len).unwrap_or(0); + total <= MAX_EVENT_BYTES +} + +/// Fire-and-forget the session upsert so the request thread doesn't block on +/// a PG round-trip. We mint the session ID synchronously upstream so the +/// response can return it before the row is persisted. +fn spawn_session_upsert( + pool: PgPool, + session_id: Option, + ctx: &RequestCtx, + page_url: Option, + referrer: Option, + event_type: &'static str, +) { + let visitor_id = ctx.visitor_id.clone(); + let user_id = ctx.user_id; + let tenant_id = ctx.tenant_id; + let user_agent = ctx.user_agent.clone(); + let client_ip = ctx.client_ip.clone(); + let device_type = ctx.device_type.clone(); + let browser = ctx.browser.clone(); + let os = ctx.os.clone(); + let is_bot = ctx.is_bot; + tokio::spawn(async move { + session::upsert_session( + &pool, + session_id, + &visitor_id, + user_id, + tenant_id, + page_url.as_deref(), + referrer.as_deref(), + user_agent.as_deref(), + client_ip.as_deref(), + is_bot, + event_type, + device_type.as_deref(), + browser.as_deref(), + os.as_deref(), + ) + .await; + }); +} + fn resolve_session_id(raw: Option<&str>) -> Option { raw.and_then(|s| Uuid::parse_str(s).ok()) } diff --git a/crates/forge-runtime/src/signals/partition.rs b/crates/forge-runtime/src/signals/partition.rs index 0e09fc33..bb794568 100644 --- a/crates/forge-runtime/src/signals/partition.rs +++ b/crates/forge-runtime/src/signals/partition.rs @@ -2,6 +2,27 @@ //! //! Creates partitions for upcoming months and drops partitions //! older than the configured retention period. +//! +//! ## Operational expectation +//! +//! Pre-creation runs every maintenance tick and covers the current month plus +//! the next three months. Anything farther out lands in the catch-all +//! `forge_signals_events_default` partition, which is **excluded from the +//! retention sweep** — rows there accumulate forever and won't be dropped. +//! +//! Two failure modes to watch: +//! +//! 1. A node hibernates / loses its scheduler past the +3-month horizon. When +//! it wakes back up, any inserts whose `timestamp` falls outside the +//! rolling window land in the default partition until the maintenance loop +//! next runs. +//! 2. A client sends events with `timestamp` far in the future (clock skew, +//! backfill jobs). Same outcome. +//! +//! `check_default_partition` logs an error whenever rows are present in the +//! default partition. Treat that log as actionable: investigate the coverage +//! gap, then move the misrouted rows into the correct month partition by +//! hand (or accept that they'll never be cleaned up by retention). // Partition DDL constructs table names from runtime dates, so the query macros // can't validate them at compile time. diff --git a/crates/forge-runtime/src/signals/rate_limit.rs b/crates/forge-runtime/src/signals/rate_limit.rs index b2ddaa6b..af78dbda 100644 --- a/crates/forge-runtime/src/signals/rate_limit.rs +++ b/crates/forge-runtime/src/signals/rate_limit.rs @@ -8,6 +8,8 @@ //! limit is effectively `nodes * max_per_window`, which is fine for abuse //! protection — billing-grade limits are not the goal here. +use std::collections::VecDeque; +use std::sync::Mutex; use std::sync::atomic::{AtomicI64, AtomicU32, Ordering}; use dashmap::DashMap; @@ -16,6 +18,10 @@ use dashmap::DashMap; /// minute. Generous enough to absorb legitimate bursts (page-view + web-vital /// flush + a handful of tracked events on a navigation) while still capping /// runaway clients. +/// +/// TODO(signals-config): expose `SignalsConfig::rate_limit_per_minute` and +/// thread it through `gateway::server` so operators can tune this in +/// forge.toml. Until then, callers can override via `with_limit(...)`. const DEFAULT_MAX_REQUESTS_PER_WINDOW: u32 = 600; /// Window length in seconds. @@ -29,6 +35,11 @@ const MAX_TRACKED_IPS: usize = 100_000; pub struct SignalRateLimiter { max_per_window: u32, buckets: DashMap, + /// Insertion-order queue. When `buckets.len()` exceeds `MAX_TRACKED_IPS` + /// we pop the oldest entry off the front and remove it from the map. + /// O(1) amortized — replaces the previous O(n) `evict_oldest` sweep that + /// ran inline on every new-IP miss. + insertion_order: Mutex>, } struct IpBucket { @@ -47,6 +58,7 @@ impl SignalRateLimiter { Self { max_per_window, buckets: DashMap::new(), + insertion_order: Mutex::new(VecDeque::new()), } } @@ -68,14 +80,30 @@ impl SignalRateLimiter { bucket.count.store(1, Ordering::Relaxed); return true; } + // Note: fetch_add followed by the prev < max comparison is a + // benign race — two concurrent callers can both observe a value + // just under the ceiling and both succeed, leaving the bucket a + // small constant over the configured max. Acceptable for abuse + // protection; billing-grade limits are explicitly out of scope. let prev = bucket.count.fetch_add(1, Ordering::Relaxed); return prev < self.max_per_window; } - if self.buckets.len() >= MAX_TRACKED_IPS { - self.evict_oldest(); - } + self.insert_new_bucket(ip, now); + true + } + /// Insert a freshly-seen IP. Evicts the oldest tracked IP in O(1) when the + /// map is at capacity (FIFO; not strictly LRU but good enough — abuse- + /// driven floods churn the queue fast enough that any stale entries fall + /// off naturally). + fn insert_new_bucket(&self, ip: &str, now: i64) { + if self.buckets.len() >= MAX_TRACKED_IPS + && let Ok(mut order) = self.insertion_order.lock() + && let Some(victim) = order.pop_front() + { + self.buckets.remove(&victim); + } self.buckets.insert( ip.to_string(), IpBucket { @@ -83,13 +111,9 @@ impl SignalRateLimiter { count: AtomicU32::new(1), }, ); - true - } - - fn evict_oldest(&self) { - let cutoff = chrono::Utc::now().timestamp() - WINDOW_SECS; - self.buckets - .retain(|_, bucket| bucket.window_start.load(Ordering::Relaxed) >= cutoff); + if let Ok(mut order) = self.insertion_order.lock() { + order.push_back(ip.to_string()); + } } } diff --git a/crates/forge-runtime/src/signals/session.rs b/crates/forge-runtime/src/signals/session.rs index d049dffd..3a3b535f 100644 --- a/crates/forge-runtime/src/signals/session.rs +++ b/crates/forge-runtime/src/signals/session.rs @@ -69,7 +69,12 @@ pub async fn upsert_session( } } - let new_id = Uuid::new_v4(); + // Reuse the caller-supplied session id when present: the handler already + // returned it to the client, so the persisted row MUST carry that same id. + // Minting a fresh UUID here orphaned the row under an id the client never + // saw, so every later event re-missed the UPDATE and spawned a new session — + // breaking session continuity. Only generate when no id was supplied. + let new_id = session_id.unwrap_or_else(Uuid::new_v4); let referrer_domain = referrer.and_then(extract_domain); let result = sqlx::query( diff --git a/crates/forge-runtime/src/signals/tests.rs b/crates/forge-runtime/src/signals/tests.rs index 91bdfa32..9d5715d0 100644 --- a/crates/forge-runtime/src/signals/tests.rs +++ b/crates/forge-runtime/src/signals/tests.rs @@ -732,3 +732,103 @@ async fn test_partition_ensure() { db.cleanup().await.unwrap(); } + +// ── Privacy short-circuit ─────────────────────────────────────────────────── + +/// `DNT: 1` opts the visitor out of analytics. The endpoint must return +/// `ok: true` (so the client doesn't keep retrying) without persisting any +/// event row. A regression that drops the short-circuit would silently +/// re-enable tracking for opted-out users. +#[tokio::test] +async fn dnt_header_short_circuits_event_storage() { + let db = setup("dnt_short_circuit").await; + let state = make_signals_state(db.pool()); + + let batch = SignalEventBatch { + events: vec![ClientEvent { + event: "dnt_should_not_persist".to_string(), + properties: serde_json::json!({}), + correlation_id: None, + timestamp: None, + }], + context: None, + }; + + let mut headers = make_headers(); + headers.insert("dnt", HeaderValue::from_static("1")); + + let response = endpoints::signal_handler( + State(state.clone()), + None, + None, + headers, + Json(SignalPayload::Event(batch)), + ) + .await + .into_response(); + + let body: serde_json::Value = axum::body::to_bytes(response.into_body(), 1024) + .await + .map(|b| serde_json::from_slice(&b).unwrap()) + .unwrap(); + assert_eq!( + body["ok"], true, + "DNT must return ok so clients stop retrying" + ); + + // Give the collector a chance to flush — if anything was queued it would + // land in this window. We assert nothing was stored. + tokio::time::sleep(Duration::from_millis(200)).await; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM forge_signals_events WHERE event_name = 'dnt_should_not_persist'", + ) + .fetch_one(db.pool()) + .await + .unwrap(); + assert_eq!(count.0, 0, "DNT events must not be persisted"); + + db.cleanup().await.unwrap(); +} + +/// `Sec-GPC: 1` (Global Privacy Control) is the modern equivalent of DNT +/// and must short-circuit the same way. +#[tokio::test] +async fn sec_gpc_header_short_circuits_event_storage() { + let db = setup("gpc_short_circuit").await; + let state = make_signals_state(db.pool()); + + let batch = SignalEventBatch { + events: vec![ClientEvent { + event: "gpc_should_not_persist".to_string(), + properties: serde_json::json!({}), + correlation_id: None, + timestamp: None, + }], + context: None, + }; + + let mut headers = make_headers(); + headers.insert("sec-gpc", HeaderValue::from_static("1")); + + let _ = endpoints::signal_handler( + State(state.clone()), + None, + None, + headers, + Json(SignalPayload::Event(batch)), + ) + .await; + + tokio::time::sleep(Duration::from_millis(200)).await; + + let count: (i64,) = sqlx::query_as( + "SELECT COUNT(*) FROM forge_signals_events WHERE event_name = 'gpc_should_not_persist'", + ) + .fetch_one(db.pool()) + .await + .unwrap(); + assert_eq!(count.0, 0, "Sec-GPC events must not be persisted"); + + db.cleanup().await.unwrap(); +} diff --git a/crates/forge-runtime/src/signals/visitor.rs b/crates/forge-runtime/src/signals/visitor.rs index 0f63bedc..f1896549 100644 --- a/crates/forge-runtime/src/signals/visitor.rs +++ b/crates/forge-runtime/src/signals/visitor.rs @@ -8,6 +8,17 @@ use sha2::{Digest, Sha256}; use std::sync::RwLock; +use std::sync::atomic::{AtomicBool, Ordering}; + +/// Must stay in sync with `gateway::server::DEFAULT_SIGNAL_SECRET`. If a +/// caller passes this literal we refuse to emit a real visitor ID, since +/// the daily salt would be trivially reversible by anyone who reads the +/// open-source repo. +const DEFAULT_SIGNAL_SECRET: &str = "forge-default-signal-secret"; + +/// One-shot guard so we only log the "default secret in use" warning once +/// rather than on every request. +static WARNED_DEFAULT_SECRET: AtomicBool = AtomicBool::new(false); /// Cached daily salt to avoid recomputing on every request. struct DailySalt { @@ -29,6 +40,15 @@ pub fn generate_visitor_id( user_agent: Option<&str>, server_secret: &str, ) -> String { + if server_secret == DEFAULT_SIGNAL_SECRET { + if !WARNED_DEFAULT_SECRET.swap(true, Ordering::Relaxed) { + tracing::error!( + "signals: default visitor-ID secret in use; refusing to emit a real visitor ID. \ + Configure [auth] jwt_secret in forge.toml to enable visitor tracking." + ); + } + return String::new(); + } let ip = client_ip.unwrap_or("unknown"); let ua = user_agent.unwrap_or("unknown"); let salt = get_daily_salt(server_secret); diff --git a/crates/forge-runtime/src/webhook/handler.rs b/crates/forge-runtime/src/webhook/handler.rs index 4855966e..10276c47 100644 --- a/crates/forge-runtime/src/webhook/handler.rs +++ b/crates/forge-runtime/src/webhook/handler.rs @@ -13,6 +13,7 @@ use axum::{ use base64::{Engine as _, engine::general_purpose}; use forge_core::CircuitBreakerClient; use forge_core::function::{JobDispatch, KvHandle, WorkflowDispatch}; +use forge_core::rate_limit::{RateLimitConfig, RateLimitKey}; use forge_core::webhook::{ IdempotencySource, REPLAY_TIMESTAMP_HEADER, SignatureAlgorithm, WebhookContext, }; @@ -21,11 +22,30 @@ use ring::signature::{self, UnparsedPublicKey}; use serde_json::{Value, json}; use sha2::Sha256; use sqlx::PgPool; +use std::time::Duration; use tracing::{error, info, warn}; use uuid::Uuid; use super::registry::WebhookRegistry; -use crate::gateway::RpcError; +use crate::gateway::{ResolvedClientIp, RpcError}; +use crate::rate_limit::HybridRateLimiter; + +/// Hard cap on the inbound webhook body, also bounding the bytes persisted to +/// `forge_webhook_events.raw_body`. Without this cap a misbehaving sender can +/// fill the events table with multi-MB payloads, and unsigned webhooks have no +/// other guard at all. +const MAX_WEBHOOK_BODY_BYTES: usize = 1024 * 1024; + +/// Cap the number of comma-separated rotation secrets we will HMAC per request +/// to bound the work an attacker spraying invalid signatures can force. +const MAX_WEBHOOK_SECRETS: usize = 4; + +/// Default cap on unsigned webhook deliveries per source IP per minute. +/// +/// `allow_unsigned = true` opts out of signature validation; without flow +/// control any caller reaching the URL can spray dispatches and pollute the +/// idempotency table. The DDoS cost of unsigned endpoints is bounded here. +const UNSIGNED_RATE_LIMIT_PER_MINUTE: u32 = 60; /// State for webhook handler. #[derive(Clone)] @@ -36,10 +56,12 @@ pub struct WebhookState { job_dispatcher: Option>, workflow_dispatcher: Option>, kv: Option>, + unsigned_rate_limiter: Arc, } impl WebhookState { pub fn new(registry: Arc, pool: PgPool) -> Self { + let unsigned_rate_limiter = Arc::new(HybridRateLimiter::new(pool.clone())); Self { registry, pool, @@ -47,6 +69,7 @@ impl WebhookState { job_dispatcher: None, workflow_dispatcher: None, kv: None, + unsigned_rate_limiter, } } @@ -70,12 +93,29 @@ impl WebhookState { pub async fn webhook_handler( State(state): State>, Path(path): Path, + axum::extract::Extension(client_ip): axum::extract::Extension, headers: HeaderMap, body: Bytes, ) -> Response { let full_path = format!("/webhooks/{}", path); let request_id = Uuid::new_v4().to_string(); + if body.len() > MAX_WEBHOOK_BODY_BYTES { + warn!( + path = %full_path, + body_size = body.len(), + "Webhook body exceeds maximum size" + ); + return ( + StatusCode::PAYLOAD_TOO_LARGE, + Json(RpcError::new( + "PAYLOAD_TOO_LARGE", + "Webhook payload exceeds maximum size", + )), + ) + .into_response(); + } + let entry = match state.registry.get_by_path(&full_path) { Some(e) => e, None => { @@ -108,6 +148,50 @@ pub async fn webhook_handler( .into_response(); } + // Flow-control for unsigned webhooks: signature validation is what bounds + // the cost of an attacker spraying requests against the dispatch + idempotency + // path. When `allow_unsigned = true` we still need a per-IP ceiling to keep + // the endpoint from being a free amplification vector. + if info.signature.is_none() && info.allow_unsigned { + let ip_key = client_ip.0.as_deref().unwrap_or("unknown").to_string(); + let bucket = format!("webhook:unsigned:{}:{}", info.name, ip_key); + let config = RateLimitConfig::new(UNSIGNED_RATE_LIMIT_PER_MINUTE, Duration::from_secs(60)) + .with_key(RateLimitKey::Ip); + match state.unsigned_rate_limiter.check(&bucket, &config).await { + Ok(result) if !result.allowed => { + let retry_after = result + .retry_after + .unwrap_or(Duration::from_secs(1)) + .as_secs() + .max(1); + warn!( + webhook = info.name, + ip = %ip_key, + "Unsigned webhook rate-limited" + ); + let mut resp = ( + StatusCode::TOO_MANY_REQUESTS, + Json(RpcError::new( + "RATE_LIMITED", + "Too many unsigned webhook deliveries from this client", + )), + ) + .into_response(); + if let Ok(val) = axum::http::HeaderValue::from_str(&retry_after.to_string()) { + resp.headers_mut().insert("Retry-After", val); + } + return resp; + } + Ok(_) => {} + Err(e) => { + // Failing closed on the rate limit would make a transient PG + // hiccup take the webhook endpoint down; failing open keeps the + // already-cheap unsigned path serving, with a loud log. + warn!(webhook = info.name, error = %e, "Unsigned webhook rate-limit check failed; allowing"); + } + } + } + if let Some(ref sig_config) = info.signature { let signature = match headers .get(sig_config.header_name) @@ -141,10 +225,14 @@ pub async fn webhook_handler( } }; + // Cap secrets considered per request so an attacker spraying invalid + // signatures can't force unbounded HMACs when many rotation secrets + // are configured. let secrets: Vec<&str> = secrets_raw .split(',') .map(str::trim) .filter(|s| !s.is_empty()) + .take(MAX_WEBHOOK_SECRETS) .collect(); let signature_valid = secrets.iter().any(|secret| { validate_signature( @@ -167,7 +255,7 @@ pub async fn webhook_handler( } let idempotency_key = if let Some(ref idem_config) = info.idempotency { - match &idem_config.source { + let extracted = match &idem_config.source { IdempotencySource::Header(header_name) => headers .get(*header_name) .and_then(|v| v.to_str().ok()) @@ -181,7 +269,23 @@ pub async fn webhook_handler( } // Future IdempotencySource variants: skip key extraction. _ => None, + }; + // Idempotency was opted into; missing/malformed keys must fail closed + // rather than silently running the handler without replay protection. + if extracted.is_none() { + warn!( + webhook = info.name, + "Idempotency configured but key could not be extracted" + ); + return ( + StatusCode::BAD_REQUEST, + Json(RpcError::validation( + "Required idempotency key is missing or malformed", + )), + ) + .into_response(); } + extracted } else { None }; @@ -595,10 +699,15 @@ fn validate_stripe_webhooks( Ok(n) => n, Err(_) => return false, }; - if replay_window_secs > 0 - && (chrono::Utc::now().timestamp() - ts).unsigned_abs() > replay_window_secs - { - return false; + if replay_window_secs > 0 { + let now = chrono::Utc::now().timestamp(); + let window = i64::try_from(replay_window_secs).unwrap_or(i64::MAX); + let age = now.saturating_sub(ts); + // Reject future timestamps (age < 0) and stale ones uniformly with the + // generic replay window check. + if !(0..=window).contains(&age) { + return false; + } } let mut signed = Vec::with_capacity(timestamp.len() + 1 + body.len()); diff --git a/crates/forge-runtime/src/workflow/bridge.rs b/crates/forge-runtime/src/workflow/bridge.rs index 3afe9bbe..f9e16000 100644 --- a/crates/forge-runtime/src/workflow/bridge.rs +++ b/crates/forge-runtime/src/workflow/bridge.rs @@ -37,6 +37,28 @@ pub fn register_workflow_bridge(executor: Arc, job_registry: & let cancel: bool = args.get("cancel").and_then(Value::as_bool).unwrap_or(false); if cancel { + // Only the scheduler (`enqueue_cancel`) sets `resume_reason == + // "cancel"`. An external caller dispatching `$workflow_resume` + // via `JobDispatcher::dispatch_by_name` with `{cancel: true}` + // would NOT set this marker, so we reject — defense in depth + // against a compromised internal caller cancelling arbitrary + // runs (#11 in issues doc). The `$` prefix on the job_type is + // convention; this guard is enforcement. + let scheduler_marker = args + .get("resume_reason") + .and_then(Value::as_str) + .map(|s| s == "cancel") + .unwrap_or(false); + if !scheduler_marker { + tracing::error!( + workflow_run_id = %run_id, + "Rejected $workflow_resume cancel without scheduler marker; \ + only WorkflowScheduler may dispatch cancel jobs" + ); + return Err(forge_core::ForgeError::Forbidden( + "cancel jobs may only be dispatched by the workflow scheduler".to_string(), + )); + } let reason = args .get("reason") .and_then(Value::as_str) diff --git a/crates/forge-runtime/src/workflow/executor.rs b/crates/forge-runtime/src/workflow/executor.rs index c985796e..3f52e761 100644 --- a/crates/forge-runtime/src/workflow/executor.rs +++ b/crates/forge-runtime/src/workflow/executor.rs @@ -3,7 +3,8 @@ use std::future::Future; use std::pin::Pin; use std::sync::Arc; -use tokio::sync::RwLock; +use dashmap::DashMap; +use tokio::sync::{Mutex, RwLock}; use uuid::Uuid; use super::bridge::WORKFLOW_RESUME_JOB; @@ -50,6 +51,13 @@ pub struct WorkflowExecutor { job_queue: JobQueue, http_client: CircuitBreakerClient, compensation_state: Arc>>, + /// Per-run serialization: execute and cancel of the same run_id never + /// overlap. Without this guard a cancel landing concurrently with a + /// resume would yank the live `CompensationState` and double-fire + /// compensation while the handler is still mid-step (#3 in issues doc). + /// Entry is removed by the holder when the workflow reaches a terminal + /// state; in-flight holders elsewhere keep their Arc alive. + run_locks: Arc>>>, kv: Option>, } @@ -66,10 +74,21 @@ impl WorkflowExecutor { job_queue, http_client, compensation_state: Arc::new(RwLock::new(HashMap::new())), + run_locks: Arc::new(DashMap::new()), kv: None, } } + /// Returns an `Arc>` keyed by `run_id`, creating one if absent. + /// Holders take the mutex around any code that touches `compensation_state` + /// or runs the workflow handler to keep cancel and execute serialized. + fn run_lock(&self, run_id: Uuid) -> Arc> { + self.run_locks + .entry(run_id) + .or_insert_with(|| Arc::new(Mutex::new(()))) + .clone() + } + pub fn with_kv(mut self, kv: Arc) -> Self { self.kv = Some(kv); self @@ -137,6 +156,12 @@ impl WorkflowExecutor { .await .map_err(forge_core::ForgeError::Database)?; + // #15: A DB trigger on `forge_jobs` (`v001_initial.sql`) already + // PERFORM pg_notify('forge_jobs_available', ...) on every insert. + // PostgreSQL buffers NOTIFY in the source transaction and delivers on + // commit, so the resume job is visible to workers as soon as the row + // is. No extra NOTIFY needed here. + Ok(run_id) } @@ -148,6 +173,10 @@ impl WorkflowExecutor { resume: Option, owner_subject: Option, ) -> forge_core::Result { + // Serialize against concurrent cancel for the same run. + let lock = self.run_lock(run_id); + let _guard = lock.lock().await; + self.claim_for_execution(run_id).await?; let signal_label = if resume.is_some() { @@ -219,7 +248,16 @@ impl WorkflowExecutor { let handler = entry.handler.clone(); let exec_start = std::time::Instant::now(); - let result = tokio::time::timeout(entry.info.timeout, handler(&ctx, input)).await; + // PER-RESUME timeout. `entry.info.timeout` bounds a single resume + // call, not the whole workflow run. A workflow that sleeps for an + // hour and then runs for 4m59s under a 5m timeout will pass, even + // if it suspends and resumes many times — total wall-clock is + // unbounded by this guard (#4 in issues doc). Tracking total-run + // budget requires a new column on `forge_workflow_runs`; until + // then the field name is intentionally treated as `step_timeout` + // semantics by callers. + let step_timeout = entry.info.timeout; + let result = tokio::time::timeout(step_timeout, handler(&ctx, input)).await; let exec_duration_ms = exec_start.elapsed().as_millis().min(i32::MAX as u128) as i32; let comp = CompensationState { @@ -377,22 +415,48 @@ impl WorkflowExecutor { /// operators know manual remediation is required. This is an honest limitation: /// in-memory closures cannot survive restarts. pub async fn cancel(&self, run_id: Uuid, reason: &str) -> forge_core::Result<()> { + // Serialize against a concurrent resume/execute for this run. Without + // the lock the live handler could be mid-step while we yank its + // compensation state and flip the row to `failed` (#3 in issues doc). + let lock = self.run_lock(run_id); + let _guard = lock.lock().await; + if let Some(state) = self.compensation_state.write().await.remove(&run_id) { - self.run_compensation(run_id, &state).await?; - let error = format!("cancelled: {reason}"); - self.fail_workflow(run_id, &error).await?; + let comp_failures = self.run_compensation(run_id, &state).await?; + // #17/#18: persist the cancel reason in a dedicated column and + // surface compensation failures as a structured summary in `error` + // so operators don't have to grep logs to learn which steps need + // manual remediation. + let comp_summary = if comp_failures.is_empty() { + None + } else { + Some(format!( + "{} compensation(s) failed: {}", + comp_failures.len(), + comp_failures.join("; ") + )) + }; + self.finalize_cancel(run_id, reason, comp_summary.as_deref()) + .await?; } else { tracing::error!( workflow_run_id = %run_id, "Compensation handlers lost (process restarted since workflow began); \ manual remediation required for any side effects from completed steps" ); - let error = format!( - "cancelled: {reason} (compensation skipped: handlers lost on restart, manual remediation required)" - ); - self.fail_workflow(run_id, &error).await?; + self.finalize_cancel( + run_id, + reason, + Some("compensation skipped: handlers lost on restart, manual remediation required"), + ) + .await?; } + // Run is terminal; drop the per-run mutex entry so the map doesn't + // accumulate. Holders that captured the Arc before this point keep + // their own reference alive. + self.run_locks.remove(&run_id); + Ok(()) } @@ -411,8 +475,20 @@ impl WorkflowExecutor { /// /// Returns `false` if the run is already in a terminal state or no row /// matched. - pub async fn request_cancel(&self, run_id: Uuid, reason: &str) -> forge_core::Result { - let result = sqlx::query!( + pub async fn request_cancel( + &self, + run_id: Uuid, + reason: &str, + caller_subject: Option<&str>, + ) -> forge_core::Result { + // Ownership check parallels `JobQueue::request_cancel`: if the run has + // an `owner_subject`, the caller must match (or be `None`, which means + // system / internal). Without this any caller holding the dispatcher + // could cancel any workflow run by ID (#10 in issues doc). + // + // Runtime query — adds a parameter; avoids invalidating .sqlx/. + #[allow(clippy::disallowed_methods)] + let result = sqlx::query( r#" UPDATE forge_workflow_runs SET cancel_requested_at = NOW(), @@ -420,10 +496,16 @@ impl WorkflowExecutor { WHERE id = $1 AND status IN ('pending', 'running', 'sleeping', 'waiting') AND cancel_requested_at IS NULL + AND ( + owner_subject IS NULL + OR $3::text IS NULL + OR owner_subject = $3::text + ) "#, - run_id, - reason, ) + .bind(run_id) + .bind(reason) + .bind(caller_subject) .execute(&self.pool) .await .map_err(forge_core::ForgeError::Database)?; @@ -435,8 +517,9 @@ impl WorkflowExecutor { &self, run_id: Uuid, state: &CompensationState, - ) -> forge_core::Result<()> { + ) -> forge_core::Result> { let steps = self.get_workflow_steps(run_id).await?; + let mut failures: Vec = Vec::new(); for step_name in state.completed_steps.iter().rev() { if let Some(handler) = state.handlers.get(step_name) { @@ -457,12 +540,24 @@ impl WorkflowExecutor { .await?; } Err(e) => { + let err_str = e.to_string(); tracing::error!( workflow_run_id = %run_id, step = %step_name, - error = %e, + error = %err_str, "Compensation failed" ); + // Tag the step row so operators can see exactly which + // compensations failed and need manual remediation + // (#18 in issues doc). + self.update_step_status_with_error( + run_id, + step_name, + StepStatus::CompensationFailed, + Some(&err_str), + ) + .await?; + failures.push(format!("{step_name}: {err_str}")); } } } else { @@ -470,7 +565,7 @@ impl WorkflowExecutor { .await?; } } - Ok(()) + Ok(failures) } async fn get_workflow_steps( @@ -535,6 +630,38 @@ impl WorkflowExecutor { Ok(()) } + /// Update step status and optionally write an error message. Used for + /// surfacing `CompensationFailed` so operators can locate stuck side + /// effects without diving into logs. + async fn update_step_status_with_error( + &self, + workflow_run_id: Uuid, + step_name: &str, + status: StepStatus, + error: Option<&str>, + ) -> forge_core::Result<()> { + // forge_workflow_steps is a runtime-owned system table; offline .sqlx + // cache doesn't always include it. + #[allow(clippy::disallowed_methods)] + sqlx::query( + r#" + UPDATE forge_workflow_steps + SET status = $3, + error = COALESCE($4, error) + WHERE workflow_run_id = $1 AND step_name = $2 + "#, + ) + .bind(workflow_run_id) + .bind(step_name) + .bind(status.as_str()) + .bind(error) + .execute(&self.pool) + .await + .map_err(forge_core::ForgeError::Database)?; + + Ok(()) + } + // forge_workflow_runs is a runtime-owned system table; offline .sqlx cache // doesn't always include it. #[allow(clippy::disallowed_methods)] @@ -660,17 +787,24 @@ impl WorkflowExecutor { /// Atomically claim a workflow for execution (transition to Running). /// - /// `'running'` is included so resume picks up a run that the scheduler has - /// already flipped to running as part of its claim-and-enqueue transaction. - /// Duplicate concurrent execution is prevented at higher layers: the job - /// queue's `FOR UPDATE SKIP LOCKED` ensures only one worker can hold a - /// given resume job, and the scheduler's row-locking UPDATE (or event - /// consume) ensures only one resume job is enqueued per wake event. + /// Rejects `running → running`: a row already in `running` is being + /// executed by another handler. Re-entering races the live handler's + /// compensation state and step writes (#2 in the issues doc). The + /// scheduler claims rows from `(sleeping, waiting)` only; the job-queue's + /// `FOR UPDATE SKIP LOCKED` plus the per-run advisory lock taken in + /// `execute_workflow` keep concurrent resumes serialized end-to-end. + /// + /// The cancel bridge calls `force_claim_for_cancel` (which permits the + /// `running → running` transition for compensation) instead of going + /// through this path. async fn claim_for_execution(&self, run_id: Uuid) -> forge_core::Result<()> { - let result = sqlx::query!( - "UPDATE forge_workflow_runs SET status = 'running' WHERE id = $1 AND status IN ('pending', 'sleeping', 'waiting', 'running')", - run_id, + // Runtime query — schema unchanged, just a tighter status set; avoids + // touching `.sqlx/` for an internal helper. + #[allow(clippy::disallowed_methods)] + let result = sqlx::query( + "UPDATE forge_workflow_runs SET status = 'running' WHERE id = $1 AND status IN ('pending', 'sleeping', 'waiting')", ) + .bind(run_id) .execute(&self.pool) .await .map_err(forge_core::ForgeError::Database)?; @@ -709,6 +843,49 @@ impl WorkflowExecutor { Ok(()) } + /// Finalize a cancellation in the database. + /// + /// Writes the cancel reason to the dedicated `cancel_reason` column rather + /// than smuggling it into `error`. `error` carries the compensation + /// failure summary (if any) so operators have a single place to look for + /// remediation work. Sets `completed_at` so dashboards stop showing the + /// run as in-flight (#17 in issues doc). + async fn finalize_cancel( + &self, + run_id: Uuid, + reason: &str, + compensation_summary: Option<&str>, + ) -> forge_core::Result<()> { + // forge_workflow_runs is a runtime-owned system table; offline .sqlx + // cache doesn't always include it. + #[allow(clippy::disallowed_methods)] + let result = sqlx::query( + r#" + UPDATE forge_workflow_runs + SET status = 'failed', + cancel_reason = COALESCE(cancel_reason, $1), + error = $2, + completed_at = NOW() + WHERE id = $3 + AND status IN ('running', 'sleeping', 'waiting', 'pending') + "#, + ) + .bind(reason) + .bind(compensation_summary) + .bind(run_id) + .execute(&self.pool) + .await + .map_err(forge_core::ForgeError::Database)?; + + if result.rows_affected() == 0 { + return Err(forge_core::ForgeError::InvalidState(format!( + "Cannot finalize cancel for workflow {}: not in a valid state", + run_id + ))); + } + Ok(()) + } + async fn fail_workflow(&self, run_id: Uuid, error: &str) -> forge_core::Result<()> { let result = sqlx::query!( "UPDATE forge_workflow_runs SET status = 'failed', error = $1, completed_at = NOW() WHERE id = $2 AND status IN ('running', 'sleeping', 'waiting', 'pending')", @@ -737,9 +914,11 @@ impl WorkflowExecutor { ) -> forge_core::Result<()> { // Uses runtime query because the status value is dynamic and the // sqlx offline cache doesn't have an entry for this parameterized form. + // #19: also set `completed_at` so blocked runs leave the active list + // (otherwise dashboards treat them as in-flight indefinitely). #[allow(clippy::disallowed_methods)] sqlx::query( - "UPDATE forge_workflow_runs SET status = $1, error = $2 WHERE id = $3 AND status IN ('running', 'sleeping', 'waiting', 'pending')", + "UPDATE forge_workflow_runs SET status = $1, error = $2, completed_at = NOW() WHERE id = $3 AND status IN ('running', 'sleeping', 'waiting', 'pending')", ) .bind(status.as_str()) .bind(reason) diff --git a/crates/forge-runtime/src/workflow/registry.rs b/crates/forge-runtime/src/workflow/registry.rs index 58a37249..301bf13c 100644 --- a/crates/forge-runtime/src/workflow/registry.rs +++ b/crates/forge-runtime/src/workflow/registry.rs @@ -7,33 +7,11 @@ use std::sync::Arc; use chrono::{DateTime, Utc}; use forge_core::ForgeError; use forge_core::config::SignatureCheckMode; +use forge_core::util::normalize_handler_args as normalize_args; use forge_core::workflow::{ForgeWorkflow, WorkflowContext, WorkflowInfo}; -use serde_json::Value; use sqlx::PgPool; use uuid::Uuid; -// Converts null to {} so unit () and empty structs deserialize correctly. -// Unwraps one-level "args"/"input" envelopes (callers may use either format). -fn normalize_args(args: Value) -> Value { - let unwrapped = match &args { - Value::Object(map) if map.len() == 1 => { - if map.contains_key("args") { - map.get("args").cloned().unwrap_or(Value::Null) - } else if map.contains_key("input") { - map.get("input").cloned().unwrap_or(Value::Null) - } else { - args - } - } - _ => args, - }; - - match &unwrapped { - Value::Null => Value::Object(serde_json::Map::new()), - _ => unwrapped, - } -} - pub type BoxedWorkflowHandler = Arc< dyn Fn( &WorkflowContext, @@ -199,57 +177,51 @@ impl WorkflowRegistry { /// failing if a previously-registered name+version row has a different /// signature (the contract changed without a version bump). New rows are /// inserted, existing matching rows get their `status` refreshed. + /// + /// Each definition is upserted in a single transaction with + /// `INSERT ... ON CONFLICT DO UPDATE ... RETURNING workflow_signature` + /// so two nodes booting concurrently can't produce a generic + /// unique-violation in place of the helpful signature-mismatch message + /// (#16 in issues doc). pub async fn persist_definitions(&self, pool: &PgPool) -> forge_core::Result<()> { for info in self.definitions() { let status = info.status.as_str(); - let existing = sqlx::query!( + let mut tx = pool.begin().await.map_err(ForgeError::Database)?; + + // Atomic upsert: only the status is updated on conflict; signature + // is preserved so we can compare it to the incoming one without + // a separate SELECT round-trip. + #[allow(clippy::disallowed_methods)] + let returned: (String,) = sqlx::query_as( r#" - SELECT workflow_signature FROM forge_workflow_definitions - WHERE workflow_name = $1 AND workflow_version = $2 + INSERT INTO forge_workflow_definitions (workflow_name, workflow_version, workflow_signature, status) + VALUES ($1, $2, $3, $4) + ON CONFLICT (workflow_name, workflow_version) DO UPDATE SET status = EXCLUDED.status + RETURNING workflow_signature "#, - info.name, - info.version, ) - .fetch_optional(pool) + .bind(info.name) + .bind(info.version) + .bind(info.signature) + .bind(status) + .fetch_one(&mut *tx) .await .map_err(ForgeError::Database)?; - if let Some(row) = existing { - if row.workflow_signature != info.signature { - return Err(ForgeError::config(format!( - "Workflow '{}' version '{}' has a different signature than previously registered. \ - Persisted contract changed under the same version. \ - Expected signature: {}, got: {}. \ - Create a new version instead of modifying the existing one.", - info.name, info.version, row.workflow_signature, info.signature - ))); - } - sqlx::query!( - "UPDATE forge_workflow_definitions SET status = $3 WHERE workflow_name = $1 AND workflow_version = $2", - info.name, - info.version, - status, - ) - .execute(pool) - .await - .map_err(ForgeError::Database)?; - } else { - sqlx::query!( - r#" - INSERT INTO forge_workflow_definitions (workflow_name, workflow_version, workflow_signature, status) - VALUES ($1, $2, $3, $4) - "#, - info.name, - info.version, - info.signature, - status, - ) - .execute(pool) - .await - .map_err(ForgeError::Database)?; + if returned.0 != info.signature { + tx.rollback().await.map_err(ForgeError::Database)?; + return Err(ForgeError::config(format!( + "Workflow '{}' version '{}' has a different signature than previously registered. \ + Persisted contract changed under the same version. \ + Expected signature: {}, got: {}. \ + Create a new version instead of modifying the existing one.", + info.name, info.version, returned.0, info.signature + ))); } + tx.commit().await.map_err(ForgeError::Database)?; + tracing::debug!( workflow = info.name, version = info.version, @@ -377,51 +349,11 @@ impl Clone for WorkflowRegistry { mod tests { use super::*; use forge_core::workflow::WorkflowDefStatus; - use serde_json::json; - - // --- normalize_args mirrors the jobs/registry contract: null collapses - // to {} (so empty-struct inputs deserialize) and one-level `args`/`input` - // envelopes are unwrapped. Other shapes pass through unchanged. - #[test] - fn normalize_args_converts_null_to_empty_object() { - assert_eq!(normalize_args(json!(null)), json!({})); - } - - #[test] - fn normalize_args_keeps_empty_object_intact() { - assert_eq!(normalize_args(json!({})), json!({})); - } + use serde_json::Value; - #[test] - fn normalize_args_unwraps_args_envelope() { - assert_eq!(normalize_args(json!({"args": {"x": 1}})), json!({"x": 1})); - // null inside the envelope still collapses to {}. - assert_eq!(normalize_args(json!({"args": null})), json!({})); - } - - #[test] - fn normalize_args_unwraps_input_envelope() { - assert_eq!(normalize_args(json!({"input": [9, 8]})), json!([9, 8])); - } - - #[test] - fn normalize_args_keeps_other_single_key_objects_intact() { - assert_eq!(normalize_args(json!({"id": 7})), json!({"id": 7})); - } - - #[test] - fn normalize_args_keeps_multi_key_objects_intact() { - let v = json!({"a": 1, "b": 2}); - assert_eq!(normalize_args(v.clone()), v); - } - - #[test] - fn normalize_args_keeps_scalars_intact() { - assert_eq!(normalize_args(json!(42)), json!(42)); - assert_eq!(normalize_args(json!("ok")), json!("ok")); - assert_eq!(normalize_args(json!(true)), json!(true)); - } + // normalize_args contract is exercised via `forge_core::util` tests; this + // file now delegates to the shared helper to keep the two registries in sync. // ForgeWorkflow is sealed, so tests build entries directly through pub fields // with noop handlers — same insertion shape as register::. diff --git a/crates/forge-runtime/src/workflow/scheduler.rs b/crates/forge-runtime/src/workflow/scheduler.rs index b0bdbb91..f5a4d605 100644 --- a/crates/forge-runtime/src/workflow/scheduler.rs +++ b/crates/forge-runtime/src/workflow/scheduler.rs @@ -11,6 +11,29 @@ use crate::jobs::JobQueue; use crate::pg::{LeaderElection, PgNotifyBus}; use forge_core::Result; +/// Why a workflow is being resumed. Surfaced to the bridge / handler in the +/// `$workflow_resume` job args under the `resume_reason` key so replayed +/// `wait_for_event` calls can distinguish "event arrived" from "event +/// timeout" (and timer wakeups from event-driven ones). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum ResumeReason { + Timer, + EventArrived, + EventTimeout, + Cancel, +} + +impl ResumeReason { + fn as_str(self) -> &'static str { + match self { + Self::Timer => "timer", + Self::EventArrived => "event_arrived", + Self::EventTimeout => "event_timeout", + Self::Cancel => "cancel", + } + } +} + /// Configuration for the workflow scheduler. #[derive(Debug, Clone)] pub struct WorkflowSchedulerConfig { @@ -198,10 +221,20 @@ impl WorkflowScheduler { for workflow in workflows { if workflow.waiting_for_event.is_some() { - self.claim_and_resume(workflow.id, false, "event_timeout") - .await; + // Event-wait timed out: tag the resume so the bridge / handler + // can tell "event arrived" apart from "event timeout" + // (#5 in issues doc). Without this signal, replayed + // `wait_for_event` cannot tell which branch fired. + self.claim_and_resume( + workflow.id, + false, + "event_timeout", + ResumeReason::EventTimeout, + ) + .await; } else { - self.claim_and_resume(workflow.id, true, "timer").await; + self.claim_and_resume(workflow.id, true, "timer", ResumeReason::Timer) + .await; } } @@ -249,6 +282,7 @@ impl WorkflowScheduler { "from_sleep": false, "cancel": true, "reason": reason, + "resume_reason": ResumeReason::Cancel.as_str(), }); let job = crate::jobs::JobRecord::new( WORKFLOW_RESUME_JOB.to_string(), @@ -343,12 +377,19 @@ impl WorkflowScheduler { return Ok(()); } + // Move the run out of the wait state to 'pending' (not 'running'): + // the executor's `claim_for_execution` is the sole claimer and only + // accepts pending/sleeping/waiting -> running. Pre-claiming to + // 'running' here would make the enqueued resume job unclaimable, so + // the handler would never run and the workflow would hang. 'pending' + // mirrors the start path (a fresh run is 'pending' with a resume job + // enqueued) and is not re-scanned by the timer/event poll queries. #[allow(clippy::disallowed_methods)] let claimed = sqlx::query( r#" UPDATE forge_workflow_runs SET wake_at = NULL, waiting_for_event = NULL, event_timeout_at = NULL, - suspended_at = NULL, status = 'running' + suspended_at = NULL, status = 'pending' WHERE id = $1 AND status IN ('sleeping', 'waiting') "#, ) @@ -365,6 +406,7 @@ impl WorkflowScheduler { let input = serde_json::json!({ "run_id": workflow_run_id.to_string(), "from_sleep": false, + "resume_reason": ResumeReason::EventArrived.as_str(), }); let job = crate::jobs::JobRecord::new( WORKFLOW_RESUME_JOB.to_string(), @@ -399,18 +441,29 @@ impl WorkflowScheduler { /// Atomically claim a workflow and enqueue a resume job in a single transaction. /// If the claim fails (row already claimed), the transaction is rolled back /// and no resume job is enqueued. - async fn claim_and_resume(&self, workflow_run_id: Uuid, from_sleep: bool, trigger: &str) { + async fn claim_and_resume( + &self, + workflow_run_id: Uuid, + from_sleep: bool, + trigger: &str, + reason: ResumeReason, + ) { let result: std::result::Result<(), sqlx::Error> = async { let mut tx = self.pool.begin().await?; // Runtime query: rewritten for single-transaction claim+resume; // convert to query!() after next `cargo sqlx prepare`. + // Move to 'pending' (not 'running'): the executor's + // `claim_for_execution` is the sole claimer and rejects 'running'. + // Pre-claiming to 'running' here would leave the enqueued resume job + // unable to claim the run, hanging timer/sleep resumes (mirrors the + // event path in `consume_claim_and_resume`). #[allow(clippy::disallowed_methods)] let claimed = sqlx::query( r#" UPDATE forge_workflow_runs SET wake_at = NULL, waiting_for_event = NULL, event_timeout_at = NULL, - suspended_at = NULL, status = 'running' + suspended_at = NULL, status = 'pending' WHERE id = $1 AND status IN ('sleeping', 'waiting') "#, ) @@ -426,6 +479,7 @@ impl WorkflowScheduler { let input = serde_json::json!({ "run_id": workflow_run_id.to_string(), "from_sleep": from_sleep, + "resume_reason": reason.as_str(), }); let job = crate::jobs::JobRecord::new( WORKFLOW_RESUME_JOB.to_string(), From 41344dc58dc0bd2c17ddd116f00aa0a40741248c Mon Sep 17 00:00:00 2001 From: Isala Piyarisi Date: Tue, 26 May 2026 01:52:08 +0530 Subject: [PATCH 3/7] harden cli and template packaging Include lockfiles in the template archive and preserve them on new, resolve the forge crate binding via proc-macro-crate, and tighten check/migrate/test/new command paths. --- crates/forge/Cargo.toml | 2 +- crates/forge/build.rs | 110 +++++---- crates/forge/src/auto_register.rs | 48 +++- crates/forge/src/cli/check/frontend.rs | 13 +- crates/forge/src/cli/check/project.rs | 27 ++- crates/forge/src/cli/check/sqlx.rs | 30 ++- crates/forge/src/cli/check/system_tables.rs | 77 +++++- crates/forge/src/cli/doctor.rs | 79 ++----- crates/forge/src/cli/frontend_codegen.rs | 16 +- crates/forge/src/cli/generate.rs | 6 +- crates/forge/src/cli/migrate.rs | 119 +++++++++- crates/forge/src/cli/new.rs | 122 ++++++++-- crates/forge/src/cli/template_catalog.rs | 12 +- crates/forge/src/cli/test.rs | 248 +++++++++++--------- crates/forge/src/cli/webhook.rs | 77 ++++-- crates/forge/src/runtime/builder.rs | 12 +- crates/forge/src/runtime/mod.rs | 25 +- 17 files changed, 699 insertions(+), 324 deletions(-) diff --git a/crates/forge/Cargo.toml b/crates/forge/Cargo.toml index 0c68069b..e9301819 100644 --- a/crates/forge/Cargo.toml +++ b/crates/forge/Cargo.toml @@ -67,6 +67,7 @@ opentelemetry-otlp = { workspace = true, optional = true } rust-embed = { workspace = true, optional = true } mime_guess = { workspace = true, optional = true } +tempfile = { workspace = true } [features] # Default to `full` so existing apps upgrade transparently. Users who want a @@ -148,7 +149,6 @@ embedded-frontend = ["dep:rust-embed", "dep:mime_guess"] nix = { version = "0.29", features = ["signal", "hostname"] } [dev-dependencies] -tempfile = { workspace = true } trybuild = { workspace = true } [build-dependencies] diff --git a/crates/forge/build.rs b/crates/forge/build.rs index fb886eeb..e0ab3ce9 100644 --- a/crates/forge/build.rs +++ b/crates/forge/build.rs @@ -1,81 +1,82 @@ use std::env; use std::fs; +use std::io; use std::path::{Path, PathBuf}; -fn main() { +fn main() -> io::Result<()> { println!("cargo:rerun-if-changed=build.rs"); - let manifest_dir = - PathBuf::from(env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR must be set")); - let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR must be set")); + let manifest_dir = PathBuf::from( + env::var("CARGO_MANIFEST_DIR") + .map_err(|e| io::Error::other(format!("CARGO_MANIFEST_DIR: {e}")))?, + ); + let out_dir = + PathBuf::from(env::var("OUT_DIR").map_err(|e| io::Error::other(format!("OUT_DIR: {e}")))?); let embedded_examples_dir = out_dir.join("examples"); if embedded_examples_dir.exists() { - fs::remove_dir_all(&embedded_examples_dir) - .expect("failed to clear generated embedded examples"); + fs::remove_dir_all(&embedded_examples_dir)?; } - fs::create_dir_all(&embedded_examples_dir) - .expect("failed to create generated embedded examples"); + fs::create_dir_all(&embedded_examples_dir)?; - if let Some(examples_dir) = find_examples_dir(&manifest_dir) { - build_bundle_from_examples(&examples_dir, &embedded_examples_dir); - return; + if let Some(examples_dir) = find_examples_dir(&manifest_dir)? { + build_bundle_from_examples(&examples_dir, &embedded_examples_dir)?; + return Ok(()); } let archive_path = manifest_dir.join("generated/examples.tar"); println!("cargo:rerun-if-changed={}", archive_path.display()); if archive_path.exists() { - let archive = - fs::File::open(&archive_path).expect("failed to open generated examples archive"); + let archive = fs::File::open(&archive_path)?; let mut archive = tar::Archive::new(archive); - archive - .unpack(&embedded_examples_dir) - .expect("failed to unpack generated examples archive"); - return; + archive.unpack(&embedded_examples_dir)?; + return Ok(()); } - unreachable!("could not find examples directory or generated examples archive"); + Err(io::Error::other( + "could not find examples directory or generated examples archive", + )) } -fn build_bundle_from_examples(examples_dir: &Path, bundle_dir: &Path) { - for framework_dir in fs::read_dir(examples_dir).expect("failed to read examples directory") { - let framework_dir = framework_dir.expect("failed to read framework entry"); +fn build_bundle_from_examples(examples_dir: &Path, bundle_dir: &Path) -> io::Result<()> { + for framework_dir in fs::read_dir(examples_dir)? { + let framework_dir = framework_dir?; let framework_path = framework_dir.path(); if !framework_path.is_dir() { continue; } - let framework_name = framework_path - .file_name() - .and_then(|name| name.to_str()) - .expect("framework directory must have utf-8 name"); + let Some(framework_name) = framework_path.file_name().and_then(|n| n.to_str()) else { + continue; + }; if !framework_name.starts_with("with-") { continue; } - for template_dir in fs::read_dir(&framework_path).expect("failed to read framework dir") { - let template_dir = template_dir.expect("failed to read template entry"); + for template_dir in fs::read_dir(&framework_path)? { + let template_dir = template_dir?; let template_path = template_dir.path(); if !template_path.is_dir() { continue; } + let Some(template_name) = template_path.file_name() else { + continue; + }; + copy_template_tree( &template_path, - &bundle_dir.join(framework_name).join( - template_path - .file_name() - .expect("template directory must have a name"), - ), - ); + &bundle_dir.join(framework_name).join(template_name), + )?; } } + Ok(()) } -fn find_examples_dir(manifest_dir: &Path) -> Option { +fn find_examples_dir(manifest_dir: &Path) -> io::Result> { let candidates = [ manifest_dir.join("../../examples"), manifest_dir.join("examples"), @@ -83,15 +84,15 @@ fn find_examples_dir(manifest_dir: &Path) -> Option { for candidate in candidates { if candidate.is_dir() { - register_rerun_paths(&candidate); - return Some(candidate); + register_rerun_paths(&candidate)?; + return Ok(Some(candidate)); } } - None + Ok(None) } -fn register_rerun_paths(root: &Path) { +fn register_rerun_paths(root: &Path) -> io::Result<()> { println!("cargo:rerun-if-changed={}", root.display()); if let Ok(entries) = fs::read_dir(root) { @@ -99,22 +100,21 @@ fn register_rerun_paths(root: &Path) { let path = entry.path(); println!("cargo:rerun-if-changed={}", path.display()); if path.is_dir() { - register_rerun_paths(&path); + register_rerun_paths(&path)?; } } } + Ok(()) } -fn copy_template_tree(src: &Path, dest: &Path) { - fs::create_dir_all(dest).expect("failed to create template directory"); - copy_dir_contents(src, dest, Path::new("")); +fn copy_template_tree(src: &Path, dest: &Path) -> io::Result<()> { + fs::create_dir_all(dest)?; + copy_dir_contents(src, dest, Path::new("")) } -fn copy_dir_contents(src: &Path, dest: &Path, relative: &Path) { - let entries = fs::read_dir(src).expect("failed to read template source directory"); - - for entry in entries { - let entry = entry.expect("failed to read template source entry"); +fn copy_dir_contents(src: &Path, dest: &Path, relative: &Path) -> io::Result<()> { + for entry in fs::read_dir(src)? { + let entry = entry?; let entry_path = entry.path(); let entry_name = entry.file_name(); let relative_path = relative.join(&entry_name); @@ -125,24 +125,20 @@ fn copy_dir_contents(src: &Path, dest: &Path, relative: &Path) { let dest_path = dest.join(&entry_name); if entry_path.is_dir() { - fs::create_dir_all(&dest_path).expect("failed to create bundled directory"); - copy_dir_contents(&entry_path, &dest_path, &relative_path); + fs::create_dir_all(&dest_path)?; + copy_dir_contents(&entry_path, &dest_path, &relative_path)?; } else { if let Some(parent) = dest_path.parent() { - fs::create_dir_all(parent).expect("failed to create bundled file parent"); + fs::create_dir_all(parent)?; } - fs::copy(&entry_path, &dest_path).expect("failed to copy template file"); + fs::copy(&entry_path, &dest_path)?; } } + Ok(()) } fn should_exclude(relative_path: &Path) -> bool { - const EXACT_FILES: &[&str] = &[ - ".forge-dev-integration.log", - "package-lock.json", - "bun.lock", - "Cargo.lock", - ]; + const EXACT_FILES: &[&str] = &[".forge-dev-integration.log"]; const PATH_COMPONENTS: &[&str] = &[ ".git", "pg_data", diff --git a/crates/forge/src/auto_register.rs b/crates/forge/src/auto_register.rs index 7250eb17..999655fb 100644 --- a/crates/forge/src/auto_register.rs +++ b/crates/forge/src/auto_register.rs @@ -1,5 +1,6 @@ //! Automatic function registration via the `inventory` crate. +use forge_core::error::{ForgeError, Result}; use forge_runtime::function::FunctionRegistry; #[cfg(feature = "cron")] @@ -39,9 +40,52 @@ pub struct AutoHandler(pub fn(&mut HandlerRegistries)); inventory::collect!(AutoHandler); -/// Register all auto-discovered handlers. -pub fn auto_register_all(registries: &mut HandlerRegistries) { +/// Register all auto-discovered handlers, failing if any handler name collides. +/// +/// Duplicate detection: the per-kind registries store handlers in `HashMap`s +/// keyed on the handler name, so a duplicate (e.g. two `#[query] pub async fn +/// get_user`s in different modules) would silently overwrite. We snapshot the +/// function-name set before and after each closure runs and surface any +/// collision as a startup error. +pub fn auto_register_all(registries: &mut HandlerRegistries) -> Result<()> { + use std::collections::HashSet; + + let mut seen: HashSet = registries + .functions + .function_names() + .map(|s| s.to_string()) + .collect(); + for entry in inventory::iter:: { + let before = registries.functions.len(); (entry.0)(registries); + let after = registries.functions.len(); + + // The closure might register zero functions (job/cron/daemon/webhook/mcp_tool + // bridges) — only validate when the function registry actually grew or + // when an existing entry got overwritten in place. + let current: HashSet = registries + .functions + .function_names() + .map(|s| s.to_string()) + .collect(); + + let newly_added: Vec = current.difference(&seen).cloned().collect(); + if !newly_added.is_empty() { + seen.extend(newly_added); + } else if after <= before { + // No net growth and no new names — either a non-function handler or + // an overwrite. Detect overwrite by checking the entry count. + let registered_count = registries.functions.len(); + if registered_count < seen.len() { + return Err(ForgeError::config( + "duplicate handler name detected during auto-registration: \ + two #[forge::*] handlers resolve to the same function name. \ + Use `name = \"...\"` in one of the macro attributes to disambiguate.", + )); + } + } } + + Ok(()) } diff --git a/crates/forge/src/cli/check/frontend.rs b/crates/forge/src/cli/check/frontend.rs index 183c554c..bd2324ca 100644 --- a/crates/forge/src/cli/check/frontend.rs +++ b/crates/forge/src/cli/check/frontend.rs @@ -151,8 +151,9 @@ impl CheckCommand { || !registry.all_enums().is_empty() || !registry.all_functions().is_empty(); - let tmp_dir = frontend_dir.join(format!("forge-check-{}", std::process::id())); - let tmp_output = tmp_dir.join("bindings"); + // tempfile::tempdir() avoids PID-reuse collisions on long-lived CI containers. + let tmp_handle = tempfile::tempdir_in(frontend_dir)?; + let tmp_output = tmp_handle.path().join("bindings"); std::fs::create_dir_all(&tmp_output)?; let tmp_output_str = tmp_output.to_string_lossy().to_string(); @@ -164,12 +165,7 @@ impl CheckCommand { force: true, }); - let cleanup = || { - let _ = std::fs::remove_dir_all(&tmp_dir); - }; - if let Err(e) = gen_result { - cleanup(); result.warn( &format!("Could not regenerate bindings: {}", e), "Run 'forge generate' to check manually", @@ -180,7 +176,6 @@ impl CheckCommand { if let Err(e) = format_generated_bindings_for_check(target, frontend_dir, output_path, &tmp_output) { - cleanup(); result.warn( &format!("Could not format regenerated bindings: {}", e), "Run 'forge generate --force' to restore generated bindings", @@ -218,8 +213,6 @@ impl CheckCommand { } } - cleanup(); - if modified.is_empty() && missing.is_empty() { result.pass("Generated bindings are up to date"); } else { diff --git a/crates/forge/src/cli/check/project.rs b/crates/forge/src/cli/check/project.rs index f728b4dc..bbf495b9 100644 --- a/crates/forge/src/cli/check/project.rs +++ b/crates/forge/src/cli/check/project.rs @@ -47,24 +47,25 @@ impl CheckCommand { }; let filename = file_name.to_string_lossy(); - let name_valid = filename - .split('_') - .next() - .map(|prefix| prefix.chars().all(|c| c.is_ascii_digit())) - .unwrap_or(false); + // Require `_.sql` with both sides non-empty. + // Empty prefix (`_initial.sql`) or empty tail (`0001_.sql`) + // pass naive splits and sort surprisingly at runtime. + let stem = filename.strip_suffix(".sql").unwrap_or(&filename); + let name_valid = match stem.split_once('_') { + Some((prefix, tail)) => { + !prefix.is_empty() + && !tail.is_empty() + && prefix.chars().all(|c| c.is_ascii_digit()) + } + None => false, + }; if !name_valid { issues.push(format!("{} - should be NNNN_name.sql", filename)); continue; } - // Check for @up marker - let content = std::fs::read_to_string(&path)?; - if content.contains("-- @up") { - valid_count += 1; - } else { - issues.push(format!("{} - missing '-- @up' marker", filename)); - } + valid_count += 1; } } @@ -82,7 +83,7 @@ impl CheckCommand { issues.len(), migration_count ), - "Fix migration file naming or add '-- @up' marker", + "Use NNNN_name.sql with a numeric prefix and a non-empty name", ); for issue in issues.iter().take(3) { result.info(issue); diff --git a/crates/forge/src/cli/check/sqlx.rs b/crates/forge/src/cli/check/sqlx.rs index 98130e50..b3480f19 100644 --- a/crates/forge/src/cli/check/sqlx.rs +++ b/crates/forge/src/cli/check/sqlx.rs @@ -116,7 +116,8 @@ pub(super) fn file_uses_sqlx_macros(content: &str) -> bool { "sqlx::query_file!(", "sqlx::query_file_as!(", ]; - content.lines().any(|line| { + let stripped = strip_comments(content); + stripped.lines().any(|line| { let code = match line.split_once("//") { Some((before, _)) => before, None => line, @@ -125,6 +126,33 @@ pub(super) fn file_uses_sqlx_macros(content: &str) -> bool { }) } +/// Replace `/* ... */` block comments with whitespace of the same length so +/// line/column offsets are preserved. Nested comments are not supported (Rust +/// allows them but they are vanishingly rare and don't affect the heuristic). +#[allow(clippy::indexing_slicing)] +fn strip_comments(content: &str) -> String { + let bytes = content.as_bytes(); + let mut out = String::with_capacity(content.len()); + let mut i = 0; + while i < bytes.len() { + if i + 1 < bytes.len() && bytes[i] == b'/' && bytes[i + 1] == b'*' { + // Skip block comment, preserving newlines so line-comment stripping later still works. + i += 2; + while i + 1 < bytes.len() && !(bytes[i] == b'*' && bytes[i + 1] == b'/') { + if bytes[i] == b'\n' { + out.push('\n'); + } + i += 1; + } + i = (i + 2).min(bytes.len()); + } else { + out.push(bytes[i] as char); + i += 1; + } + } + out +} + pub(super) fn inspect_sqlx_cache(sqlx_dir: &Path) -> Result { if !sqlx_dir.exists() { return Ok(SqlxCacheCheck::Missing); diff --git a/crates/forge/src/cli/check/system_tables.rs b/crates/forge/src/cli/check/system_tables.rs index eb9cff87..a72f341e 100644 --- a/crates/forge/src/cli/check/system_tables.rs +++ b/crates/forge/src/cli/check/system_tables.rs @@ -1,16 +1,48 @@ use anyhow::Result; use std::path::Path; +// Derived from crates/forge-runtime/migrations/system/v00*_*.sql. Keep in sync +// when a new system table is added there; `forge check` must fail closed when +// a user migration shadows a runtime-owned name. pub(super) const RESERVED_SYSTEM_TABLES: &[&str] = &[ - "forge_jobs", - "forge_workflow_runs", - "forge_workflow_definitions", + "forge_admin_audit", + "forge_change_log", "forge_cron_runs", - "forge_system_migrations", + "forge_daemons", + "forge_jobs", + "forge_jobs_history", + "forge_kv", + "forge_kv_counters", + "forge_leaders", + "forge_nodes", + "forge_oauth_clients", + "forge_oauth_codes", + "forge_paused_queues", + "forge_rate_limits", "forge_refresh_tokens", + "forge_signals_daily_rollup", "forge_signals_events", + "forge_signals_hourly_stats", + "forge_signals_sessions", + "forge_signals_users", + "forge_system_migrations", + "forge_webhook_events", + "forge_workflow_definitions", + "forge_workflow_events", + "forge_workflow_runs", + "forge_workflow_state", + "forge_workflow_steps", ]; +/// System tables a handler may legitimately write to directly. +/// +/// `forge_workflow_events` is the workflow event inbox: a handler delivers an +/// external event to a `ctx.wait_for_event(...)` workflow by inserting a row +/// (there is no higher-level API for this, and the runtime's own harness does +/// the same). It stays in `RESERVED_SYSTEM_TABLES` for the migration-shadow +/// check, but writing to it is a supported pattern, not a leak. +const HANDLER_WRITABLE_SYSTEM_TABLES: &[&str] = &["forge_workflow_events"]; + pub(super) fn scan_system_table_writes( dir: &Path, out: &mut Vec<(std::path::PathBuf, &'static str)>, @@ -33,6 +65,9 @@ pub(super) fn scan_system_table_writes( let lower = content.to_ascii_lowercase(); for table in RESERVED_SYSTEM_TABLES { + if HANDLER_WRITABLE_SYSTEM_TABLES.contains(table) { + continue; + } let needles = [ format!("insert into {table}"), format!("update {table}"), @@ -46,3 +81,37 @@ pub(super) fn scan_system_table_writes( } Ok(()) } + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn flags_state_table_writes_but_allows_the_workflow_event_inbox() { + let dir = tempfile::tempdir().unwrap(); + std::fs::write( + dir.path().join("jobs.rs"), + r#"sqlx::query("INSERT INTO forge_jobs (id) VALUES ($1)")"#, + ) + .unwrap(); + std::fs::write( + dir.path().join("events.rs"), + r#"sqlx::query("INSERT INTO forge_workflow_events (id) VALUES ($1)")"#, + ) + .unwrap(); + + let mut out = Vec::new(); + scan_system_table_writes(dir.path(), &mut out).unwrap(); + let flagged: Vec<&str> = out.iter().map(|(_, t)| *t).collect(); + + assert!( + flagged.contains(&"forge_jobs"), + "direct write to a state table must be flagged" + ); + assert!( + !flagged.contains(&"forge_workflow_events"), + "the workflow event inbox is a supported handler write target" + ); + } +} diff --git a/crates/forge/src/cli/doctor.rs b/crates/forge/src/cli/doctor.rs index b0da0456..8d45fcdf 100644 --- a/crates/forge/src/cli/doctor.rs +++ b/crates/forge/src/cli/doctor.rs @@ -87,7 +87,6 @@ impl DoctorCommand { if let Some(ref root) = root { check_forge_toml(&mut report, root); check_sqlx_cache_freshness(&mut report, root); - check_latest_migration_markers(&mut report, root); } println!(); @@ -348,61 +347,6 @@ fn check_sqlx_cache_freshness(report: &mut Report, root: &Path) { } } -fn check_latest_migration_markers(report: &mut Report, root: &Path) { - let dir = root.join("migrations"); - let entries = match std::fs::read_dir(&dir) { - Ok(e) => e, - Err(_) => { - report.record(CheckStatus::Skip, "migrations/", "no migrations/ dir", None); - return; - } - }; - let mut latest: Option = None; - for entry in entries.flatten() { - let p = entry.path(); - if p.extension().and_then(|s| s.to_str()) == Some("sql") - && latest.as_ref().map(|l| p > *l).unwrap_or(true) - { - latest = Some(p); - } - } - let Some(path) = latest else { - report.record(CheckStatus::Skip, "migrations/", "empty", None); - return; - }; - let content = match std::fs::read_to_string(&path) { - Ok(c) => c, - Err(e) => { - report.record( - CheckStatus::Fail, - "migrations/", - &format!("read error: {e}"), - None, - ); - return; - } - }; - let name = path - .file_name() - .and_then(|s| s.to_str()) - .unwrap_or("(unknown)"); - if content.contains("-- @up") || !content.trim().is_empty() { - report.record( - CheckStatus::Ok, - "latest migration", - &format!("{name} present"), - None, - ); - } else { - report.record( - CheckStatus::Fail, - "latest migration", - &format!("{name} is empty"), - Some("Migration file must contain SQL"), - ); - } -} - fn parse_pg_host_port(url: &str) -> Option<(String, u16)> { let rest = url .strip_prefix("postgres://") @@ -445,7 +389,16 @@ fn required_rust_version(root: Option<&Path>) -> String { fn version_meets(found: &str, required: &str) -> bool { fn parts(s: &str) -> Vec { - s.split('.').filter_map(|x| x.parse().ok()).collect() + s.split('.') + .map(|x| { + // Strip prerelease/build suffixes ("1.93.0-beta", "1.93+abc") + // by truncating at the first non-digit character per component. + let digits: String = x.chars().take_while(|c| c.is_ascii_digit()).collect(); + digits.parse::().ok() + }) + .take_while(|p| p.is_some()) + .flatten() + .collect() } let f = parts(found); let r = parts(required); @@ -558,14 +511,12 @@ mod tests { } #[test] - fn version_meets_handles_trailing_garbage_after_full_match() { - // Dotted segments that don't parse are dropped, so any trailing tag - // attached to a later segment gets truncated. As long as enough - // numeric components match before the bad one, the comparison passes. + fn version_meets_strips_prerelease_suffixes() { + // Per-component digit truncation handles standard suffixes. assert!(version_meets("1.92.0-nightly", "1.92")); - // But if the bad token replaces a required component, the missing - // component is treated as 0 and the comparison fails. - assert!(!version_meets("1.93-beta", "1.92")); + assert!(version_meets("1.93-beta", "1.92")); + assert!(version_meets("1.93.0+build.5", "1.92")); + assert!(!version_meets("1.91.0-stable", "1.92")); } #[test] diff --git a/crates/forge/src/cli/frontend_codegen.rs b/crates/forge/src/cli/frontend_codegen.rs index de7e4c7e..921128bd 100644 --- a/crates/forge/src/cli/frontend_codegen.rs +++ b/crates/forge/src/cli/frontend_codegen.rs @@ -67,7 +67,21 @@ fn format_generated_rust_bindings(output_dir: &Path) -> Result<()> { rustfmt.arg(file); } - let _ = rustfmt.status(); + match rustfmt.status() { + Ok(status) if status.success() => {} + Ok(status) => { + eprintln!( + "warning: rustfmt exited with status {} while formatting generated Dioxus bindings; output left unformatted", + status + ); + } + Err(e) => { + eprintln!( + "warning: could not run rustfmt to format generated Dioxus bindings: {} (install rustfmt or run 'rustup component add rustfmt')", + e + ); + } + } Ok(()) } diff --git a/crates/forge/src/cli/generate.rs b/crates/forge/src/cli/generate.rs index 36a1f53e..4be753ab 100644 --- a/crates/forge/src/cli/generate.rs +++ b/crates/forge/src/cli/generate.rs @@ -90,10 +90,13 @@ impl GenerateCommand { || !registry.all_enums().is_empty() || !registry.all_functions().is_empty(); + // Serialize the schema up front, but only write it to disk after the + // binding generator succeeds. Otherwise a failed `forge generate` would + // leave `forge.schema.json` describing a state that doesn't match the + // bindings on disk. let schema_path = Path::new("forge.schema.json"); let schema_json = forge_codegen::emit_schema_json(®istry) .map_err(|e| anyhow::anyhow!("Failed to serialize schema: {}", e))?; - std::fs::write(schema_path, &schema_json)?; eprint!( " Generating {} bindings...", @@ -106,6 +109,7 @@ impl GenerateCommand { has_schema, force: self.force, })?; + std::fs::write(schema_path, &schema_json)?; eprintln!(" done"); // Sync the frontend toolchain. For SvelteKit this runs `svelte-kit diff --git a/crates/forge/src/cli/migrate.rs b/crates/forge/src/cli/migrate.rs index 3a79dcdf..551b3885 100644 --- a/crates/forge/src/cli/migrate.rs +++ b/crates/forge/src/cli/migrate.rs @@ -33,7 +33,16 @@ pub enum MigrateAction { Status, /// Generate .sqlx/ offline cache for compile-time query checking. - Prepare, + Prepare { + /// Apply pending migrations before generating the cache. Without this, + /// prepare refuses to mutate a non-local DATABASE_URL unattended. + #[arg(long)] + with_up: bool, + + /// Skip the interactive confirmation prompt. + #[arg(short = 'y', long)] + yes: bool, + }, } impl MigrateCommand { @@ -84,10 +93,39 @@ impl MigrateCommand { println!(); } - MigrateAction::Prepare => { + MigrateAction::Prepare { with_up, yes } => { ui::section("FORGE Prepare"); - if !available.is_empty() { + let database_url_for_check = config.database.url().to_string(); + let is_local = database_url_is_local(&database_url_for_check); + + let pending = runner.status(&available).await?.pending; + + if !pending.is_empty() { + if !with_up { + let masked = mask_database_url(&database_url_for_check); + println!( + " {} {} pending migration(s) detected.", + ui::warn(), + pending.len() + ); + println!(" Target DATABASE_URL: {masked}"); + if !is_local && !yes { + anyhow::bail!( + "Refusing to run pending migrations against a non-local database \ + without explicit consent.\n\n \ + Re-run with `--with-up` to apply, or `--yes` to acknowledge the \ + target. Set DATABASE_URL to a localhost instance for unattended \ + use." + ); + } + if !yes { + anyhow::bail!( + "Refusing to auto-run migrations from `forge migrate prepare`.\n \ + Pass `--with-up` to apply, or run `forge migrate up` separately." + ); + } + } println!(" {} Running pending migrations...", ui::step()); runner.run(available).await?; println!(" {} Migrations complete", ui::ok()); @@ -190,3 +228,78 @@ impl MigrateCommand { Ok(()) } } + +/// True when the URL clearly targets a developer-local Postgres (no risk of +/// stomping a shared environment by accident). +fn database_url_is_local(url: &str) -> bool { + let rest = match url + .strip_prefix("postgres://") + .or_else(|| url.strip_prefix("postgresql://")) + { + Some(r) => r, + None => return false, + }; + let host_section = rest.rsplit_once('@').map(|(_, r)| r).unwrap_or(rest); + let host_port = host_section + .split(['/', '?']) + .next() + .unwrap_or(host_section); + const LOCAL: &[&str] = &["localhost", "127.0.0.1", "::1", "0.0.0.0"]; + if LOCAL.contains(&host_port) { + return true; + } + // Strip trailing :port only when the suffix is all-digits and the + // remaining host has no `:` (rules out IPv6 host without brackets). + let host = match host_port.rsplit_once(':') { + Some((h, p)) + if !p.is_empty() && p.chars().all(|c| c.is_ascii_digit()) && !h.contains(':') => + { + h + } + _ => host_port, + }; + LOCAL.contains(&host) +} + +/// Replace the password in a `postgres[ql]://user:password@host…` URL with `***`. +fn mask_database_url(url: &str) -> String { + let (scheme, rest) = match url.split_once("://") { + Some(pair) => pair, + None => return url.to_string(), + }; + let Some((userinfo, host)) = rest.rsplit_once('@') else { + return url.to_string(); + }; + let masked_userinfo = match userinfo.split_once(':') { + Some((user, _pw)) => format!("{user}:***"), + None => userinfo.to_string(), + }; + format!("{scheme}://{masked_userinfo}@{host}") +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use super::*; + + #[test] + fn database_url_is_local_basic() { + assert!(database_url_is_local("postgres://u:p@localhost:5432/db")); + assert!(database_url_is_local("postgres://u:p@127.0.0.1/db")); + assert!(database_url_is_local("postgresql://u@::1/db")); + assert!(!database_url_is_local("postgres://u:p@db.prod:5432/db")); + assert!(!database_url_is_local("not-a-url")); + } + + #[test] + fn mask_database_url_basic() { + assert_eq!( + mask_database_url("postgres://u:secret@host:5432/db"), + "postgres://u:***@host:5432/db" + ); + assert_eq!( + mask_database_url("postgres://host/db"), + "postgres://host/db" + ); + } +} diff --git a/crates/forge/src/cli/new.rs b/crates/forge/src/cli/new.rs index 5b78c095..aceb840c 100644 --- a/crates/forge/src/cli/new.rs +++ b/crates/forge/src/cli/new.rs @@ -77,6 +77,36 @@ pub(super) fn extract_project_name(input: &str) -> String { .to_string() } +/// Reject names that would escape the cwd, inject shell metacharacters into +/// generated configs / commands, or produce a malformed Cargo / npm package. +/// +/// Allow only `[a-zA-Z0-9_-]`, must start with alphanumeric. +pub(super) fn validate_project_name(name: &str) -> Result<()> { + let trimmed = name.trim(); + if trimmed.is_empty() { + anyhow::bail!("project name cannot be empty or whitespace"); + } + if name != trimmed { + anyhow::bail!("project name cannot have leading or trailing whitespace"); + } + let Some(first) = name.chars().next() else { + anyhow::bail!("project name cannot be empty"); + }; + if !first.is_ascii_alphanumeric() { + anyhow::bail!( + "invalid project name '{name}': must start with a letter or digit (got '{first}')" + ); + } + for c in name.chars() { + if !(c.is_ascii_alphanumeric() || c == '_' || c == '-') { + anyhow::bail!( + "invalid project name '{name}': only [a-zA-Z0-9_-] allowed (rejected character: {c:?})" + ); + } + } + Ok(()) +} + fn is_git_available() -> bool { StdCommand::new("git") .arg("--version") @@ -160,8 +190,6 @@ fn run_formatters(dir: &Path, frontend: FrontendTarget) -> Result<()> { } fn generate_cargo_lockfile(dir: &Path, frontend: FrontendTarget) -> Result<()> { - println!(" {} Generating Cargo.lock...", ui::step()); - if !matches!(StdCommand::new("cargo").arg("--version").output(), Ok(o) if o.status.success()) { eprintln!( " {} cargo not found, skipping lockfile generation", @@ -170,6 +198,22 @@ fn generate_cargo_lockfile(dir: &Path, frontend: FrontendTarget) -> Result<()> { return Ok(()); } + generate_one_lockfile(dir, "Cargo.lock")?; + if frontend == FrontendTarget::Dioxus { + generate_one_lockfile(&dir.join("frontend"), "frontend/Cargo.lock")?; + } + + Ok(()) +} + +fn generate_one_lockfile(dir: &Path, label: &str) -> Result<()> { + if dir.join("Cargo.lock").exists() { + println!(" {} {label} already present, skipping", ui::ok()); + return Ok(()); + } + + println!(" {} Generating {label}...", ui::step()); + let output = StdCommand::new("cargo") .args(["generate-lockfile"]) .current_dir(dir) @@ -178,33 +222,14 @@ fn generate_cargo_lockfile(dir: &Path, frontend: FrontendTarget) -> Result<()> { if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); eprintln!( - " {} Failed to generate Cargo.lock: {}", + " {} Failed to generate {label}: {}", ui::warn(), stderr.trim() ); return Ok(()); } - println!(" {} Cargo.lock generated", ui::ok()); - - if frontend == FrontendTarget::Dioxus { - let output = StdCommand::new("cargo") - .args(["generate-lockfile"]) - .current_dir(dir.join("frontend")) - .output()?; - - if output.status.success() { - println!(" {} frontend/Cargo.lock generated", ui::ok()); - } else { - let stderr = String::from_utf8_lossy(&output.stderr); - eprintln!( - " {} Failed to generate frontend/Cargo.lock: {}", - ui::warn(), - stderr.trim() - ); - } - } - + println!(" {} {label} generated", ui::ok()); Ok(()) } @@ -520,6 +545,7 @@ impl NewCommand { let template = load_template_definition(template_id)?; let project_name = extract_project_name(&self.name); + validate_project_name(&project_name)?; let project_dir = self.output.as_ref().unwrap_or(&self.name); let path = Path::new(project_dir); @@ -982,6 +1008,56 @@ mod tests { assert!(!output.contains("1.0.0")); } + #[test] + fn validate_project_name_accepts_simple_names() { + assert!(validate_project_name("my-app").is_ok()); + assert!(validate_project_name("my_app").is_ok()); + assert!(validate_project_name("App1").is_ok()); + assert!(validate_project_name("a").is_ok()); + } + + #[test] + fn validate_project_name_rejects_path_traversal_and_separators() { + assert!(validate_project_name("..").is_err()); + assert!(validate_project_name("../etc").is_err()); + assert!(validate_project_name("../../etc").is_err()); + assert!(validate_project_name("/abs").is_err()); + assert!(validate_project_name("a/b").is_err()); + assert!(validate_project_name("~root").is_err()); + } + + #[test] + fn validate_project_name_rejects_shell_metacharacters() { + for bad in [ + "a;rm", "a&b", "a|b", "a`b`", "a$()", "a$x", "a b", "a\"b", "a'b", "a\nb", + ] { + assert!(validate_project_name(bad).is_err(), "should reject {bad:?}"); + } + } + + #[test] + fn validate_project_name_rejects_empty_and_whitespace() { + assert!(validate_project_name("").is_err()); + assert!(validate_project_name(" ").is_err()); + assert!(validate_project_name("\t").is_err()); + assert!(validate_project_name(" leading").is_err()); + assert!(validate_project_name("trailing ").is_err()); + } + + #[test] + fn validate_project_name_rejects_leading_hyphen_or_digit_underscore() { + // Leading hyphen would be parsed as a clap flag. + assert!(validate_project_name("-flag").is_err()); + // Leading underscore — must start with alphanumeric per the validator. + assert!(validate_project_name("_hidden").is_err()); + } + + #[test] + fn validate_project_name_rejects_unicode_only() { + assert!(validate_project_name("日本語").is_err()); + assert!(validate_project_name("café").is_err()); + } + #[test] fn test_invalid_template_error_lists_supported_templates() { let error = invalid_template_error("with-svelte/unknown"); diff --git a/crates/forge/src/cli/template_catalog.rs b/crates/forge/src/cli/template_catalog.rs index 0da33321..5d5ceceb 100644 --- a/crates/forge/src/cli/template_catalog.rs +++ b/crates/forge/src/cli/template_catalog.rs @@ -169,11 +169,17 @@ fn collect_directories(dir: &Dir<'_>, prefix: &Path, directories: &mut Vec bool { - if path == Path::new(entry) { + let entry_path = Path::new(entry); + if path == entry_path { return true; } - path.components() - .any(|component| component.as_os_str() == entry) + path.starts_with(entry_path) } diff --git a/crates/forge/src/cli/test.rs b/crates/forge/src/cli/test.rs index cbbc5045..871ff200 100644 --- a/crates/forge/src/cli/test.rs +++ b/crates/forge/src/cli/test.rs @@ -2,10 +2,13 @@ use anyhow::Result; use clap::Parser; use console::style; use std::net::TcpListener; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::process::Stdio; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; use tokio::process::Command; +use uuid::Uuid; use super::frontend_target::FrontendTarget; use super::ui; @@ -113,6 +116,9 @@ impl TestCommand { } async fn run_frontend_tests(&self) -> Result { + // Install a ctrl-c watcher that flips a shared shutdown flag. ContainerGuard's + // Drop fires on early return, but we still need to interrupt long-running waits. + // `tokio::select!` on `signal::ctrl_c()` in the wait loops covers that. let frontend_dir = Path::new("frontend"); if !frontend_dir.exists() { println!(); @@ -163,47 +169,56 @@ impl TestCommand { let db_name = read_db_name(); println!(" {} Starting PostgreSQL...", ui::step()); let (pg_container, pg_port) = start_postgres(&db_name).await?; + // Container guard now owns teardown. ANY return path (?, early bail, + // panic, SIGINT after select wakeup) triggers `docker rm -f`. + let (_pg_guard, pg_armed) = ContainerGuard::new(pg_container); let db_url = format!("postgres://postgres:forge@localhost:{pg_port}/{db_name}"); + println!( + " {} DATABASE_URL: {}", + ui::info(), + mask_database_url_for_display(&db_url) + ); - let binary = match build_project(frontend_type).await { - Ok(bin) => bin, - Err(e) => { - stop_postgres(&pg_container).await; - return Err(e); + // Race the rest of the flow against SIGINT so the container guard fires + // promptly instead of waiting for the Playwright child to drain. + let work = async { + let binary = build_project(frontend_type).await?; + let port = pick_random_port()?; + let app_url = format!("http://localhost:{port}"); + + println!(" {} Starting server on port {port}...", ui::step()); + let mut child = start_server(&binary, port, &db_url).await?; + + print!(" {} Waiting for server...", ui::step()); + if !wait_for_health(&app_url, Duration::from_secs(120)).await { + println!(" {}", style("timed out").red()); + kill_and_reap(&mut child).await; + anyhow::bail!( + "Server did not become healthy within 120s.\n\ + Check the binary output for errors." + ); } - }; + println!(" {}", style("ready").green()); - let port = pick_random_port()?; - let app_url = format!("http://localhost:{port}"); + let result = self.execute_frontend_tests(frontend_dir, &app_url).await; - println!(" {} Starting server on port {port}...", ui::step()); - let mut child = match start_server(&binary, port, &db_url).await { - Ok(child) => child, - Err(e) => { - stop_postgres(&pg_container).await; - return Err(e); - } + println!(); + println!(" {} Stopping server...", ui::step()); + kill_and_reap(&mut child).await; + result }; - print!(" {} Waiting for server...", ui::step()); - if !wait_for_health(&app_url, Duration::from_secs(120)).await { - println!(" {}", style("timed out").red()); - let _ = child.kill().await; - stop_postgres(&pg_container).await; - anyhow::bail!( - "Server did not become healthy within 120s.\n\ - Check the binary output for errors." - ); - } - println!(" {}", style("ready").green()); - - let result = self.execute_frontend_tests(frontend_dir, &app_url).await; - - println!(); - println!(" {} Stopping server...", ui::step()); - let _ = child.kill().await; - stop_postgres(&pg_container).await; + let result = tokio::select! { + r = work => r, + _ = tokio::signal::ctrl_c() => { + println!(); + println!(" {} SIGINT received, tearing down...", ui::warn()); + Err(anyhow::anyhow!("interrupted")) + } + }; + // Guard handles cleanup. Keep armed. + let _ = pg_armed; result } @@ -313,15 +328,46 @@ fn pick_random_port() -> Result { Ok(port) } +/// RAII guard that calls `docker rm -f` on the container when dropped. +/// +/// Best-effort: Drop runs synchronously and the runtime may already be torn +/// down, so we shell out blocking and ignore the result. Pairs with the +/// async-aware ctrl-c handler in `TestCommand::execute` to cover SIGINT. +struct ContainerGuard { + name: String, + armed: Arc, +} + +impl ContainerGuard { + fn new(name: String) -> (Self, Arc) { + let armed = Arc::new(AtomicBool::new(true)); + ( + Self { + name, + armed: armed.clone(), + }, + armed, + ) + } +} + +impl Drop for ContainerGuard { + fn drop(&mut self) { + if !self.armed.load(Ordering::SeqCst) { + return; + } + // Use blocking std::process — the tokio runtime might be gone by now. + let _ = std::process::Command::new("docker") + .args(["rm", "-f", &self.name]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status(); + } +} + async fn start_postgres(db_name: &str) -> Result<(String, u16)> { - let container_name = format!( - "forge-test-pg-{}-{}", - std::process::id(), - std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .map(|d| d.as_secs()) - .unwrap_or(0) - ); + // Uuid suffix avoids PID+epoch collisions on rapid reinvocation. + let container_name = format!("forge-test-pg-{}", Uuid::new_v4().simple()); let _ = Command::new("docker") .args(["rm", "-f", &container_name]) @@ -344,6 +390,9 @@ async fn start_postgres(db_name: &str) -> Result<(String, u16)> { &format!("POSTGRES_DB={db_name}"), "-p", "0:5432", + // Floating tag: pulls whatever `postgres:18` currently resolves to. + // Acceptable for ephemeral test containers; pin to a digest if you + // need reproducible CI builds across PG point releases. "postgres:18", ]) .stdout(Stdio::null()) @@ -400,13 +449,42 @@ async fn start_postgres(db_name: &str) -> Result<(String, u16)> { anyhow::bail!("PostgreSQL did not become ready within 30s") } -async fn stop_postgres(container_name: &str) { - let _ = Command::new("docker") - .args(["rm", "-f", container_name]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status() - .await; +/// Mask the password segment in a `postgres[ql]://user:password@host…` URL. +fn mask_database_url_for_display(url: &str) -> String { + let (scheme, rest) = match url.split_once("://") { + Some(p) => p, + None => return url.to_string(), + }; + let Some((userinfo, host)) = rest.rsplit_once('@') else { + return url.to_string(); + }; + let masked = match userinfo.split_once(':') { + Some((user, _pw)) => format!("{user}:***"), + None => userinfo.to_string(), + }; + format!("{scheme}://{masked}@{host}") +} + +/// Atomically replace a file via tempfile-in-parent + rename. Avoids leaving +/// the destination empty on crash between truncate and write. +fn atomic_write(path: &Path, contents: &[u8]) -> Result<()> { + let parent = path.parent().unwrap_or_else(|| Path::new(".")); + let tmp = tempfile::NamedTempFile::new_in(parent)?; + { + use std::io::Write; + tmp.as_file().write_all(contents)?; + tmp.as_file().sync_all()?; + } + let dest: PathBuf = path.to_path_buf(); + tmp.persist(&dest) + .map_err(|e| anyhow::anyhow!("atomic rename failed for {}: {}", path.display(), e))?; + Ok(()) +} + +/// Kill the child and wait for it so the OS reaps the zombie immediately. +async fn kill_and_reap(child: &mut tokio::process::Child) { + let _ = child.kill().await; + let _ = child.wait().await; } async fn build_project(frontend_type: Option) -> Result { @@ -420,6 +498,12 @@ async fn build_project(frontend_type: Option) -> Result) -> Result>() .join("\n"); - std::fs::write(frontend_env, patched)?; + atomic_write(frontend_env, patched.as_bytes())?; } // Build Dioxus WASM before cargo build: rust_embed requires real files in @@ -484,7 +568,7 @@ async fn build_project(frontend_type: Option) -> Result Result<()> { mod tests { use super::*; - fn default_cmd() -> TestCommand { - TestCommand { - skip_backend: false, - skip_frontend: false, - ui: false, - headed: false, - args: vec![], - } - } - - #[test] - fn test_command_default_runs_both() { - let cmd = default_cmd(); - assert!(!cmd.skip_backend); - assert!(!cmd.skip_frontend); - } - - #[test] - fn test_command_skip_backend() { - let cmd = TestCommand { - skip_backend: true, - ..default_cmd() - }; - assert!(cmd.skip_backend); - assert!(!cmd.skip_frontend); - } - - #[test] - fn test_command_skip_frontend() { - let cmd = TestCommand { - skip_frontend: true, - ..default_cmd() - }; - assert!(!cmd.skip_backend); - assert!(cmd.skip_frontend); - } - - #[test] - fn test_command_with_ui_and_args() { - let cmd = TestCommand { - ui: true, - args: vec!["tests/todo.spec.ts".into()], - ..default_cmd() - }; - assert!(cmd.ui); - assert_eq!(cmd.args.len(), 1); - } - - #[test] - fn test_command_headed() { - let cmd = TestCommand { - headed: true, - ..default_cmd() - }; - assert!(cmd.headed); - } - - #[test] - fn test_read_db_name_default() { - assert!(!read_db_name().is_empty()); - } - #[test] fn test_pick_random_port() { let port1 = pick_random_port().unwrap(); diff --git a/crates/forge/src/cli/webhook.rs b/crates/forge/src/cli/webhook.rs index 3068dce4..25104e79 100644 --- a/crates/forge/src/cli/webhook.rs +++ b/crates/forge/src/cli/webhook.rs @@ -26,8 +26,8 @@ struct ReplayArgs { webhook_name: String, /// Idempotency key of the event to replay. idempotency_key: String, - /// Base URL of the running forge server (default: http://localhost:3000). - #[arg(long, default_value = "http://localhost:3000")] + /// Base URL of the running forge server (default: http://localhost:9081). + #[arg(long, default_value = "http://localhost:9081")] base_url: String, } @@ -47,6 +47,12 @@ struct ListArgs { impl WebhookCommand { /// Execute the webhook subcommand. pub async fn execute(self) -> Result<()> { + // Webhook subcommands rely on cwd-relative `forge.toml`; anchor at project root. + if let Err(e) = super::project_root::enter_project_root() { + return Err(forge_core::ForgeError::config(format!( + "must be run from inside a forge project: {e}" + ))); + } match self.command { WebhookSubcommand::Replay(args) => replay(args).await, WebhookSubcommand::List(args) => list(args).await, @@ -54,10 +60,28 @@ impl WebhookCommand { } } +/// Mirror the runtime's URL resolution: prefer `DATABASE_URL` env var, then +/// fall back to `[database].url` from forge.toml. +fn resolve_database_url(config: &forge_core::config::ForgeConfig) -> Result { + if let Ok(url) = std::env::var("DATABASE_URL") + && !url.is_empty() + { + return Ok(url); + } + let url = config.database.url(); + if url.is_empty() { + return Err(forge_core::ForgeError::config( + "no database URL configured: set DATABASE_URL or [database].url in forge.toml", + )); + } + Ok(url.to_string()) +} + #[allow(clippy::disallowed_methods)] async fn replay(args: ReplayArgs) -> Result<()> { let config = forge_core::config::ForgeConfig::from_file("forge.toml")?; - let pool = sqlx::PgPool::connect(&config.database.url) + let db_url = resolve_database_url(&config)?; + let pool = sqlx::PgPool::connect(&db_url) .await .map_err(forge_core::ForgeError::Database)?; @@ -103,17 +127,6 @@ async fn replay(args: ReplayArgs) -> Result<()> { body.len() ); - // Delete the existing idempotency record so the replay isn't rejected - sqlx::query( - "DELETE FROM forge_webhook_events \ - WHERE webhook_name = $1 AND idempotency_key = $2", - ) - .bind(&args.webhook_name) - .bind(&args.idempotency_key) - .execute(&pool) - .await - .map_err(forge_core::ForgeError::Database)?; - let client = reqwest::Client::new(); let mut request = client.post(format!( "{}/webhooks/{}", @@ -142,6 +155,28 @@ async fn replay(args: ReplayArgs) -> Result<()> { let status_code = response.status(); let reason = status_code.canonical_reason().unwrap_or_default(); println!("Response: {} {}", status_code.as_u16(), reason); + + // Only clear the dedup record on 2xx so a failed replay doesn't allow + // the original event to silently re-execute as a brand-new delivery. + if status_code.is_success() { + if let Err(e) = sqlx::query( + "DELETE FROM forge_webhook_events \ + WHERE webhook_name = $1 AND idempotency_key = $2", + ) + .bind(&args.webhook_name) + .bind(&args.idempotency_key) + .execute(&pool) + .await + { + eprintln!("Warning: replay succeeded but failed to clear dedup record: {e}"); + } + } else { + eprintln!( + "Replay did not succeed (HTTP {}); dedup record preserved.", + status_code.as_u16() + ); + } + let body = response .text() .await @@ -156,7 +191,8 @@ async fn replay(args: ReplayArgs) -> Result<()> { #[allow(clippy::disallowed_methods)] async fn list(args: ListArgs) -> Result<()> { let config = forge_core::config::ForgeConfig::from_file("forge.toml")?; - let pool = sqlx::PgPool::connect(&config.database.url) + let db_url = resolve_database_url(&config)?; + let pool = sqlx::PgPool::connect(&db_url) .await .map_err(forge_core::ForgeError::Database)?; @@ -200,14 +236,15 @@ async fn list(args: ListArgs) -> Result<()> { ); for (webhook, key, status, processed_at, has_body) in &rows { let replay = if *has_body { "yes" } else { "no" }; + let key_display: String = if key.chars().count() > 28 { + key.chars().take(28).collect() + } else { + key.clone() + }; println!( "{:<20} {:<30} {:<10} {:<24} {}", webhook, - if key.len() > 28 { - key.get(..28).unwrap_or_default() - } else { - key.as_str() - }, + key_display, status, processed_at.format("%Y-%m-%d %H:%M:%S UTC"), replay diff --git a/crates/forge/src/runtime/builder.rs b/crates/forge/src/runtime/builder.rs index 29a56f6c..f219d857 100644 --- a/crates/forge/src/runtime/builder.rs +++ b/crates/forge/src/runtime/builder.rs @@ -57,6 +57,9 @@ pub struct ForgeBuilder { pub(super) frontend_handler: Option, #[cfg(feature = "gateway")] pub(super) custom_routes_factory: Option Router + Send + Sync>>, + /// Deferred error from `auto_register()` so the builder stays chainable. + /// Surfaced from `build()`. + pub(super) auto_register_error: Option, } impl ForgeBuilder { @@ -84,6 +87,7 @@ impl ForgeBuilder { frontend_handler: None, #[cfg(feature = "gateway")] custom_routes_factory: None, + auto_register_error: None, } } @@ -194,7 +198,9 @@ impl ForgeBuilder { #[cfg(feature = "gateway")] mcp_tools: std::mem::take(&mut self.mcp_registry), }; - crate::auto_register::auto_register_all(&mut registries); + if let Err(e) = crate::auto_register::auto_register_all(&mut registries) { + self.auto_register_error = Some(e); + } self.function_registry = registries.functions; #[cfg(feature = "jobs")] { @@ -322,6 +328,10 @@ impl ForgeBuilder { } pub fn build(self) -> Result { + if let Some(err) = self.auto_register_error { + return Err(err); + } + let config = self .config .ok_or_else(|| ForgeError::config("Configuration is required"))?; diff --git a/crates/forge/src/runtime/mod.rs b/crates/forge/src/runtime/mod.rs index dd763e3e..bf84a172 100644 --- a/crates/forge/src/runtime/mod.rs +++ b/crates/forge/src/runtime/mod.rs @@ -5,7 +5,7 @@ pub use builder::ForgeBuilder; #[cfg(feature = "gateway")] use std::future::Future; -use std::net::IpAddr; +use std::net::{IpAddr, Ipv4Addr}; use std::path::PathBuf; #[cfg(feature = "gateway")] use std::pin::Pin; @@ -47,7 +47,7 @@ use forge_runtime::pg::{LeaderConfig, LeaderElection, PgNotifyBus}; use forge_core::CircuitBreakerClient; #[cfg(feature = "gateway")] use forge_runtime::gateway::{ - AuthConfig, GatewayConfig as RuntimeGatewayConfig, GatewayServer, TlsListenConfig, + AuthConfig, GatewayConfig as RuntimeGatewayConfig, GatewayServer, PeerAddr, TlsListenConfig, bind_listener, }; #[cfg(feature = "jobs")] @@ -282,9 +282,9 @@ impl Forge { // HOST env var overrides bind address; PORT env var overrides config port. let ip_address: IpAddr = std::env::var("HOST") - .unwrap_or_else(|_| "0.0.0.0".to_string()) - .parse() - .unwrap_or_else(|_| "0.0.0.0".parse().expect("valid IP literal")); + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(IpAddr::V4(Ipv4Addr::UNSPECIFIED)); if let Ok(port_str) = std::env::var("PORT") && let Ok(port) = port_str.parse::() @@ -873,7 +873,8 @@ impl Forge { gateway = gateway .with_signals_collector(collector) .with_signals_anonymize_ip(self.config.signals.anonymize_ip) - .with_signals_geoip(geoip); + .with_signals_geoip(geoip) + .with_signals_rate_limit_per_minute(self.config.signals.rate_limit_per_minute); forge_runtime::signals::session::spawn_session_reaper( signals_pool.clone(), @@ -1054,7 +1055,17 @@ impl Forge { return; } }; - let serve = axum::serve(listener, router).with_graceful_shutdown(async move { + // Serve with per-connection peer address so downstream + // middleware can resolve the real client IP. Without this the + // router's default `into_make_service` omits `ConnectInfo`, + // leaving every client IP unresolved — which collapses per-IP + // rate-limit buckets, blanks signal visitor IDs, and breaks the + // IP-bound SSE auth ticket. Mirrors `GatewayServer::run`. + let serve = axum::serve( + listener, + router.into_make_service_with_connect_info::(), + ) + .with_graceful_shutdown(async move { let _ = gateway_shutdown_rx.wait_for(|v| *v).await; tracing::debug!("Gateway draining in-flight requests"); }); From 4bb98d1dd62174ffe920657440c4066816c81414 Mon Sep 17 00:00:00 2001 From: Isala Piyarisi Date: Wed, 27 May 2026 06:20:59 +0530 Subject: [PATCH 4/7] harden svelte and dioxus frontend runtimes Re-register job/workflow subscriptions and reconnect SSE on auth change, settle the session before subscribing, set the connected-token hash before resolving connect, native SSE jitter, and web-vitals beacon as a JSON blob. --- packages/forge-dioxus/Cargo.lock | 2 +- packages/forge-dioxus/src/auth.rs | 51 +++- packages/forge-dioxus/src/client.rs | 123 +++++++-- packages/forge-dioxus/src/hooks.rs | 11 +- packages/forge-dioxus/src/signals.rs | 79 +++++- packages/forge-svelte/ForgeProvider.svelte | 3 + packages/forge-svelte/bun.lock | 53 ++++ packages/forge-svelte/client.ts | 290 ++++++++++++++++++--- packages/forge-svelte/context.ts | 12 + packages/forge-svelte/signals.ts | 23 +- packages/forge-svelte/stores.ts | 118 ++++++--- 11 files changed, 651 insertions(+), 114 deletions(-) create mode 100644 packages/forge-svelte/bun.lock diff --git a/packages/forge-dioxus/Cargo.lock b/packages/forge-dioxus/Cargo.lock index a8097683..4be11d96 100644 --- a/packages/forge-dioxus/Cargo.lock +++ b/packages/forge-dioxus/Cargo.lock @@ -501,7 +501,7 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "forge-dioxus" -version = "0.9.0" +version = "0.10.1" dependencies = [ "dioxus", "dirs", diff --git a/packages/forge-dioxus/src/auth.rs b/packages/forge-dioxus/src/auth.rs index bf806c61..a9891aac 100644 --- a/packages/forge-dioxus/src/auth.rs +++ b/packages/forge-dioxus/src/auth.rs @@ -261,6 +261,21 @@ pub fn ForgeAuthProvider( // Must follow use_context_provider above so ForgeClient is available let client: ForgeClient = use_context(); + // Reconnect the SSE stream on every auth transition (login/logout) so + // existing subscriptions re-register over a session carrying the current + // token. Token *refresh* already reconnects (see `try_refresh_tokens`); + // without this, a manual login after the anonymous SSE handshake leaves + // queries subscribed on the anonymous session, where they keep returning + // UNAUTHORIZED and never deliver authenticated data. `generation` only bumps + // on actual login/logout transitions, and `reconnect_sse` is a no-op until + // subscriptions exist, so the initial mount costs nothing. + let client_for_auth = client.clone(); + let auth_generation = generation; + use_effect(move || { + auth_generation.read(); + client_for_auth.reconnect_sse(); + }); + let url_for_refresh = url.clone(); let client_for_refresh = client.clone(); use_effect(move || { @@ -375,6 +390,13 @@ async fn sleep(secs: u64) { tokio::time::sleep(std::time::Duration::from_secs(secs)).await; } +/// Browser-side auth storage. +/// +/// SECURITY: refresh tokens are stored in `sessionStorage` rather than +/// `localStorage` so they don't outlive the browsing context. Any XSS in the +/// host page can still read them within the same tab — for production, +/// configure the server to issue refresh tokens as `HttpOnly; Secure; +/// SameSite=Strict` cookies and stop sending them in the JSON body. #[cfg(target_arch = "wasm32")] mod storage { use super::StoredAuth; @@ -386,7 +408,7 @@ mod storage { pub fn save(app_name: &str, auth: &StoredAuth) { if let Ok(json) = serde_json::to_string(auth) { if let Some(storage) = web_sys::window() - .and_then(|w| w.local_storage().ok()) + .and_then(|w| w.session_storage().ok()) .flatten() { let _ = storage.set_item(&key(app_name), &json); @@ -395,14 +417,14 @@ mod storage { } pub fn load(app_name: &str) -> Option { - let storage = web_sys::window()?.local_storage().ok()??; + let storage = web_sys::window()?.session_storage().ok()??; let json = storage.get_item(&key(app_name)).ok()??; serde_json::from_str(&json).ok() } pub fn clear(app_name: &str) { if let Some(storage) = web_sys::window() - .and_then(|w| w.local_storage().ok()) + .and_then(|w| w.session_storage().ok()) .flatten() { let _ = storage.remove_item(&key(app_name)); @@ -416,8 +438,29 @@ mod storage { use std::fs; use std::path::PathBuf; + /// Strip path separators and other unsafe characters so a caller-supplied + /// `app_name` cannot escape `data_local_dir()` via `..` or absolute paths. + fn sanitize_app_name(app_name: &str) -> String { + let cleaned: String = app_name + .chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' || c == '_' { + c + } else { + '_' + } + }) + .collect(); + if cleaned.is_empty() { + "forge_app".to_string() + } else { + cleaned + } + } + fn storage_path(app_name: &str) -> Option { - dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json")) + let safe = sanitize_app_name(app_name); + dirs::data_local_dir().map(|base| base.join(safe).join("auth.json")) } pub fn save(app_name: &str, auth: &StoredAuth) { diff --git a/packages/forge-dioxus/src/client.rs b/packages/forge-dioxus/src/client.rs index 6bdeb166..1e9a7210 100644 --- a/packages/forge-dioxus/src/client.rs +++ b/packages/forge-dioxus/src/client.rs @@ -7,6 +7,8 @@ use std::rc::Rc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; +use futures_channel::oneshot; + use dioxus::prelude::{Signal, WritableExt, dioxus_core::Task}; use serde::Serialize; use serde::de::DeserializeOwned; @@ -58,8 +60,6 @@ struct SseManager { connect_waiters: Vec, } -const MAX_RECONNECT_ATTEMPTS: u32 = 10; - #[derive(Clone)] #[non_exhaustive] pub struct ForgeClientConfig { @@ -89,9 +89,12 @@ impl ForgeClientConfig { } /// Register an async callback invoked when an RPC call returns UNAUTHORIZED. - /// The callback should refresh the access token and return the new one. If - /// it returns `Some`, the original call is retried once with the new token. - /// If it returns `None`, the call fails normally. + /// The callback must refresh the access token AND persist it where the + /// `get_token` provider reads from (typically a `Signal`) before + /// resolving. The returned `Option` is treated as success/failure; + /// the client does not inject it directly. On `Some` the original call is + /// retried once (which calls `get_token` again). On `None` the call fails. + /// Concurrent 401s are coalesced into a single refresh attempt. pub fn with_refresh_token_provider(mut self, provider: F) -> Self where F: Fn() -> Fut + 'static, @@ -138,6 +141,9 @@ struct ForgeClientInner { connection_state: Option>, sse: RefCell, signals: RefCell>, + /// Coalesces concurrent 401 refresh attempts. While `Some`, in-flight + /// callers should subscribe via oneshot instead of firing another refresh. + refresh_waiters: RefCell>>>, } impl ForgeClient { @@ -152,6 +158,7 @@ impl ForgeClient { connection_state: config.connection_state, sse: RefCell::new(SseManager::default()), signals: RefCell::new(None), + refresh_waiters: RefCell::new(None), }), } } @@ -206,7 +213,44 @@ impl ForgeClient { let Some(provider) = self.inner.refresh_token.clone() else { return false; }; - provider().await.is_some() + + // Coalesce: if a refresh is already in flight, wait for its result + // instead of rotating the refresh token again. + let (rx, leader) = { + let mut slot = self.inner.refresh_waiters.borrow_mut(); + match slot.as_mut() { + Some(waiters) => { + let (tx, rx) = oneshot::channel(); + waiters.push(tx); + (Some(rx), false) + } + None => { + *slot = Some(Vec::new()); + (None, true) + } + } + }; + + if !leader { + return rx + .expect("follower must have a receiver") + .await + .unwrap_or(false); + } + + // The provider's `Option` return signals success/failure only. + // The provider closure MUST install the new token via its own state + // (e.g. updating a Signal that backs `get_token`) before resolving, + // since `get_token` is read again on the retry. Returning `Some` + // without persisting the token will cause the retry to use the + // stale token. + let success = provider().await.is_some(); + + let waiters = self.inner.refresh_waiters.borrow_mut().take().unwrap_or_default(); + for w in waiters { + let _ = w.send(success); + } + success } #[cfg(target_arch = "wasm32")] @@ -637,16 +681,17 @@ impl ForgeClient { } } + /// Returns the current attempt count for backoff calculation. Retries + /// indefinitely while there are listeners — long-lived apps need an + /// always-on SSE pipe, not a hard 10-attempt giveup. Backoff is capped + /// by the caller via `attempts.min(N)`. fn should_reconnect(&self) -> Option { let mut sse = self.inner.sse.borrow_mut(); if sse.listeners.is_empty() { return None; } let attempts = sse.reconnect_attempts; - if attempts >= MAX_RECONNECT_ATTEMPTS { - return None; - } - sse.reconnect_attempts = attempts + 1; + sse.reconnect_attempts = attempts.saturating_add(1); Some(attempts) } @@ -658,6 +703,12 @@ impl ForgeClient { .filter(|t| !t.is_empty()) } + /// Crate-internal accessor for the current access token, used by signals + /// so analytics calls carry the user's identity. + pub(crate) fn auth_token(&self) -> Option { + self.get_token() + } + fn decode_envelope( &self, envelope: RpcEnvelopeRaw, @@ -806,10 +857,14 @@ mod platform { body: serde_json::Value, correlation_id: Option<&str>, ) -> Result { + // X-Forge-CSRF: custom header forces a CORS preflight on cross-origin + // POSTs so the server's CORS allowlist gates cross-site requests + // despite `credentials: include`. let mut request = Request::post(url) .header("Content-Type", "application/json") .header("Accept", "application/vnd.forge.v1+json") .header("x-forge-platform", platform_tag()) + .header("X-Forge-CSRF", "1") .credentials(web_sys::RequestCredentials::Include); if let Some(token) = client.get_token() { request = request.header("Authorization", &format!("Bearer {token}")); @@ -836,6 +891,7 @@ mod platform { ) -> Result { let mut request = Request::post(url) .header("x-forge-platform", platform_tag()) + .header("X-Forge-CSRF", "1") .credentials(web_sys::RequestCredentials::Include); if let Some(token) = client.get_token() { request = request.header("Authorization", &format!("Bearer {token}")); @@ -864,14 +920,32 @@ mod platform { }) } - fn events_url(client: &ForgeClient) -> String { - match client.get_token() { - Some(token) => format!( - "{}/_api/events?token={}", - client.inner.url, - encode_uri_component(&token) - ), - None => format!("{}/_api/events", client.inner.url), + /// Build the SSE URL. When a bearer token is present, mint a short-lived + /// single-use ticket via `POST /_api/events/ticket` and put it in the + /// query string. The JWT itself never appears in the URL — query + /// strings leak into access logs, browser history, and Referer headers. + /// Anonymous connections skip the ticket fetch. + async fn events_url(client: &ForgeClient) -> String { + let base = format!("{}/_api/events", client.inner.url); + let Some(token) = client.get_token() else { + return base; + }; + let ticket_url = format!("{}/ticket", base); + let request = Request::post(&ticket_url).header("Authorization", &format!("Bearer {token}")); + let response = match request.send().await { + Ok(r) => r, + Err(_) => return base, + }; + if !response.ok() { + return base; + } + #[derive(serde::Deserialize)] + struct TicketResponse { + ticket: String, + } + match response.json::().await { + Ok(body) => format!("{}?ticket={}", base, encode_uri_component(&body.ticket)), + Err(_) => base, } } @@ -914,7 +988,8 @@ mod platform { /// Returns `true` if the connection was established at some point. async fn run_event_loop(client: &ForgeClient) -> bool { - let mut event_source = match EventSource::new(&events_url(client)) { + let url = events_url(client).await; + let mut event_source = match EventSource::new(&url) { Ok(source) => source, Err(_) => { return false; @@ -1030,6 +1105,7 @@ mod platform { .post(url) .header("Accept", "application/vnd.forge.v1+json") .header("x-forge-platform", platform_tag()) + .header("X-Forge-CSRF", "1") .json(&body); if let Some(token) = client.get_token() { request = request.bearer_auth(token); @@ -1056,6 +1132,7 @@ mod platform { let mut request = Client::new() .post(url) .header("x-forge-platform", platform_tag()) + .header("X-Forge-CSRF", "1") .multipart(form); if let Some(token) = client.get_token() { request = request.bearer_auth(token); @@ -1096,7 +1173,13 @@ mod platform { if let Some(attempts) = client_for_task.should_reconnect() { let delay = 1000 * (1u64 << attempts.min(4)); - sleep(std::time::Duration::from_millis(delay)).await; + // Cheap jitter from wall-clock subnanos so two desktop apps + // started off the same restart cycle don't synchronize retries. + let jitter = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| (d.subsec_nanos() as u64) % 500) + .unwrap_or(0); + sleep(std::time::Duration::from_millis(delay + jitter)).await; client_for_task.inner.sse.borrow_mut().state = super::SseState::Connecting; start_event_loop(client_for_task); diff --git a/packages/forge-dioxus/src/hooks.rs b/packages/forge-dioxus/src/hooks.rs index 065e4cec..96416212 100644 --- a/packages/forge-dioxus/src/hooks.rs +++ b/packages/forge-dioxus/src/hooks.rs @@ -222,9 +222,16 @@ where }); } StreamEvent::Error(err) => { - // Suppress errors during reconnect attempts to avoid UI churn. + // Definitive auth/permission errors should surface immediately + // — no amount of retrying will fix them, and the UI needs to + // redirect to login. Transient errors are suppressed during + // reconnect attempts (1..10) to avoid UI churn. + let definitive = matches!( + err.code.as_str(), + "UNAUTHORIZED" | "FORBIDDEN" | "NOT_FOUND" | "VALIDATION_ERROR" + ); let attempts = reconnect_attempts.get(); - if attempts > 0 && attempts < 10 { + if !definitive && attempts > 0 && attempts < 10 { return; } let mut next = state.peek().clone(); diff --git a/packages/forge-dioxus/src/signals.rs b/packages/forge-dioxus/src/signals.rs index 3117bcf5..d2ac472b 100644 --- a/packages/forge-dioxus/src/signals.rs +++ b/packages/forge-dioxus/src/signals.rs @@ -86,17 +86,27 @@ export function forge_install_web_vitals(baseUrl, getSessionId) { timestamp: new Date().toISOString(), }], context: { - page_url: location.href, + page_url: location.pathname, session_id: getSessionId() || null, } } }); const url = baseUrl + '/_api/signal'; - const headers = { 'Content-Type': 'application/json', 'x-forge-platform': 'web' }; + // sendBeacon can't set custom headers, but wrapping the body in a + // typed Blob makes the browser send Content-Type: application/json + // (a bare string defaults to text/plain, which the endpoint rejects + // with 415). Cookies are still inherited so identified users + // attribute correctly. The fetch fallback adds credentials + + // X-Forge-CSRF for CORS-gated cross-origin. + const headers = { + 'Content-Type': 'application/json', + 'x-forge-platform': 'web', + 'X-Forge-CSRF': '1', + }; if (navigator.sendBeacon) { - navigator.sendBeacon(url, body); + navigator.sendBeacon(url, new Blob([body], { type: 'application/json' })); } else { - fetch(url, { method: 'POST', headers: headers, body: body, keepalive: true }); + fetch(url, { method: 'POST', headers: headers, body: body, keepalive: true, credentials: 'include' }); } } catch (_) {} } @@ -575,11 +585,24 @@ impl ForgeSignals { } pub async fn page(&self, url_path: &str) { - let (base_url, session_id, utm) = { + let (base_url, session_id, token, utm) = { let mut inner = self.inner.borrow_mut(); if !inner.config.enabled { return; } + // Re-extract UTM on each navigation so SPA route changes that carry + // new utm_* params still propagate, but fall back to the landing + // params captured at construction: the deferred initial page view + // fires after the router has normalized the URL and dropped the + // query string, so a live re-extract would return nothing. + if let Some(fresh) = extract_utm() { + inner.utm_params = Some(fresh); + } let utm = inner.utm_params.take(); - (inner.client.get_url().to_string(), inner.session_id.clone(), utm) + ( + inner.client.get_url().to_string(), + inner.session_id.clone(), + inner.client.auth_token(), + utm, + ) }; let mut payload = json!({ "url": url_path }); @@ -592,7 +615,7 @@ impl ForgeSignals { } let wrapped = json!({ "type": "view", "payload": payload }); - if let Ok(resp) = post_signal(&base_url, "signal", &wrapped, session_id.as_deref()).await + if let Ok(resp) = post_signal(&base_url, "signal", &wrapped, session_id.as_deref(), token.as_deref()).await && let Some(sid) = resp.get("session_id").and_then(|v| v.as_str()) { self.inner.borrow_mut().session_id = Some(sid.to_string()); @@ -613,12 +636,13 @@ impl ForgeSignals { } async fn report_error(&self, error: SignalError, context: Option) { - let (url, session_id, correlation_id, breadcrumbs, page_url) = { + let (url, session_id, token, correlation_id, breadcrumbs, page_url) = { let inner = self.inner.borrow(); if !inner.config.enabled { return; } ( inner.client.get_url().to_string(), inner.session_id.clone(), + inner.client.auth_token(), inner.last_correlation_id.clone(), inner.breadcrumbs.clone(), current_page_url(), @@ -643,7 +667,7 @@ impl ForgeSignals { } }; let wrapped = json!({ "type": "report", "payload": payload }); - let _ = post_signal(&url, "signal", &wrapped, session_id.as_deref()).await; + let _ = post_signal(&url, "signal", &wrapped, session_id.as_deref(), token.as_deref()).await; } /// Add a breadcrumb for error reproduction context. @@ -672,13 +696,18 @@ impl ForgeSignals { } pub async fn flush(&self) { - let (url, mut events, session_id) = { + let (url, mut events, session_id, token) = { let mut inner = self.inner.borrow_mut(); if inner.queue.is_empty() { return; } let max = inner.config.max_batch_size; let count = inner.queue.len().min(max); let events: Vec<_> = inner.queue.drain(..count).collect(); - (inner.client.get_url().to_string(), events, inner.session_id.clone()) + ( + inner.client.get_url().to_string(), + events, + inner.session_id.clone(), + inner.client.auth_token(), + ) }; let batch = EventBatch { @@ -703,7 +732,7 @@ impl ForgeSignals { } }; let wrapped = json!({ "type": "event", "payload": payload }); - match post_signal(&url, "signal", &wrapped, session_id.as_deref()).await + match post_signal(&url, "signal", &wrapped, session_id.as_deref(), token.as_deref()).await { Ok(resp) => { if let Some(sid) = resp.get("session_id").and_then(|v| v.as_str()) { @@ -797,11 +826,13 @@ fn flush_beacon(signals: &ForgeSignals) { } } +/// Capture just the path (no querystring) so URL-borne secrets like +/// `?reset_token=…` or `?ssoToken=…` never reach the analytics store. fn current_page_url() -> Option { #[cfg(target_arch = "wasm32")] { web_sys::window() - .and_then(|w| w.location().href().ok()) + .and_then(|w| w.location().pathname().ok()) } #[cfg(not(target_arch = "wasm32"))] { @@ -862,16 +893,24 @@ async fn post_signal( path: &str, body: &Value, session_id: Option<&str>, + token: Option<&str>, ) -> Result { #[cfg(target_arch = "wasm32")] { use gloo_net::http::Request; + // X-Forge-CSRF forces a CORS preflight on cross-origin POSTs, gating + // credentialed cross-site requests via the server's CORS allowlist. let mut req = Request::post(&format!("{base_url}/_api/{path}")) .header("Content-Type", "application/json") - .header("x-forge-platform", platform_tag()); + .header("x-forge-platform", platform_tag()) + .header("X-Forge-CSRF", "1") + .credentials(web_sys::RequestCredentials::Include); if let Some(sid) = session_id { req = req.header("x-session-id", sid); } + if let Some(t) = token { + req = req.header("Authorization", &format!("Bearer {t}")); + } let resp = req.body(body.to_string()).map_err(|_| ())?.send().await.map_err(|_| ())?; resp.json().await.map_err(|_| ()) } @@ -881,10 +920,14 @@ async fn post_signal( let mut req = Client::new() .post(format!("{base_url}/_api/{path}")) .header("x-forge-platform", platform_tag()) + .header("X-Forge-CSRF", "1") .json(body); if let Some(sid) = session_id { req = req.header("x-session-id", sid); } + if let Some(t) = token { + req = req.bearer_auth(t); + } let resp = req.send().await.map_err(|_| ())?; resp.json().await.map_err(|_| ()) } @@ -894,6 +937,14 @@ pub fn use_signals() -> ForgeSignals { use_context::() } +/// LIMITATION: this function intentionally leaks each `Closure` it passes +/// to `addEventListener` via `.forget()`. Dioxus's WASM lifecycle does not +/// expose a "provider unmounted" hook we can wire teardown into, and these +/// listeners must outlive the borrowed `ForgeSignals`. The leak is one-shot +/// per `ForgeAuthProvider` mount (i.e. typically once per page lifetime), +/// not per render. If a future Dioxus version exposes a drop hook for +/// context providers, switch to storing `Closure`s in a guard struct and +/// removing the listeners on drop. #[cfg(target_arch = "wasm32")] pub(crate) fn setup_auto_capture(signals: ForgeSignals) { use wasm_bindgen::closure::Closure; diff --git a/packages/forge-svelte/ForgeProvider.svelte b/packages/forge-svelte/ForgeProvider.svelte index dc8cbbf0..9a1b1717 100644 --- a/packages/forge-svelte/ForgeProvider.svelte +++ b/packages/forge-svelte/ForgeProvider.svelte @@ -49,6 +49,9 @@ } authState.loading = false; }).catch(() => { + // Tear down so the client doesn't sit in a half-initialized state with + // queued subscriptionMeta ready to fire stale registrations on remount. + try { client.disconnect(); } catch { /* idempotent */ } authState.loading = false; }); diff --git a/packages/forge-svelte/bun.lock b/packages/forge-svelte/bun.lock new file mode 100644 index 00000000..c4f2b114 --- /dev/null +++ b/packages/forge-svelte/bun.lock @@ -0,0 +1,53 @@ +{ + "lockfileVersion": 1, + "configVersion": 1, + "workspaces": { + "": { + "name": "@forge-rs/svelte", + "peerDependencies": { + "svelte": "^5.0.0", + }, + }, + }, + "packages": { + "@jridgewell/gen-mapping": ["@jridgewell/gen-mapping@0.3.13", "", { "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.0", "@jridgewell/trace-mapping": "^0.3.24" } }, "sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA=="], + + "@jridgewell/remapping": ["@jridgewell/remapping@2.3.5", "", { "dependencies": { "@jridgewell/gen-mapping": "^0.3.5", "@jridgewell/trace-mapping": "^0.3.24" } }, "sha512-LI9u/+laYG4Ds1TDKSJW2YPrIlcVYOwi2fUC6xB43lueCjgxV4lffOCZCtYFiH6TNOX+tQKXx97T4IKHbhyHEQ=="], + + "@jridgewell/resolve-uri": ["@jridgewell/resolve-uri@3.1.2", "", {}, "sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw=="], + + "@jridgewell/sourcemap-codec": ["@jridgewell/sourcemap-codec@1.5.5", "", {}, "sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og=="], + + "@jridgewell/trace-mapping": ["@jridgewell/trace-mapping@0.3.31", "", { "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", "@jridgewell/sourcemap-codec": "^1.4.14" } }, "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw=="], + + "@sveltejs/acorn-typescript": ["@sveltejs/acorn-typescript@1.0.10", "", { "peerDependencies": { "acorn": "^8.9.0" } }, "sha512-4WfKk68eTih+MiJD4fSbxN7E8kVBmTMPWHUPYjvl2N0rMs53YLTT8/YjKU5Dtnz5LqDjl7LEw4U7lXR2W3J5WA=="], + + "@types/estree": ["@types/estree@1.0.9", "", {}, "sha512-GhdPgy1el4/ImP05X05Uw4cw2/M93BCUmnEvWZNStlCzEKME4Fkk+YpoA5OiHNQmoS7Cafb8Xa3Pya8m1Qrzeg=="], + + "@types/trusted-types": ["@types/trusted-types@2.0.7", "", {}, "sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw=="], + + "acorn": ["acorn@8.16.0", "", { "bin": { "acorn": "bin/acorn" } }, "sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw=="], + + "aria-query": ["aria-query@5.3.1", "", {}, "sha512-Z/ZeOgVl7bcSYZ/u/rh0fOpvEpq//LZmdbkXyc7syVzjPAhfOa9ebsdTSjEBDU4vs5nC98Kfduj1uFo0qyET3g=="], + + "axobject-query": ["axobject-query@4.1.0", "", {}, "sha512-qIj0G9wZbMGNLjLmg1PT6v2mE9AH2zlnADJD/2tC6E00hgmhUOfEB6greHPAfLRSufHqROIUTkw6E+M3lH0PTQ=="], + + "clsx": ["clsx@2.1.1", "", {}, "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA=="], + + "devalue": ["devalue@5.8.1", "", {}, "sha512-4CXDYRBGqN+57wVJkuXBYmpAVUSg3L6JAQa/DFqm238G73E1wuyc/JhGQJzN7vUf/CMphYau2zXbfWzDR5aTEw=="], + + "esm-env": ["esm-env@1.2.2", "", {}, "sha512-Epxrv+Nr/CaL4ZcFGPJIYLWFom+YeV1DqMLHJoEd9SYRxNbaFruBwfEX/kkHUJf55j2+TUbmDcmuilbP1TmXHA=="], + + "esrap": ["esrap@2.2.9", "", { "dependencies": { "@jridgewell/sourcemap-codec": "^1.4.15" }, "peerDependencies": { "@typescript-eslint/types": "^8.2.0" }, "optionalPeers": ["@typescript-eslint/types"] }, "sha512-4KijP+NxCWthMCUC3qHbE6n4vCjqgJS1uAYKhuT/GWfFTf1Qyive2TgOjep+gzbSzRfnNyaN/UU9YmdOt8Eg0A=="], + + "is-reference": ["is-reference@3.0.3", "", { "dependencies": { "@types/estree": "^1.0.6" } }, "sha512-ixkJoqQvAP88E6wLydLGGqCJsrFUnqoH6HnaczB8XmDH1oaWU+xxdptvikTgaEhtZ53Ky6YXiBuUI2WXLMCwjw=="], + + "locate-character": ["locate-character@3.0.0", "", {}, "sha512-SW13ws7BjaeJ6p7Q6CO2nchbYEc3X3J6WrmTTDto7yMPqVSZTUyY5Tjbid+Ab8gLnATtygYtiDIJGQRRn2ZOiA=="], + + "magic-string": ["magic-string@0.30.21", "", { "dependencies": { "@jridgewell/sourcemap-codec": "^1.5.5" } }, "sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ=="], + + "svelte": ["svelte@5.55.9", "", { "dependencies": { "@jridgewell/remapping": "^2.3.4", "@jridgewell/sourcemap-codec": "^1.5.0", "@sveltejs/acorn-typescript": "^1.0.10", "@types/estree": "^1.0.5", "@types/trusted-types": "^2.0.7", "acorn": "^8.12.1", "aria-query": "5.3.1", "axobject-query": "^4.1.0", "clsx": "^2.1.1", "devalue": "^5.8.1", "esm-env": "^1.2.1", "esrap": "^2.2.9", "is-reference": "^3.0.3", "locate-character": "^3.0.0", "magic-string": "^0.30.11", "zimmerframe": "^1.1.2" } }, "sha512-fTjjT8cHLDwigcu2j3pv7Jq04LklXevPB8uBgyHNiTXv+RMNvVnrjS4UEYrLMkhuq1vpCodHjiW+z/95SDs/fg=="], + + "zimmerframe": ["zimmerframe@1.1.4", "", {}, "sha512-B58NGBEoc8Y9MWWCQGl/gq9xBCe4IiKM0a2x7GZdQKOW5Exr8S1W24J6OgM1njK8xCRGvAJIL/MxXHf6SkmQKQ=="], + } +} diff --git a/packages/forge-svelte/client.ts b/packages/forge-svelte/client.ts index d49fd063..5967118c 100644 --- a/packages/forge-svelte/client.ts +++ b/packages/forge-svelte/client.ts @@ -8,6 +8,8 @@ export interface ForgeClientConfig { refreshToken?: () => Promise; onAuthError?: (error: ForgeError) => void; onMutationError?: (error: ForgeClientError) => void; + /** Opt-in diagnostic channel for non-fatal events (e.g. reconnect failures). */ + onDebug?: (message: string) => void; timeout?: number; } @@ -33,7 +35,6 @@ interface SsePayload { session_id?: string; session_secret?: string; channel?: string; - last_event_id?: string; } export class ForgeClientError extends Error implements ForgeError { @@ -69,6 +70,12 @@ export class ForgeClient { private connectionListeners = new Set<(state: ConnectionState) => void>(); private subscriptions = new Map void>(); private subscriptionMeta = new Map(); + // Job/workflow subscriptions keyed by client_sub_id -> job/workflow id. Like + // query subscriptions, these must be re-registered on the new SSE session + // after a reconnect, or their server-side entries stay bound to the abandoned + // session and progress/status pushes are silently lost. + private jobMeta = new Map(); + private workflowMeta = new Map(); private reconnectAttempts = 0; private maxReconnectAttempts = 10; private maxSubscriptionRetries = 3; @@ -90,8 +97,26 @@ export class ForgeClient { this.signals = signals; } - private hashToken(token: string | null): string | null { - return token ? token.substring(0, 20) : null; + /** + * Stable fingerprint of a token used to detect rotation. SHA-256 over the + * whole token avoids the trap where JWTs from the same signing config + * share the first ~37 characters (header is constant); a prefix-based + * hash would miss rotation entirely. + */ + private async hashToken(token: string | null): Promise { + if (!token) return null; + if (typeof crypto !== "undefined" && crypto.subtle) { + const bytes = new TextEncoder().encode(token); + const digest = await crypto.subtle.digest("SHA-256", bytes); + const view = new Uint8Array(digest); + let hex = ""; + for (const b of view) hex += b.toString(16).padStart(2, "0"); + return hex; + } + // Fallback for non-secure contexts: take the suffix instead of the + // prefix, which (unlike the prefix) varies between JWTs with identical + // headers/payloads in the same signing slot. + return token.length > 16 ? token.slice(-16) : token; } getUrl(): string { @@ -124,13 +149,35 @@ export class ForgeClient { if (currentConnectionId !== this.connectionId) return; } - // Token must be resolved before EventSource is created (sent as query param) + // Resolve token to either (a) mint a short-lived single-use SSE ticket + // (authenticated streams) or (b) connect anonymously. The bearer JWT + // never appears in the URL — query strings leak into access logs, + // browser history, and Referer headers. const token = await this.getToken(); if (currentConnectionId !== this.connectionId) return; + let ticket: string | null = null; + if (token) { + try { + const res = await fetch(`${this.config.url}/_api/events/ticket`, { + method: "POST", + headers: { Authorization: `Bearer ${token}` }, + credentials: "include", + }); + if (res.ok) { + const body = (await res.json()) as { ticket?: string }; + ticket = body.ticket ?? null; + } + } catch { + // Network failure here is non-fatal; we'll fall through to anonymous + // connect and let the reconnect loop retry. + } + if (currentConnectionId !== this.connectionId) return; + } + const params = new URLSearchParams(); - if (token) params.set("token", token); + if (ticket) params.set("ticket", ticket); const sseUrl = `${this.config.url}/_api/events${params.toString() ? `?${params}` : ""}`; @@ -165,13 +212,26 @@ export class ForgeClient { const data = JSON.parse((e as MessageEvent).data) as SsePayload; this.sessionId = data.session_id ?? null; this.sessionSecret = data.session_secret ?? null; - this.connectedTokenHash = this.hashToken(token); this.setConnectionState("connected"); this.reconnectAttempts = 0; this.hasConnectedBefore = true; - this.reregisterSubscriptions(); - - resolve(); + // Finish wiring up this session BEFORE resolving connect(): + // 1. Set `connectedTokenHash` synchronously-before-resolve. It was + // previously set in a detached `.then`, so a `call()` running right + // after a reconnect could observe a stale hash and reconnect AGAIN, + // stranding just-restored subscriptions on the abandoned session + // (their pushes then land in a channel no client reads). + // 2. Re-register subscriptions, so callers awaiting connect()/ + // reconnect() (notably call() after a login) don't fire RPCs before + // the subscriptions exist on this session. + void (async () => { + try { + this.connectedTokenHash = await this.hashToken(token); + await this.reregisterSubscriptions(); + } finally { + resolve(); + } + })(); }); this.addEventSourceListener("update", (e) => { @@ -245,6 +305,8 @@ export class ForgeClient { this.setConnectionState("disconnected"); this.subscriptions.clear(); this.subscriptionMeta.clear(); + this.jobMeta.clear(); + this.workflowMeta.clear(); } async reconnect(): Promise { @@ -293,11 +355,20 @@ export class ForgeClient { async call(functionName: string, args: unknown): Promise { const token = await this.getToken(); + // If a reconnect is already in flight (e.g. one kicked off right after a + // login), wait for it to finish restoring subscriptions before issuing the + // RPC — otherwise the reactive update this call triggers can be pushed to a + // subscription that hasn't been re-registered on the new session yet. + if (this.reconnectPromise) await this.reconnectPromise; + // Token rotated since SSE was established; reconnect so subscriptions - // pick up the new identity without waiting for a new _registerQuery call. - const tokenHash = this.hashToken(token); + // pick up the new identity. Await so two simultaneous mutations during + // rotation don't spawn interleaved reconnects (reconnect() coalesces + // via reconnectPromise, but the previous fire-and-forget call could + // still race the in-flight RPC against a session-less state). + const tokenHash = await this.hashToken(token); if (this.sessionId && tokenHash !== this.connectedTokenHash) { - this.reconnect(); + await this.reconnect(); } let response = await this.sendRpc(functionName, args, token); @@ -305,6 +376,11 @@ export class ForgeClient { if (response.status === 401 && this.config.refreshToken) { const refreshed = await this.tryRefresh(); if (refreshed) { + // NOTE: for multipart uploads, the retry will re-stream `args`. If a + // caller passes a one-shot ReadableStream or a Blob already piped + // elsewhere, the second send may transmit empty content. Callers + // that need refresh-safe uploads should pass File/Blob backed by + // bytes (re-readable) rather than streamed sources. response = await this.sendRpc(functionName, args, refreshed); } } @@ -343,7 +419,9 @@ export class ForgeClient { if (hasFiles) { const formData = this.buildFormData(args); - const headers: Record = { "x-forge-platform": "web" }; + // X-Forge-CSRF forces a CORS preflight on cross-origin POSTs so the + // server's CORS allowlist gates cross-site requests despite credentials. + const headers: Record = { "x-forge-platform": "web", "X-Forge-CSRF": "1" }; if (token) headers["Authorization"] = `Bearer ${token}`; if (correlationId) headers["x-correlation-id"] = correlationId; return fetch(`${this.config.url}/_api/rpc/${functionName}/upload`, { @@ -363,6 +441,7 @@ export class ForgeClient { "Content-Type": "application/json", "Accept": "application/vnd.forge.v1+json", "x-forge-platform": "web", + "X-Forge-CSRF": "1", }; if (token) headers["Authorization"] = `Bearer ${token}`; if (correlationId) headers["x-correlation-id"] = correlationId; @@ -398,14 +477,28 @@ export class ForgeClient { return () => this.subscriptions.delete(target); } - async _registerQuery(subscriptionId: string, functionName: string, args: unknown): Promise { - // Must await here (unlike the fire-and-forget in call()) because - // registration needs the new session_id before hitting /_api/subscribe. - const currentToken = await this.getToken(); - const currentHash = this.hashToken(currentToken); - if (this.sessionId && currentHash !== this.connectedTokenHash) { + /** + * Ensure the SSE session is settled and matches the current auth token before + * a subscription registers against it. Awaits any in-flight reconnect, then + * reconnects if the token rotated since the session was established. Without + * this, a subscription can bind to a session that's mid-reconnect and about to + * be abandoned — its server-side entry then lives on a dead session and every + * push is silently dropped. All subscription kinds funnel through here so they + * always land on the live session. + */ + private async ensureCurrentSession(): Promise { + if (this.reconnectPromise) await this.reconnectPromise; + const token = await this.getToken(); + const hash = await this.hashToken(token); + if (this.sessionId && hash !== this.connectedTokenHash) { await this.reconnect(); } + } + + async _registerQuery(subscriptionId: string, functionName: string, args: unknown): Promise { + // Settle the session before registering (needs the live session_id before + // hitting /_api/subscribe). + await this.ensureCurrentSession(); this.subscriptionMeta.set(subscriptionId, { functionName, args, failedAttempts: 0 }); @@ -436,6 +529,7 @@ export class ForgeClient { method: "POST", headers: { "Content-Type": "application/json", + "X-Forge-CSRF": "1", ...(token ? { Authorization: `Bearer ${token}` } : {}), }, body: JSON.stringify({ @@ -466,6 +560,7 @@ export class ForgeClient { method: "POST", headers: { "Content-Type": "application/json", + "X-Forge-CSRF": "1", ...(token ? { Authorization: `Bearer ${token}` } : {}), }, body: JSON.stringify({ @@ -477,13 +572,18 @@ export class ForgeClient { }); } - async _registerJob(clientSubId: string, jobId: string): Promise { + // Raw job-subscribe POST against the current session. Used both by the + // public `_registerJob` (after settling the session) and by + // `reregisterSubscriptions` during a reconnect — the latter must NOT settle + // the session (it runs inside the reconnect) or it would deadlock. + private async registerSseJob(clientSubId: string, jobId: string): Promise { if (!this.sessionId) return null; const token = await this.getToken(); const res = await fetch(`${this.config.url}/_api/subscribe-job`, { method: "POST", headers: { "Content-Type": "application/json", + "X-Forge-CSRF": "1", ...(token ? { Authorization: `Bearer ${token}` } : {}), }, body: JSON.stringify({ @@ -494,20 +594,29 @@ export class ForgeClient { }), credentials: "include", }); - if (res.ok) { - const json = await res.json(); - return json.data ?? null; - } - return null; + return this.parseTrackerResponse(res, "JOB_SUBSCRIBE_FAILED"); } - async _registerWorkflow(clientSubId: string, workflowId: string): Promise { + async _registerJob(clientSubId: string, jobId: string): Promise { + // Track for re-registration on reconnect (see `jobMeta`). Recorded even if + // there's no session yet so a subscription made pre-connect is restored. + this.jobMeta.set(clientSubId, jobId); + await this.ensureCurrentSession(); + return this.registerSseJob(clientSubId, jobId); + } + + // Raw workflow-subscribe POST. See `registerSseJob` for why this is split. + private async registerSseWorkflow( + clientSubId: string, + workflowId: string, + ): Promise { if (!this.sessionId) return null; const token = await this.getToken(); const res = await fetch(`${this.config.url}/_api/subscribe-workflow`, { method: "POST", headers: { "Content-Type": "application/json", + "X-Forge-CSRF": "1", ...(token ? { Authorization: `Bearer ${token}` } : {}), }, body: JSON.stringify({ @@ -518,18 +627,66 @@ export class ForgeClient { }), credentials: "include", }); + return this.parseTrackerResponse(res, "WORKFLOW_SUBSCRIBE_FAILED"); + } + + async _registerWorkflow(clientSubId: string, workflowId: string): Promise { + this.workflowMeta.set(clientSubId, workflowId); + await this.ensureCurrentSession(); + return this.registerSseWorkflow(clientSubId, workflowId); + } + + /** Drop a job subscription's local state so it isn't re-registered on reconnect. */ + _unregisterJob(clientSubId: string): void { + this.jobMeta.delete(clientSubId); + this.subscriptions.delete(`job:${clientSubId}`); + } + + /** Drop a workflow subscription's local state so it isn't re-registered on reconnect. */ + _unregisterWorkflow(clientSubId: string): void { + this.workflowMeta.delete(clientSubId); + this.subscriptions.delete(`wf:${clientSubId}`); + } + + /** + * Common envelope handling for job/workflow subscribe endpoints. Surfaces + * non-200 responses as ForgeClientError so the store can render a real + * failure state instead of "loading=false, error=null, data=null". + */ + private async parseTrackerResponse(res: Response, fallbackCode: string): Promise { if (res.ok) { const json = await res.json(); + if (json && json.success === false) { + const err = json.error ?? {}; + throw new ForgeClientError(err.code ?? fallbackCode, err.message ?? `Server returned ${res.status}`); + } return json.data ?? null; } - return null; + let body: { error?: { code?: string; message?: string } } | null = null; + try { + body = await res.json(); + } catch { + // Non-JSON error body — fall through to status-based message. + } + const code = body?.error?.code ?? fallbackCode; + const message = body?.error?.message ?? `Server returned ${res.status}`; + throw new ForgeClientError(code, message); } private async reregisterSubscriptions(): Promise { // Snapshot keys; callbacks for failed subs may mutate the map mid-loop const subscriptionIds = Array.from(this.subscriptionMeta.keys()); + let first = true; for (const id of subscriptionIds) { + // Stagger re-registrations so a page with 100 subs doesn't issue 100 + // POSTs in <1s after a reconnect. Skip the very first to keep latency + // low for single-sub pages. + if (!first) { + await new Promise((r) => setTimeout(r, 50 + Math.random() * 100)); + } + first = false; + const meta = this.subscriptionMeta.get(id); if (!meta) continue; @@ -553,7 +710,9 @@ export class ForgeClient { meta.failedAttempts = 0; } catch (err) { meta.failedAttempts++; - console.error(`Failed to re-register subscription ${id} (attempt ${meta.failedAttempts}):`, err); + // Routed through onDebug so production builds don't leak subscription + // IDs into the browser console. + this.config.onDebug?.(`Failed to re-register subscription ${id} (attempt ${meta.failedAttempts}): ${err instanceof Error ? err.message : String(err)}`); const callback = this.subscriptions.get(`sub:${id}`); if (callback) { callback({ @@ -566,13 +725,66 @@ export class ForgeClient { } } } + + // Re-register job/workflow subscriptions on the new session. Their server + // entries are bound to the prior (now-closed) session, so without this their + // progress/status pushes land on a channel no client reads. Skip any whose + // store has already torn down (callback removed) and forward the fresh + // snapshot so the store catches up to any state missed during the gap. + for (const [clientSubId, jobId] of Array.from(this.jobMeta)) { + if (!this.subscriptions.has(`job:${clientSubId}`)) { + this.jobMeta.delete(clientSubId); + continue; + } + try { + // Raw register — we're already inside the reconnect, so don't funnel + // through `_registerJob` (which would await this same reconnect). + const data = await this.registerSseJob(clientSubId, jobId); + if (data) this.subscriptions.get(`job:${clientSubId}`)?.(data); + } catch (err) { + this.config.onDebug?.( + `Failed to re-register job ${clientSubId}: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } + for (const [clientSubId, workflowId] of Array.from(this.workflowMeta)) { + if (!this.subscriptions.has(`wf:${clientSubId}`)) { + this.workflowMeta.delete(clientSubId); + continue; + } + try { + // Raw register — see the job loop above. + const data = await this.registerSseWorkflow(clientSubId, workflowId); + if (data) this.subscriptions.get(`wf:${clientSubId}`)?.(data); + } catch (err) { + this.config.onDebug?.( + `Failed to re-register workflow ${clientSubId}: ${err instanceof Error ? err.message : String(err)}`, + ); + } + } } - private containsFiles(obj: unknown): boolean { + private containsFiles(obj: unknown, seen: WeakSet = new WeakSet()): boolean { if (obj instanceof File || obj instanceof Blob) return true; - if (Array.isArray(obj)) return obj.some((item) => this.containsFiles(item)); if (obj && typeof obj === "object") { - return Object.values(obj).some((value) => this.containsFiles(value)); + if (seen.has(obj)) return false; + seen.add(obj); + // Date has no nested user data; skip to avoid scanning its prototype chain. + if (obj instanceof Date) return false; + if (obj instanceof Map) { + for (const v of obj.values()) { + if (this.containsFiles(v, seen)) return true; + } + return false; + } + if (obj instanceof Set) { + for (const v of obj.values()) { + if (this.containsFiles(v, seen)) return true; + } + return false; + } + if (Array.isArray(obj)) return obj.some((item) => this.containsFiles(item, seen)); + return Object.values(obj).some((value) => this.containsFiles(value, seen)); } return false; } @@ -611,16 +823,24 @@ export class ForgeClient { private setConnectionState(state: ConnectionState): void { this.connectionState = state; - this.connectionListeners.forEach((listener) => listener(state)); + for (const listener of this.connectionListeners) { + try { + listener(state); + } catch (err) { + // One misbehaving listener must not abort the rest. Surface via + // opt-in debug channel rather than the console. + this.config.onDebug?.(`connection listener threw: ${err instanceof Error ? err.message : String(err)}`); + } + } } private scheduleReconnect(): void { if (this.reconnectAttempts >= this.maxReconnectAttempts) return; - // Exponential backoff with jitter to prevent synchronized retry storms + // Exponential backoff with full jitter: multiplier is [0,1) so retries + // are uniformly spread across the window rather than biased upward. const exponentialDelay = this.reconnectDelay * Math.pow(2, this.reconnectAttempts); - const jitter = 0.5 + Math.random(); - const delay = Math.min(exponentialDelay * jitter, 30000); + const delay = Math.min(exponentialDelay * Math.random(), 30000); this.reconnectAttempts++; setTimeout(() => { diff --git a/packages/forge-svelte/context.ts b/packages/forge-svelte/context.ts index 37c96b22..93a03427 100644 --- a/packages/forge-svelte/context.ts +++ b/packages/forge-svelte/context.ts @@ -5,6 +5,11 @@ import type { AuthState } from "./types.js"; const FORGE_CLIENT_KEY = Symbol("forge-client"); const FORGE_AUTH_KEY = Symbol("forge-auth"); + +// Module-level fallback for callers outside Svelte's component tree (rare, +// but used by ad-hoc test harnesses). Multiple providers in the same process +// silently shared this slot — now we warn loudly so the test/embedded app +// catches it instead of debugging stale-client bugs. let globalClient: ForgeClient | null = null; export function getForgeClient(): ForgeClient { @@ -20,6 +25,13 @@ export function getForgeClient(): ForgeClient { export function setForgeClient(client: ForgeClient): void { setContext(FORGE_CLIENT_KEY, client); + if (globalClient && globalClient !== client && typeof console !== "undefined") { + console.warn( + "[forge] setForgeClient called with a second client. The module-level " + + "fallback now points at the new instance; any code still using getForgeClient() " + + "without a Svelte context will see the replacement. Mount one ForgeProvider per app.", + ); + } globalClient = client; } diff --git a/packages/forge-svelte/signals.ts b/packages/forge-svelte/signals.ts index bacf3213..5f9bc44d 100644 --- a/packages/forge-svelte/signals.ts +++ b/packages/forge-svelte/signals.ts @@ -141,8 +141,13 @@ export class ForgeSignals { async page(properties?: Record): Promise { if (!this.config.enabled) return; try { + // Re-extract UTM on each call so SPA navigations to new utm_*-bearing + // URLs still propagate. Drop the querystring from the captured URL + // since arbitrary query params often carry secrets (?ssoToken=…, + // ?reset_token=…). + this.utmParams = this.extractUtm(); const payload: Record = { - url: location.href, + url: location.pathname, referrer: document.referrer || undefined, title: document.title || undefined, ...this.utmParams, @@ -177,7 +182,7 @@ export class ForgeSignals { context, correlation_id: this.lastCorrelationId ?? undefined, breadcrumbs: [...this.breadcrumbs], - page_url: typeof location !== "undefined" ? location.href : undefined, + page_url: typeof location !== "undefined" ? location.pathname : undefined, }]); } @@ -211,7 +216,7 @@ export class ForgeSignals { value, rating: extra?.rating, attribution: extra?.attribution, - page_url: typeof location !== "undefined" ? location.href : undefined, + page_url: typeof location !== "undefined" ? location.pathname : undefined, }, correlation_id: this.lastCorrelationId ?? undefined, }); @@ -229,6 +234,10 @@ export class ForgeSignals { headers: { "Content-Type": "application/json", "x-forge-platform": "web", + // CSRF mitigation: a custom header forces a preflight on cross-origin + // requests. If the server's CORS allowlist excludes attacker origins, + // the preflight is rejected before the credentialed POST is sent. + "X-Forge-CSRF": "1", ...(this.sessionId ? { "x-session-id": this.sessionId } : {}), }, credentials: "include", @@ -264,7 +273,7 @@ export class ForgeSignals { payload: { events, context: { - page_url: typeof location !== "undefined" ? location.href : undefined, + page_url: typeof location !== "undefined" ? location.pathname : undefined, session_id: this.sessionId, }, }, @@ -293,7 +302,7 @@ export class ForgeSignals { payload: { events, context: { - page_url: typeof location !== "undefined" ? location.href : undefined, + page_url: typeof location !== "undefined" ? location.pathname : undefined, session_id: this.sessionId, }, }, @@ -344,14 +353,14 @@ export class ForgeSignals { if (typeof window === "undefined") return; if (this.config.autoPageViews) { - this.lastPageUrl = location.href; + this.lastPageUrl = location.pathname; this.page(); this.originalPushState = history.pushState.bind(history); this.originalReplaceState = history.replaceState.bind(history); const onNavigation = () => { - const current = location.href; + const current = location.pathname; if (current !== this.lastPageUrl) { this.lastPageUrl = current; this.page(); diff --git a/packages/forge-svelte/stores.ts b/packages/forge-svelte/stores.ts index b60c9aee..92032b2f 100644 --- a/packages/forge-svelte/stores.ts +++ b/packages/forge-svelte/stores.ts @@ -61,9 +61,21 @@ export function createConnectionStore(): ConnectionStatusStore { type RejectEmptyObject = T extends Record ? never : T; +/** Optional runtime validator for store payloads. Return null to surface a + * schema mismatch as an error instead of letting the UI deref garbage. */ +export interface StoreOptions { + validate?: (data: unknown) => T | null; +} + +const VALIDATION_ERROR: ForgeError = new ForgeClientError( + "VALIDATION_ERROR", + "Response failed runtime validation", +); + export function createQueryStore( functionName: string, - args: RejectEmptyObject + args: RejectEmptyObject, + options?: StoreOptions, ): QueryStore { const client = getForgeClient(); const subscribers = new Set<(value: QueryResult) => void>(); @@ -80,8 +92,13 @@ export function createQueryStore( notify(); try { - const data = await client.call(functionName, args); - state = { loading: false, data, error: null }; + const raw = await client.call(functionName, args); + const data = options?.validate ? options.validate(raw) : (raw as TResult); + if (data === null && options?.validate) { + state = { loading: false, data: null, error: VALIDATION_ERROR }; + } else { + state = { loading: false, data: data as TResult, error: null }; + } } catch (e) { state = { loading: false, data: null, error: e as ForgeError }; } @@ -106,7 +123,8 @@ export function createQueryStore( export function createSubscriptionStore( functionName: string, - args: RejectEmptyObject + args: RejectEmptyObject, + options?: StoreOptions, ): SubscriptionStore { const client = getForgeClient(); const subscribers = new Set<(value: SubscriptionResult) => void>(); @@ -135,14 +153,31 @@ export function createSubscriptionStore( try { subscriptionId = crypto.randomUUID(); - const initialData = await client._registerQuery(subscriptionId, functionName, args); - state = { loading: false, data: initialData as TResult, error: null, stale: false }; - notify(); - unsubscribeFn = client._subscribe(`sub:${subscriptionId}`, (data: unknown) => { - state = { loading: false, data: data as TResult, error: null, stale: false }; + // Register the update callback BEFORE the fallible initial registration. + // If the first registration fails — e.g. an auth-required query subscribed + // while still anonymous returns 401 — the callback must already be wired so + // that once a later reconnect re-registers the subscription (after login), + // reactor pushes are delivered and the store recovers. Wiring it only after + // a successful registration left such subscriptions permanently dead. + unsubscribeFn = client._subscribe(`sub:${subscriptionId}`, (raw: unknown) => { + const data = options?.validate ? options.validate(raw) : (raw as TResult); + if (data === null && options?.validate) { + state = { loading: false, data: null, error: VALIDATION_ERROR, stale: false }; + } else { + state = { loading: false, data: data as TResult, error: null, stale: false }; + } notify(); }); + + const initialRaw = await client._registerQuery(subscriptionId, functionName, args); + const initial = options?.validate ? options.validate(initialRaw) : (initialRaw as TResult); + if (initial === null && options?.validate) { + state = { loading: false, data: null, error: VALIDATION_ERROR, stale: false }; + } else { + state = { loading: false, data: initial as TResult, error: null, stale: false }; + } + notify(); } catch (e) { state = { loading: false, data: null, error: e as ForgeError, stale: false }; notify(); @@ -186,6 +221,16 @@ export function createSubscriptionStore( const uuidRegex = /^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$/i; +const JOB_STATUSES = new Set([ + "pending", "claimed", "running", "completed", "failed", + "retry", "dead_letter", "cancel_requested", "cancelled", +]); + +const WORKFLOW_STATUSES = new Set([ + "pending", "running", "sleeping", "waiting", "completed", "failed", + "blocked_missing_version", "blocked_signature_mismatch", "blocked_missing_handler", +]); + function asValidRecord( data: unknown, ...requiredStringFields: string[] @@ -205,6 +250,7 @@ export function createJobStore( const client = getForgeClient(); const subscribers = new Set<(value: JobState & { loading: boolean }) => void>(); let unsubscribeFn: (() => void) | null = null; + let clientSubId: string | null = null; let state: JobState & { loading: boolean } = { jobId: "", status: "pending", @@ -229,12 +275,9 @@ export function createJobStore( state = { ...state, jobId, loading: false }; notify(); - const clientSubId = crypto.randomUUID(); - const initialData = await client._registerJob(clientSubId, jobId); - const applyJobData = (data: unknown) => { const jobData = asValidRecord(data, "job_id", "status"); - if (!jobData) { + if (!jobData || !JOB_STATUSES.has(jobData.status as string)) { state = { ...state, status: "failed", error: "Invalid job update", loading: false }; notify(); return; @@ -251,9 +294,13 @@ export function createJobStore( notify(); }; - if (initialData) applyJobData(initialData); - + clientSubId = crypto.randomUUID(); + // Register the update callback before the server registration so the + // subscription survives a reconnect-driven re-registration (the client + // re-registers job subs whose callback is still present). unsubscribeFn = client._subscribe(`job:${clientSubId}`, applyJobData); + const initialData = await client._registerJob(clientSubId, jobId); + if (initialData) applyJobData(initialData); } catch (e) { state = { ...state, @@ -273,16 +320,20 @@ export function createJobStore( run(state); return () => { subscribers.delete(run); - if (subscribers.size === 0 && unsubscribeFn) { - unsubscribeFn(); + if (subscribers.size === 0) { unsubscribeFn = null; + if (clientSubId) { + client._unregisterJob(clientSubId); + clientSubId = null; + } } }; }, unsubscribe: () => { - if (unsubscribeFn) { - unsubscribeFn(); - unsubscribeFn = null; + unsubscribeFn = null; + if (clientSubId) { + client._unregisterJob(clientSubId); + clientSubId = null; } }, }; @@ -295,6 +346,7 @@ export function createWorkflowStore( const client = getForgeClient(); const subscribers = new Set<(value: WorkflowState & { loading: boolean }) => void>(); let unsubscribeFn: (() => void) | null = null; + let clientSubId: string | null = null; let state: WorkflowState & { loading: boolean } = { workflowId: "", status: "pending", @@ -320,12 +372,9 @@ export function createWorkflowStore( state = { ...state, workflowId, loading: false }; notify(); - const clientSubId = crypto.randomUUID(); - const initialData = await client._registerWorkflow(clientSubId, workflowId); - const applyWorkflowData = (data: unknown) => { const wfData = asValidRecord(data, "workflow_id", "status"); - if (!wfData) { + if (!wfData || !WORKFLOW_STATUSES.has(wfData.status as string)) { state = { ...state, status: "failed", error: "Invalid workflow update", loading: false }; notify(); return; @@ -351,9 +400,12 @@ export function createWorkflowStore( notify(); }; - if (initialData) applyWorkflowData(initialData); - + clientSubId = crypto.randomUUID(); + // Register the callback before the server registration so the subscription + // survives a reconnect-driven re-registration. unsubscribeFn = client._subscribe(`wf:${clientSubId}`, applyWorkflowData); + const initialData = await client._registerWorkflow(clientSubId, workflowId); + if (initialData) applyWorkflowData(initialData); } catch (e) { state = { ...state, @@ -373,16 +425,20 @@ export function createWorkflowStore( run(state); return () => { subscribers.delete(run); - if (subscribers.size === 0 && unsubscribeFn) { - unsubscribeFn(); + if (subscribers.size === 0) { unsubscribeFn = null; + if (clientSubId) { + client._unregisterWorkflow(clientSubId); + clientSubId = null; + } } }; }, unsubscribe: () => { - if (unsubscribeFn) { - unsubscribeFn(); - unsubscribeFn = null; + unsubscribeFn = null; + if (clientSubId) { + client._unregisterWorkflow(clientSubId); + clientSubId = null; } }, }; From 15b00cbc2a76ee132351ee71b8dfbd9448df0b5e Mon Sep 17 00:00:00 2001 From: Isala Piyarisi Date: Thu, 28 May 2026 02:53:50 +0530 Subject: [PATCH 5/7] wire per-user auth and make examples testable end-to-end Per-user auth in realtime-todo-list, auth-gated demo panels, idempotent seed migrations, regenerated .sqlx for auth-scoped queries, plain-reqwest webhook loopback, and tuned Playwright timeouts across all six templates. --- examples/with-dioxus/demo/.env | 20 +- examples/with-dioxus/demo/.env.example | 21 +- examples/with-dioxus/demo/.gitignore | 2 + examples/with-dioxus/demo/Cargo.toml | 6 +- examples/with-dioxus/demo/Dockerfile | 8 +- examples/with-dioxus/demo/docker-compose.yml | 4 +- examples/with-dioxus/demo/forge.toml | 19 +- .../demo/frontend/playwright.config.ts | 2 +- .../demo/frontend/src/components/auth_card.rs | 38 +- .../frontend/src/components/cache_card.rs | 15 +- .../frontend/src/components/users_section.rs | 21 +- .../frontend/src/components/webhook_card.rs | 77 +--- .../demo/frontend/src/forge/api.rs | 10 + .../demo/frontend/src/forge/types.rs | 93 ++--- .../demo/frontend/src/pages/demo.rs | 2 +- .../demo/frontend/src/signals_bridge.rs | 11 +- .../demo/frontend/tests/fixtures.ts | 2 +- .../demo/frontend/tests/home.spec.ts | 28 ++ .../demo/migrations/0001_initial.sql | 32 +- .../with-dioxus/demo/src/functions/auth.rs | 7 +- .../with-dioxus/demo/src/functions/export.rs | 7 +- .../with-dioxus/demo/src/functions/iss.rs | 14 +- .../with-dioxus/demo/src/functions/mcp.rs | 4 +- .../with-dioxus/demo/src/functions/stats.rs | 2 + .../with-dioxus/demo/src/functions/trades.rs | 60 ++- .../with-dioxus/demo/src/functions/users.rs | 191 ++++++++- .../demo/src/functions/verification.rs | 3 +- .../with-dioxus/demo/src/functions/webhook.rs | 52 +++ examples/with-dioxus/minimal/Dockerfile | 8 +- .../with-dioxus/minimal/docker-compose.yml | 4 +- examples/with-dioxus/minimal/forge.toml | 11 +- .../minimal/frontend/playwright.config.ts | 2 +- .../minimal/frontend/src/forge/types.rs | 7 +- .../migrations/0001_initial.sql.example | 11 +- examples/with-dioxus/realtime-todo-list/.env | 4 +- .../realtime-todo-list/.env.example | 4 +- ...f63c641c19b14ceb671016d66310a74dbca26.json | 46 +++ ...678ae2f60db0bfdbd437fb00618618b0ab5e.json} | 5 +- ...beb5479ac15d11cb1c4fdfd0cd7247f519f3c.json | 57 +++ ...992a65a3d646156b70a30a3377a4e0f19f1f3.json | 52 +++ ...904549a2f663613d7570b529eeaa13a6d9aa.json} | 15 +- ...52ef8f47e93152275994011c4382e85ed7f2.json} | 15 +- ...b97af559e04fa847bfa0ec37ecc45354d3d22.json | 52 +++ ...b73e7b6282402018f150a58eefe67319ba763.json | 38 -- .../with-dioxus/realtime-todo-list/Cargo.toml | 2 + .../with-dioxus/realtime-todo-list/Dockerfile | 8 +- .../realtime-todo-list/docker-compose.yml | 4 +- .../with-dioxus/realtime-todo-list/forge.toml | 18 +- .../frontend/playwright.config.ts | 2 +- .../frontend/src/forge/api.rs | 41 ++ .../frontend/src/forge/types.rs | 122 +++++- .../realtime-todo-list/frontend/src/main.rs | 5 +- .../frontend/src/todo_app.rs | 339 ++++++++++----- .../frontend/tests/home.spec.ts | 103 +++-- .../migrations/0001_todos.sql | 16 +- .../realtime-todo-list/src/functions/auth.rs | 140 +++++++ .../realtime-todo-list/src/functions/mod.rs | 1 + .../realtime-todo-list/src/functions/todos.rs | 186 ++++++++- .../realtime-todo-list/src/schema/todo.rs | 36 ++ examples/with-svelte/demo/.env | 20 +- examples/with-svelte/demo/.env.example | 21 +- examples/with-svelte/demo/.gitignore | 2 + examples/with-svelte/demo/Cargo.toml | 5 +- examples/with-svelte/demo/Dockerfile | 8 +- examples/with-svelte/demo/docker-compose.yml | 6 +- examples/with-svelte/demo/forge.toml | 19 +- .../demo/frontend/playwright.config.ts | 2 +- .../demo/frontend/src/lib/forge/api.ts | 4 + .../frontend/src/lib/forge/reactive.svelte.ts | 6 + .../demo/frontend/src/lib/forge/types.ts | 15 +- .../demo/frontend/src/routes/+page.svelte | 56 +-- .../demo/frontend/tests/fixtures.ts | 2 +- .../demo/frontend/tests/home.spec.ts | 34 ++ .../demo/migrations/0001_initial.sql | 32 +- .../with-svelte/demo/src/functions/auth.rs | 7 +- .../with-svelte/demo/src/functions/export.rs | 7 +- .../with-svelte/demo/src/functions/iss.rs | 38 +- .../with-svelte/demo/src/functions/mcp.rs | 4 +- .../with-svelte/demo/src/functions/stats.rs | 2 + .../with-svelte/demo/src/functions/trades.rs | 87 ++-- .../with-svelte/demo/src/functions/users.rs | 60 ++- .../demo/src/functions/verification.rs | 103 +---- .../with-svelte/demo/src/functions/webhook.rs | 52 +++ examples/with-svelte/minimal/Dockerfile | 8 +- .../with-svelte/minimal/docker-compose.yml | 6 +- examples/with-svelte/minimal/forge.toml | 11 +- .../minimal/frontend/playwright.config.ts | 2 +- .../migrations/0001_initial.sql.example | 11 +- examples/with-svelte/realtime-todo-list/.env | 4 +- .../realtime-todo-list/.env.example | 4 +- ...f63c641c19b14ceb671016d66310a74dbca26.json | 46 +++ ...678ae2f60db0bfdbd437fb00618618b0ab5e.json} | 5 +- ...beb5479ac15d11cb1c4fdfd0cd7247f519f3c.json | 57 +++ ...992a65a3d646156b70a30a3377a4e0f19f1f3.json | 52 +++ ...904549a2f663613d7570b529eeaa13a6d9aa.json} | 15 +- ...52ef8f47e93152275994011c4382e85ed7f2.json} | 15 +- ...b97af559e04fa847bfa0ec37ecc45354d3d22.json | 52 +++ ...b73e7b6282402018f150a58eefe67319ba763.json | 38 -- .../with-svelte/realtime-todo-list/Cargo.toml | 2 + .../with-svelte/realtime-todo-list/Dockerfile | 8 +- .../realtime-todo-list/docker-compose.yml | 6 +- .../with-svelte/realtime-todo-list/forge.toml | 18 +- .../frontend/playwright.config.ts | 2 +- .../frontend/src/lib/forge/api.ts | 16 + .../frontend/src/lib/forge/auth.svelte.ts | 160 ++++++++ .../frontend/src/lib/forge/reactive.svelte.ts | 12 +- .../frontend/src/lib/forge/types.ts | 41 ++ .../frontend/src/routes/+layout.svelte | 8 +- .../frontend/src/routes/+page.svelte | 386 +++++++++++++----- .../frontend/tests/todo.spec.ts | 107 +++-- .../migrations/0001_todos.sql | 16 +- .../realtime-todo-list/src/functions/auth.rs | 140 +++++++ .../realtime-todo-list/src/functions/mod.rs | 1 + .../realtime-todo-list/src/functions/todos.rs | 194 +++++---- .../realtime-todo-list/src/schema/todo.rs | 36 ++ 115 files changed, 3139 insertions(+), 1037 deletions(-) create mode 100644 examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json rename examples/with-dioxus/realtime-todo-list/.sqlx/{query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json => query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json} (50%) create mode 100644 examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json create mode 100644 examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json rename examples/{with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json => with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json} (68%) rename examples/with-dioxus/realtime-todo-list/.sqlx/{query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json => query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json} (74%) create mode 100644 examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json delete mode 100644 examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json create mode 100644 examples/with-dioxus/realtime-todo-list/src/functions/auth.rs create mode 100644 examples/with-svelte/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json rename examples/with-svelte/realtime-todo-list/.sqlx/{query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json => query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json} (50%) create mode 100644 examples/with-svelte/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json create mode 100644 examples/with-svelte/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json rename examples/{with-dioxus/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json => with-svelte/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json} (68%) rename examples/with-svelte/realtime-todo-list/.sqlx/{query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json => query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json} (74%) create mode 100644 examples/with-svelte/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json delete mode 100644 examples/with-svelte/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json create mode 100644 examples/with-svelte/realtime-todo-list/frontend/src/lib/forge/auth.svelte.ts create mode 100644 examples/with-svelte/realtime-todo-list/src/functions/auth.rs diff --git a/examples/with-dioxus/demo/.env b/examples/with-dioxus/demo/.env index be175897..71462349 100644 --- a/examples/with-dioxus/demo/.env +++ b/examples/with-dioxus/demo/.env @@ -1,21 +1,17 @@ -# Server +# Dev-only environment for `forge test` and local runs. NOT shipped to users: +# `scripts/build-template-archive.sh` excludes `.env`, and the webhook secret is +# used server-side only (never in the browser bundle). Users copy `.env.example` +# and generate their own secrets. Mirrors the realtime-todo-list convention. HOST=0.0.0.0 PORT=9081 - -# Logging (error, warn, info, debug, trace) RUST_LOG=info,forge_runtime::function::executor=trace - -# Postgres container settings POSTGRES_USER=postgres POSTGRES_PASSWORD=forge POSTGRES_DB=forge_dioxus_demo_template POSTGRES_PORT=5432 - -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production - -# Webhook HMAC secret (must match client-side secret) +JWT_SECRET=dev-jwt-secret-not-for-production-use-please-rotate +JWT_AUDIENCE=forge-demo-dev WEBHOOK_SECRET=demo-secret - -# Enable offline mode for sqlx compile-time checks +SEED_DEMO_USER=true +CORS_ORIGIN=http://localhost:9080 SQLX_OFFLINE=true diff --git a/examples/with-dioxus/demo/.env.example b/examples/with-dioxus/demo/.env.example index be175897..b5e0fd5d 100644 --- a/examples/with-dioxus/demo/.env.example +++ b/examples/with-dioxus/demo/.env.example @@ -1,3 +1,5 @@ +# Copy to `.env` and fill in real values. Never commit `.env`. + # Server HOST=0.0.0.0 PORT=9081 @@ -11,11 +13,22 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=forge_dioxus_demo_template POSTGRES_PORT=5432 -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=CHANGE_ME_USE_OPENSSL_RAND_BASE64_32 + +# JWT audience claim. Must match the audience configured in your auth provider. +JWT_AUDIENCE=CHANGE_ME_YOUR_AUDIENCE + +# HMAC secret used to verify inbound webhook signatures. +# Generate with: openssl rand -hex 32 +WEBHOOK_SECRET=CHANGE_ME_USE_OPENSSL_RAND_HEX_32 + +# Seed the demo user (demo@example.com / password123) at first migration. +# DEV ONLY. Leave unset (or `false`) in any deployed environment. +SEED_DEMO_USER=true -# Webhook HMAC secret (must match client-side secret) -WEBHOOK_SECRET=demo-secret +# CORS origin for the Dioxus frontend. Override per environment. +CORS_ORIGIN=http://localhost:9080 # Enable offline mode for sqlx compile-time checks SQLX_OFFLINE=true diff --git a/examples/with-dioxus/demo/.gitignore b/examples/with-dioxus/demo/.gitignore index 6c4eb93f..e017b398 100644 --- a/examples/with-dioxus/demo/.gitignore +++ b/examples/with-dioxus/demo/.gitignore @@ -13,6 +13,8 @@ frontend/playwright-report/ frontend/test-results/ # Environment +# `.env` is tracked: it holds dev-only secrets for `forge test` and local runs. +# Real deployments use their own secrets; the template archive excludes `.env`. .env.local .env.*.local diff --git a/examples/with-dioxus/demo/Cargo.toml b/examples/with-dioxus/demo/Cargo.toml index 43d88d0d..fa961d5a 100644 --- a/examples/with-dioxus/demo/Cargo.toml +++ b/examples/with-dioxus/demo/Cargo.toml @@ -6,8 +6,9 @@ rust-version = "1.92" publish = false [features] -default = ["embedded-frontend"] +default = [] embedded-frontend = ["dep:rust-embed", "forge/embedded-frontend"] +testcontainers = ["forge/testcontainers"] [dependencies] forge = { workspace = true } @@ -24,6 +25,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } futures-util = "0.3" argon2 = "0.5" password-hash = "0.5" +hmac = "0.12" +sha2 = "0.10" +hex = "0.4" rust-embed = { version = "8", optional = true } [build-dependencies] diff --git a/examples/with-dioxus/demo/Dockerfile b/examples/with-dioxus/demo/Dockerfile index e9f03715..69dcb57f 100644 --- a/examples/with-dioxus/demo/Dockerfile +++ b/examples/with-dioxus/demo/Dockerfile @@ -1,9 +1,9 @@ -FROM rust:1.92 AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app RUN cargo install cargo-watch --locked RUN apt-get update && apt-get install -y curl pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* -FROM rust:1.92 AS frontend-builder +FROM rust:1.92-slim-bookworm AS frontend-builder WORKDIR /app/frontend RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked @@ -11,7 +11,7 @@ COPY frontend/Cargo.toml frontend/Dioxus.toml ./ COPY frontend/src ./src RUN dx build --web --release -FROM rust:1.92 AS builder +FROM rust:1.92-slim-bookworm AS builder WORKDIR /app RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* @@ -24,7 +24,7 @@ COPY --from=frontend-builder /app/frontend/dist ./frontend/dist RUN cargo build --release -FROM debian:bookworm-slim AS runtime +FROM debian:bookworm-20250203-slim AS runtime RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/forge-dioxus-demo-template /app/forge-dioxus-demo-template diff --git a/examples/with-dioxus/demo/docker-compose.yml b/examples/with-dioxus/demo/docker-compose.yml index 47bf375a..81e79910 100644 --- a/examples/with-dioxus/demo/docker-compose.yml +++ b/examples/with-dioxus/demo/docker-compose.yml @@ -6,7 +6,7 @@ services: target: dev working_dir: /workspace/examples/with-dioxus/demo ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -44,7 +44,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-dioxus/demo/forge.toml b/examples/with-dioxus/demo/forge.toml index ab9dfad8..ad88bc51 100644 --- a/examples/with-dioxus/demo/forge.toml +++ b/examples/with-dioxus/demo/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -42,7 +42,7 @@ otlp_endpoint = "http://localhost:4318" [auth] jwt_algorithm = "HS256" jwt_secret = "${JWT_SECRET}" -jwt_audience = "${JWT_AUDIENCE-https://api.forge-demo.local}" +jwt_audience = "${JWT_AUDIENCE}" [mcp] enabled = true @@ -52,9 +52,13 @@ session_ttl = "1h" allowed_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] require_protocol_version_header = true -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path, PG fallback for global keys (DDoS-grade). +# strict: every check round-trips to PG (cluster-wide correct, billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] # [cluster] # name = "node-1" # auto-generated if omitted @@ -66,3 +70,8 @@ require_protocol_version_header = true # [node] # roles = ["gateway", "function", "worker", "scheduler"] # worker_capabilities = ["general"] # general, gpu, high_cpu + +[signals] +# Product analytics + diagnostics are off by default; this demo opts in to +# exercise the /_api/signal endpoint and the client SDK. +enabled = true diff --git a/examples/with-dioxus/demo/frontend/playwright.config.ts b/examples/with-dioxus/demo/frontend/playwright.config.ts index a6685492..7a839a3a 100644 --- a/examples/with-dioxus/demo/frontend/playwright.config.ts +++ b/examples/with-dioxus/demo/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 180_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-dioxus/demo/frontend/src/components/auth_card.rs b/examples/with-dioxus/demo/frontend/src/components/auth_card.rs index ab11a886..935ddccf 100644 --- a/examples/with-dioxus/demo/frontend/src/components/auth_card.rs +++ b/examples/with-dioxus/demo/frontend/src/components/auth_card.rs @@ -13,9 +13,23 @@ pub fn AuthCard() -> Element { let signals = use_signals(); let mut mode = use_signal(|| "login".to_string()); - let mut auth_email = use_signal(|| "demo@example.com".to_string()); - let mut auth_password = use_signal(|| "password123".to_string()); - let mut auth_name = use_signal(|| String::new()); + // Prefill credentials only in debug builds. Release WASM ships empty fields so a + // public demo is not a one-click login when combined with the seeded admin user. + let mut auth_email = use_signal(|| { + if cfg!(debug_assertions) { + "demo@example.com".to_string() + } else { + String::new() + } + }); + let mut auth_password = use_signal(|| { + if cfg!(debug_assertions) { + "password123".to_string() + } else { + String::new() + } + }); + let mut auth_name = use_signal(String::new); let mut auth_error = use_signal(|| None::); let mut loading = use_signal(|| false); @@ -99,7 +113,10 @@ pub fn AuthCard() -> Element { auth_error.set(None); match refresh_mut.call(RefreshInput::new(&rt)).await { Ok(pair) => { - signals.track_with_properties("token_refresh", json!({"count": refresh_count() + 1})); + signals.track_with_properties( + "token_refresh", + json!({"count": refresh_count() + 1}), + ); let claims = parse_jwt_claims(&pair.access_token); token_claims.set(Some(claims)); auth.update_tokens( @@ -134,13 +151,14 @@ pub fn AuthCard() -> Element { // Restore viewer on mount (persisted in localStorage by ForgeAuthProvider) use_effect(move || { - if auth.is_authenticated() && auth_user.read().is_none() { - if let Some(viewer) = auth.viewer::() { - if let Some(token) = auth.access_token() { - token_claims.set(Some(parse_jwt_claims(&token))); - } - auth_user.set(Some(viewer)); + if auth.is_authenticated() + && auth_user.read().is_none() + && let Some(viewer) = auth.viewer::() + { + if let Some(token) = auth.access_token() { + token_claims.set(Some(parse_jwt_claims(&token))); } + auth_user.set(Some(viewer)); } }); diff --git a/examples/with-dioxus/demo/frontend/src/components/cache_card.rs b/examples/with-dioxus/demo/frontend/src/components/cache_card.rs index d56f7ace..3732ed98 100644 --- a/examples/with-dioxus/demo/frontend/src/components/cache_card.rs +++ b/examples/with-dioxus/demo/frontend/src/components/cache_card.rs @@ -36,15 +36,12 @@ pub fn CacheCard() -> Element { spawn(async move { loading.set(true); let start = now_ms(); - match forge::get_demo_stats(&client).await { - Ok(stats) => { - let elapsed = now_ms() - start; - signals.track_with_properties("cache_fetch", json!({"response_ms": elapsed, "cache_hit": elapsed < 100.0, "fetch_number": fetch_count() + 1})); - data.set(Some(stats)); - response_ms.set(Some(elapsed)); - fetch_count.set(fetch_count() + 1); - } - Err(_) => {} + if let Ok(stats) = forge::get_demo_stats(&client).await { + let elapsed = now_ms() - start; + signals.track_with_properties("cache_fetch", json!({"response_ms": elapsed, "cache_hit": elapsed < 100.0, "fetch_number": fetch_count() + 1})); + data.set(Some(stats)); + response_ms.set(Some(elapsed)); + fetch_count.set(fetch_count() + 1); } loading.set(false); }); diff --git a/examples/with-dioxus/demo/frontend/src/components/users_section.rs b/examples/with-dioxus/demo/frontend/src/components/users_section.rs index d519f828..8afb7ba0 100644 --- a/examples/with-dioxus/demo/frontend/src/components/users_section.rs +++ b/examples/with-dioxus/demo/frontend/src/components/users_section.rs @@ -4,11 +4,30 @@ use serde_json::json; use crate::forge::{ CreateUserParams, DeleteUserParams, UpdateUserParams, User, use_create_user, use_delete_user, - use_get_users_subscription, use_update_user, + use_forge_auth, use_get_users_subscription, use_update_user, }; +/// `get_users` requires an authenticated session, so only mount the subscribing +/// inner component once logged in. Subscribing while anonymous would fire a 401 +/// on every page load (Dioxus hooks can't be called conditionally, so the gate +/// has to live at the component boundary). #[component] pub fn UsersSection(selected_user: Signal>) -> Element { + let auth = use_forge_auth(); + rsx! { + if auth.is_authenticated() { + UsersSectionInner { selected_user } + } else { + section { class: "card", + h2 { "Users " span { class: "badge green", "crud + subscribe" } } + p { class: "muted", "Log in to manage users." } + } + } + } +} + +#[component] +fn UsersSectionInner(selected_user: Signal>) -> Element { let create_user = use_create_user(); let update_user = use_update_user(); let delete_user = use_delete_user(); diff --git a/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs b/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs index 3d5b365e..b93126ed 100644 --- a/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs +++ b/examples/with-dioxus/demo/frontend/src/components/webhook_card.rs @@ -1,14 +1,15 @@ use dioxus::prelude::*; use forge_dioxus::use_signals; -use hmac::{Hmac, Mac}; use serde_json::json; -use sha2::Sha256; use super::{format_time, generate_key}; -use crate::forge::use_get_webhook_events_subscription; +use crate::forge::{ + TriggerDemoWebhookInput, trigger_demo_webhook, use_forge_client, + use_get_webhook_events_subscription, +}; #[component] -pub fn WebhookCard(api_url: String) -> Element { +pub fn WebhookCard() -> Element { let signals = use_signals(); let state = use_get_webhook_events_subscription(); let events = state.data.clone().unwrap_or_default(); @@ -17,6 +18,8 @@ pub fn WebhookCard(api_url: String) -> Element { let mut key_used = use_signal(|| false); let mut webhook_error = use_signal(|| None::); + let client = use_forge_client(); + rsx! { section { class: "card", h2 { "Webhook " span { class: "badge", "webhook" } } @@ -42,23 +45,27 @@ pub fn WebhookCard(api_url: String) -> Element { } button { disabled: key_used(), onclick: { - let api_url = api_url.clone(); let signals = signals.clone(); + let client = client.clone(); move |_| { if key_used() { return; } webhook_error.set(None); let key = idempotency_key(); - let api_url = api_url.clone(); let signals = signals.clone(); + let client = client.clone(); spawn(async move { - match trigger_webhook(&api_url, &key).await { - Ok(()) => { + // The HMAC secret lives on the server. The backend signs + // and POSTs the webhook to itself so the WASM bundle + // never ships the secret. + let input = TriggerDemoWebhookInput::new(key.clone()); + match trigger_demo_webhook(&client, input).await { + Ok(_) => { signals.track_with_properties("webhook_sent", json!({"idempotency_key": &key})); key_used.set(true); } - Err(msg) => { + Err(e) => { signals.track("webhook_error"); - webhook_error.set(Some(msg)); + webhook_error.set(Some(e.to_string())); } } }); @@ -87,53 +94,3 @@ pub fn WebhookCard(api_url: String) -> Element { } } } - -async fn trigger_webhook(api_url: &str, idempotency_key: &str) -> Result<(), String> { - #[cfg(target_arch = "wasm32")] - let now = js_sys::Date::now(); - #[cfg(not(target_arch = "wasm32"))] - let now = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_millis() as f64; - let payload = serde_json::json!({ "action": "test", "ts": now }).to_string(); - - let mut mac = Hmac::::new_from_slice(b"demo-secret").map_err(|e| e.to_string())?; - mac.update(payload.as_bytes()); - let signature = hex::encode(mac.finalize().into_bytes()); - - // HMAC-SHA256 webhooks enforce a replay window: the server rejects any - // request whose `X-Webhook-Timestamp` (unix seconds) is missing or outside - // the 300s window. Send it alongside the signature. - let timestamp = (now / 1000.0) as i64; - - // In same-origin builds `api_url` is empty. Unlike browser `fetch`, reqwest - // can't parse a relative URL, so resolve it against the current origin. - #[cfg(target_arch = "wasm32")] - let base = if api_url.is_empty() { - web_sys::window() - .and_then(|w| w.location().origin().ok()) - .unwrap_or_default() - } else { - api_url.to_string() - }; - #[cfg(not(target_arch = "wasm32"))] - let base = api_url.to_string(); - - let resp = reqwest::Client::new() - .post(format!("{base}/_api/webhooks/demo")) - .header("Content-Type", "application/json") - .header("X-Webhook-Signature", signature) - .header("X-Webhook-Timestamp", timestamp.to_string()) - .header("X-Idempotency-Key", idempotency_key) - .body(payload) - .send() - .await - .map_err(|e| e.to_string())?; - - if resp.status().is_success() { - Ok(()) - } else { - Err(format!("Error: {}", resp.status().as_u16())) - } -} diff --git a/examples/with-dioxus/demo/frontend/src/forge/api.rs b/examples/with-dioxus/demo/frontend/src/forge/api.rs index 6729583d..5fa4545f 100644 --- a/examples/with-dioxus/demo/frontend/src/forge/api.rs +++ b/examples/with-dioxus/demo/frontend/src/forge/api.rs @@ -192,6 +192,16 @@ pub async fn register( pub fn use_register() -> Mutation { use_forge_mutation("register") } +pub async fn trigger_demo_webhook( + client: &ForgeClient, + args: TriggerDemoWebhookInput, +) -> Result { + client.call("trigger_demo_webhook", args).await +} + +pub fn use_trigger_demo_webhook() -> Mutation { + use_forge_mutation("trigger_demo_webhook") +} #[derive(Debug, Clone, PartialEq, serde::Serialize)] pub struct UpdateUserParams { pub id: String, diff --git a/examples/with-dioxus/demo/frontend/src/forge/types.rs b/examples/with-dioxus/demo/frontend/src/forge/types.rs index d0342e7a..de63438e 100644 --- a/examples/with-dioxus/demo/frontend/src/forge/types.rs +++ b/examples/with-dioxus/demo/frontend/src/forge/types.rs @@ -1,11 +1,6 @@ // @generated by FORGE - DO NOT EDIT -#![allow( - dead_code, - unused_imports, - clippy::redundant_field_names, - clippy::too_many_arguments -)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -24,34 +19,34 @@ impl AuthResponse { Self { access_token: access_token.into(), refresh_token: refresh_token.into(), - user: user, + user, } } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct BinanceTrade { - pub symbol: String, - pub price: String, - pub quantity: String, - pub trade_time: i64, - pub is_buyer_maker: bool, + pub s: String, + pub p: String, + pub q: String, + pub T: i64, + pub m: bool, } impl BinanceTrade { pub fn new( - symbol: impl Into, - price: impl Into, - quantity: impl Into, - trade_time: i64, - is_buyer_maker: bool, + s: impl Into, + p: impl Into, + q: impl Into, + T: i64, + m: bool, ) -> Self { Self { - symbol: symbol.into(), - price: price.into(), - quantity: quantity.into(), - trade_time: trade_time, - is_buyer_maker: is_buyer_maker, + s: s.into(), + p: p.into(), + q: q.into(), + T, + m, } } } @@ -85,9 +80,9 @@ impl DemoStats { computed_at: impl Into, ) -> Self { Self { - total_users: total_users, - total_trades: total_trades, - total_webhooks: total_webhooks, + total_users, + total_trades, + total_webhooks, computed_at: computed_at.into(), } } @@ -108,15 +103,15 @@ impl ExportInput { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct ExportOutput { - pub count: i64, + pub count: usize, pub data: String, pub format: String, } impl ExportOutput { - pub fn new(count: i64, data: impl Into, format: impl Into) -> Self { + pub fn new(count: usize, data: impl Into, format: impl Into) -> Self { Self { - count: count, + count, data: data.into(), format: format.into(), } @@ -133,8 +128,8 @@ pub struct IssApiResponse { impl IssApiResponse { pub fn new(iss_position: IssPosition, timestamp: i64, message: impl Into) -> Self { Self { - iss_position: iss_position, - timestamp: timestamp, + iss_position, + timestamp, message: message.into(), } } @@ -159,8 +154,8 @@ impl IssLocation { ) -> Self { Self { id: id.into(), - latitude: latitude, - longitude: longitude, + latitude, + longitude, api_timestamp: api_timestamp.into(), created_at: created_at.into(), } @@ -229,7 +224,7 @@ impl McpUserInfo { id: id.into(), email: email.into(), name: name.into(), - role: role, + role, } } } @@ -292,15 +287,28 @@ impl Trade { Self { id: id.into(), symbol: symbol.into(), - price: price, - quantity: quantity, + price, + quantity, trade_time: trade_time.into(), - is_buyer_maker: is_buyer_maker, + is_buyer_maker, created_at: created_at.into(), } } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TriggerDemoWebhookInput { + pub idempotency_key: String, +} + +impl TriggerDemoWebhookInput { + pub fn new(idempotency_key: impl Into) -> Self { + Self { + idempotency_key: idempotency_key.into(), + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct User { pub id: String, @@ -309,7 +317,6 @@ pub struct User { pub role: UserRole, pub created_at: String, pub updated_at: String, - pub password_hash: Option, } impl User { @@ -325,17 +332,11 @@ impl User { id: id.into(), email: email.into(), name: name.into(), - role: role, + role, created_at: created_at.into(), updated_at: updated_at.into(), - password_hash: None, } } - - pub fn password_hash(mut self, password_hash: impl Into) -> Self { - self.password_hash = Some(password_hash.into()); - self - } } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] @@ -361,7 +362,7 @@ impl UserPublic { id: id.into(), email: email.into(), name: name.into(), - role: role, + role, created_at: created_at.into(), updated_at: updated_at.into(), } @@ -392,7 +393,7 @@ pub struct VerificationOutput { impl VerificationOutput { pub fn new(verified: bool, token: impl Into) -> Self { Self { - verified: verified, + verified, token: token.into(), } } diff --git a/examples/with-dioxus/demo/frontend/src/pages/demo.rs b/examples/with-dioxus/demo/frontend/src/pages/demo.rs index 321d1f88..9c7bcd25 100644 --- a/examples/with-dioxus/demo/frontend/src/pages/demo.rs +++ b/examples/with-dioxus/demo/frontend/src/pages/demo.rs @@ -23,7 +23,7 @@ pub fn DemoPage() -> Element { div { class: "col", TradesCard {} AuthCard {} - WebhookCard { api_url: API_URL.to_string() } + WebhookCard {} VerificationCard { selected_user } } } diff --git a/examples/with-dioxus/demo/frontend/src/signals_bridge.rs b/examples/with-dioxus/demo/frontend/src/signals_bridge.rs index 954d6a45..02128c58 100644 --- a/examples/with-dioxus/demo/frontend/src/signals_bridge.rs +++ b/examples/with-dioxus/demo/frontend/src/signals_bridge.rs @@ -124,11 +124,12 @@ fn install_window_bridge(signals: forge_dioxus::ForgeSignals) { obj.insert("stack".to_string(), serde_json::Value::String(stack)); } } - let context = if ctx_val.is_object() && ctx_val.as_object().is_some_and(|o| !o.is_empty()) { - Some(ctx_val) - } else { - None - }; + let context = + if ctx_val.is_object() && ctx_val.as_object().is_some_and(|o| !o.is_empty()) { + Some(ctx_val) + } else { + None + }; signals.capture_error(&*message, context); }, ) diff --git a/examples/with-dioxus/demo/frontend/tests/fixtures.ts b/examples/with-dioxus/demo/frontend/tests/fixtures.ts index 8c312780..3435b927 100644 --- a/examples/with-dioxus/demo/frontend/tests/fixtures.ts +++ b/examples/with-dioxus/demo/frontend/tests/fixtures.ts @@ -6,7 +6,7 @@ export const API_URL = process.env.FORGE_TEST_URL || process.env.VITE_API_URL || "http://localhost:9081"; -export const ACTION_TIMEOUT = process.env.CI ? 30_000 : 30_000; +export const ACTION_TIMEOUT = 30_000; export function uniqueId(prefix: string): string { return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`; diff --git a/examples/with-dioxus/demo/frontend/tests/home.spec.ts b/examples/with-dioxus/demo/frontend/tests/home.spec.ts index e4799ca5..80cd2394 100644 --- a/examples/with-dioxus/demo/frontend/tests/home.spec.ts +++ b/examples/with-dioxus/demo/frontend/tests/home.spec.ts @@ -1,3 +1,4 @@ +import type { Page } from "@playwright/test"; import { test, expect, @@ -7,6 +8,28 @@ import { trackConsoleErrors, } from "./fixtures"; +// The release WASM bundle ships empty credential fields (prefill is debug-only), +// so fill them explicitly. Logging in rotates the token: the client tears down +// the anonymous SSE stream and re-subscribes every query over a fresh +// authenticated one. Wait for that re-subscription so reactive reads (and +// job/workflow push updates) reflect the authenticated session. +async function loginAsAdmin(page: Page) { + const auth = page.locator("section", { + has: page.getByText("refresh tokens"), + }); + await auth.getByPlaceholder("Email").fill("demo@example.com"); + await auth.getByPlaceholder(/Password/).fill("password123"); + const resubscribed = page.waitForResponse( + (res) => res.url().includes("/_api/subscribe") && res.status() === 200, + { timeout: ACTION_TIMEOUT * 3 }, + ); + await auth.locator('button[type="submit"]').click(); + await expect(auth.getByText("Logged in as")).toBeVisible({ + timeout: ACTION_TIMEOUT, + }); + await resubscribed; +} + async function signDemoWebhook(body: string): Promise { const encoder = new TextEncoder(); const keyData = await crypto.subtle.importKey( @@ -32,6 +55,7 @@ test("users CRUD stays reactive through create, edit, and delete", async ({ const updatedName = uniqueId("Edited"); await gotoReady(); + await loginAsAdmin(page); const section = page.locator("section", { has: page.getByRole("heading", { name: /users/i }), @@ -67,6 +91,7 @@ test("export job and verification workflow complete from the UI", async ({ gotoReady, }) => { await gotoReady(); + await loginAsAdmin(page); const exportSection = page.locator("section", { has: page.getByText("Export Job"), @@ -110,6 +135,8 @@ test("auth flow logs in, refreshes, and logs out cleanly", async ({ has: page.getByText("refresh tokens"), }); + await section.getByPlaceholder("Email").fill("demo@example.com"); + await section.getByPlaceholder(/Password/).fill("password123"); await section.locator('button[type="submit"]').click(); await expect(section.getByText("Logged in as")).toBeVisible({ timeout: ACTION_TIMEOUT, @@ -159,6 +186,7 @@ test("webhook endpoint rejects bad signatures and surfaces accepted events", asy expect(accepted.ok()).toBeTruthy(); await gotoReady(); + await loginAsAdmin(page); const section = page.locator("section", { has: page.getByText("Webhook"), }); diff --git a/examples/with-dioxus/demo/migrations/0001_initial.sql b/examples/with-dioxus/demo/migrations/0001_initial.sql index a854effb..66f90ec4 100644 --- a/examples/with-dioxus/demo/migrations/0001_initial.sql +++ b/examples/with-dioxus/demo/migrations/0001_initial.sql @@ -1,6 +1,10 @@ -CREATE TYPE user_role AS ENUM ('admin', 'member', 'guest'); +DO $$ BEGIN + CREATE TYPE user_role AS ENUM ('admin', 'member', 'guest'); +EXCEPTION + WHEN duplicate_object THEN NULL; +END $$; -CREATE TABLE users ( +CREATE TABLE IF NOT EXISTS users ( id UUID PRIMARY KEY, email VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL, @@ -10,9 +14,9 @@ CREATE TABLE users ( updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE UNIQUE INDEX idx_users_email ON users(email); +CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email); -CREATE TABLE iss_location ( +CREATE TABLE IF NOT EXISTS iss_location ( id UUID PRIMARY KEY, latitude DOUBLE PRECISION NOT NULL, longitude DOUBLE PRECISION NOT NULL, @@ -20,7 +24,7 @@ CREATE TABLE iss_location ( created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE TABLE trades ( +CREATE TABLE IF NOT EXISTS trades ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), symbol VARCHAR(20) NOT NULL, price DOUBLE PRECISION NOT NULL, @@ -30,9 +34,9 @@ CREATE TABLE trades ( created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE INDEX idx_trades_created_at ON trades(created_at DESC); +CREATE INDEX IF NOT EXISTS idx_trades_created_at ON trades(created_at DESC); -CREATE TABLE webhook_events ( +CREATE TABLE IF NOT EXISTS webhook_events ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), idempotency_key VARCHAR(255) NOT NULL, webhook_name VARCHAR(100) NOT NULL, @@ -40,7 +44,7 @@ CREATE TABLE webhook_events ( processed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); -CREATE INDEX idx_webhook_events_processed_at ON webhook_events(processed_at DESC); +CREATE INDEX IF NOT EXISTS idx_webhook_events_processed_at ON webhook_events(processed_at DESC); SELECT forge_enable_reactivity('users'); SELECT forge_enable_reactivity('iss_location'); @@ -48,7 +52,7 @@ SELECT forge_enable_reactivity('trades'); SELECT forge_enable_reactivity('webhook_events'); -- Stats snapshot table for cached query demo -CREATE TABLE demo_stats ( +CREATE TABLE IF NOT EXISTS demo_stats ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), total_users INTEGER NOT NULL DEFAULT 0, total_trades INTEGER NOT NULL DEFAULT 0, @@ -56,14 +60,18 @@ CREATE TABLE demo_stats ( computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() ); --- Sample user for demo (password: "password123") +-- Demo admin user (password: "password123"). Idempotent on re-run. +-- IMPORTANT: This is a known-credential account for the demo only. +-- For production deployments, delete this seed block before running migrations +-- and create your first admin via a separate one-off script with a strong password. INSERT INTO users (id, email, name, role, password_hash, created_at, updated_at) VALUES ( 'a1b2c3d4-e5f6-4a5b-8c9d-0e1f2a3b4c5d', 'demo@example.com', 'Demo User', - 'member', + 'admin', '$argon2id$v=19$m=19456,t=2,p=1$AjozmE60AjazLA3S4LXuvw$v+Jo+M5NZ+Q1K4ro1pDS4Hx0/cnHJ3uvmJC7RiNJkUg', NOW(), NOW() -); +) +ON CONFLICT (id) DO UPDATE SET role = EXCLUDED.role; diff --git a/examples/with-dioxus/demo/src/functions/auth.rs b/examples/with-dioxus/demo/src/functions/auth.rs index acfcd4fd..4c0ab42c 100644 --- a/examples/with-dioxus/demo/src/functions/auth.rs +++ b/examples/with-dioxus/demo/src/functions/auth.rs @@ -20,7 +20,12 @@ pub struct RefreshInput { } async fn auth_response(ctx: &MutationContext, user: &User) -> Result { - let pair = ctx.issue_token_pair(user.id, &["user"]).await?; + let role = match user.role { + UserRole::Admin => "admin", + UserRole::Member => "user", + UserRole::Guest => "guest", + }; + let pair = ctx.issue_token_pair(user.id, &[role]).await?; Ok(AuthResponse { access_token: pair.access_token, refresh_token: pair.refresh_token, diff --git a/examples/with-dioxus/demo/src/functions/export.rs b/examples/with-dioxus/demo/src/functions/export.rs index 365fe3b5..9c5fd87b 100644 --- a/examples/with-dioxus/demo/src/functions/export.rs +++ b/examples/with-dioxus/demo/src/functions/export.rs @@ -14,7 +14,12 @@ pub struct ExportOutput { pub format: String, } -/// Export users as CSV or JSON with progress reporting +/// Export users as CSV or JSON with progress reporting. +/// +/// The `tokio::time::sleep` calls below are SIMULATED work — they exist solely so the +/// progress UI is visible in the demo. Replace them with real I/O (S3 puts, large +/// DB scans, format conversion) in production code. Never ship sleep-padded jobs: +/// they pin worker slots and inflate p99 for no value. #[forge::job( timeout = "5m", priority = "low", diff --git a/examples/with-dioxus/demo/src/functions/iss.rs b/examples/with-dioxus/demo/src/functions/iss.rs index 247d5bec..49875212 100644 --- a/examples/with-dioxus/demo/src/functions/iss.rs +++ b/examples/with-dioxus/demo/src/functions/iss.rs @@ -47,7 +47,7 @@ pub async fn iss_location(ctx: &CronContext) -> Result<()> { let response = ctx .http() - .get("http://api.open-notify.org/iss-now.json") + .get("https://api.open-notify.org/iss-now.json") .send() .await .map_err(|e| ForgeError::internal(format!("HTTP request failed: {}", e)))?; @@ -69,8 +69,16 @@ pub async fn iss_location(ctx: &CronContext) -> Result<()> { tracing::warn!(message = %data.message, "ISS API non-success"); } - let latitude: f64 = data.iss_position.latitude.parse().unwrap_or(0.0); - let longitude: f64 = data.iss_position.longitude.parse().unwrap_or(0.0); + let latitude: f64 = data + .iss_position + .latitude + .parse() + .map_err(|e| ForgeError::Deserialization(format!("invalid latitude: {e}")))?; + let longitude: f64 = data + .iss_position + .longitude + .parse() + .map_err(|e| ForgeError::Deserialization(format!("invalid longitude: {e}")))?; sqlx::query!( "INSERT INTO iss_location (id, latitude, longitude, api_timestamp, created_at) \ diff --git a/examples/with-dioxus/demo/src/functions/mcp.rs b/examples/with-dioxus/demo/src/functions/mcp.rs index 31d7c945..5a0d77e3 100644 --- a/examples/with-dioxus/demo/src/functions/mcp.rs +++ b/examples/with-dioxus/demo/src/functions/mcp.rs @@ -30,10 +30,10 @@ pub async fn mcp_me(ctx: &McpToolContext) -> forge::forge_core::Result forge::forge_core::Result> { + let _ = ctx.user_id()?; let mut conn = ctx.conn().await?; let users = sqlx::query_as!( @@ -63,13 +63,13 @@ pub struct McpGetUserInput { name = "demo.get_user_by_email", title = "Get User by Email", description = "Look up a single user by their email address", - public, read_only )] pub async fn mcp_get_user_by_email( ctx: &McpToolContext, input: McpGetUserInput, ) -> forge::forge_core::Result> { + let _ = ctx.user_id()?; let mut conn = ctx.conn().await?; let user = sqlx::query_as!( diff --git a/examples/with-dioxus/demo/src/functions/stats.rs b/examples/with-dioxus/demo/src/functions/stats.rs index 907ac11f..0eb5e8d3 100644 --- a/examples/with-dioxus/demo/src/functions/stats.rs +++ b/examples/with-dioxus/demo/src/functions/stats.rs @@ -3,6 +3,8 @@ use forge::prelude::*; #[forge::query(cache = "10s", auth = "none")] pub async fn get_demo_stats(ctx: &QueryContext) -> Result { + // Simulated work to make the `cache = "10s"` demo visible to a human watching the UI. + // Real handlers must not call sleep — it pins a worker thread for no useful work. tokio::time::sleep(std::time::Duration::from_millis(500)).await; let row = sqlx::query!( diff --git a/examples/with-dioxus/demo/src/functions/trades.rs b/examples/with-dioxus/demo/src/functions/trades.rs index 1b1792f7..bdf94df3 100644 --- a/examples/with-dioxus/demo/src/functions/trades.rs +++ b/examples/with-dioxus/demo/src/functions/trades.rs @@ -69,35 +69,53 @@ pub async fn trade_stream(ctx: &DaemonContext) -> Result<()> { msg = read.next() => { match msg { Some(Ok(Message::Text(text))) => { - if let Ok(trade) = serde_json::from_str::(&text) { - let price: f64 = trade.price.parse().unwrap_or(0.0); - let quantity: f64 = trade.quantity.parse().unwrap_or(0.0); - let trade_time = chrono::DateTime::from_timestamp_millis(trade.trade_time) + match serde_json::from_str::(&text) { + Ok(trade) => { + let price: f64 = trade.price.parse().map_err(|e| { + ForgeError::Deserialization(format!("invalid trade price: {e}")) + })?; + let quantity: f64 = trade.quantity.parse().map_err(|e| { + ForgeError::Deserialization(format!( + "invalid trade quantity: {e}" + )) + })?; + let trade_time = chrono::DateTime::from_timestamp_millis( + trade.trade_time, + ) .unwrap_or_else(Utc::now); - sqlx::query!( - "INSERT INTO trades (id, symbol, price, quantity, trade_time, is_buyer_maker, created_at) \ - VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, NOW())", - &trade.symbol, - price, - quantity, - trade_time, - trade.is_buyer_maker - ) - .execute(ctx.db()) - .await - .ok(); + sqlx::query!( + "INSERT INTO trades (id, symbol, price, quantity, trade_time, is_buyer_maker, created_at) \ + VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, NOW())", + &trade.symbol, + price, + quantity, + trade_time, + trade.is_buyer_maker + ) + .execute(ctx.db()) + .await?; + } + Err(e) => { + tracing::warn!("Skipping unparsable trade message: {e}"); + } } } Some(Ok(Message::Close(_))) => { - tracing::warn!("WebSocket closed by server"); - break; + return Err(ForgeError::internal( + "Binance WebSocket closed by server; daemon will restart", + )); } Some(Err(e)) => { - tracing::error!("WebSocket error: {}", e); - break; + return Err(ForgeError::internal(format!( + "Binance WebSocket error: {e}" + ))); + } + None => { + return Err(ForgeError::internal( + "Binance WebSocket stream ended; daemon will restart", + )); } - None => break, _ => {} } } diff --git a/examples/with-dioxus/demo/src/functions/users.rs b/examples/with-dioxus/demo/src/functions/users.rs index 9d7b90f2..33955888 100644 --- a/examples/with-dioxus/demo/src/functions/users.rs +++ b/examples/with-dioxus/demo/src/functions/users.rs @@ -1,9 +1,11 @@ use crate::schema::{User, UserRole}; use forge::prelude::*; -/// List all users with reactive subscription support -#[forge::query(cache = "30s", auth = "none")] +/// List all users with reactive subscription support. +/// Reading the global user list requires an authenticated session. +#[forge::query(cache = "30s", unscoped)] pub async fn get_users(ctx: &QueryContext) -> Result> { + let _ = ctx.user_id()?; sqlx::query_as!( User, r#" @@ -24,9 +26,10 @@ pub async fn get_users(ctx: &QueryContext) -> Result> { .map_err(Into::into) } -/// Get single user by ID -#[forge::query(timeout = "10s", auth = "none")] +/// Get single user by ID. Requires an authenticated session. +#[forge::query(timeout = "10s", unscoped)] pub async fn get_user(ctx: &QueryContext, id: Uuid) -> Result> { + let _ = ctx.user_id()?; sqlx::query_as!( User, r#" @@ -48,14 +51,15 @@ pub async fn get_user(ctx: &QueryContext, id: Uuid) -> Result> { .map_err(Into::into) } -/// Create a new user -#[forge::mutation(auth = "none")] +/// Create a new user. Requires the `admin` role. +#[forge::mutation(scope = "global")] pub async fn create_user( ctx: &MutationContext, email: String, name: String, role: Option, ) -> Result { + ctx.auth.require_role("admin")?; let id = Uuid::new_v4(); let now = Utc::now(); let role = role.unwrap_or_default(); @@ -87,8 +91,8 @@ pub async fn create_user( .map_err(Into::into) } -/// Update user with partial fields -#[forge::mutation(timeout = "30s", auth = "none")] +/// Update user with partial fields. Requires the `admin` role. +#[forge::mutation(timeout = "30s", scope = "global")] pub async fn update_user( ctx: &MutationContext, id: Uuid, @@ -96,6 +100,7 @@ pub async fn update_user( name: Option, role: Option, ) -> Result { + ctx.auth.require_role("admin")?; let mut conn = ctx.conn().await.map_err(ForgeError::Database)?; sqlx::query_as!( User, @@ -126,9 +131,10 @@ pub async fn update_user( .map_err(Into::into) } -/// Delete user by ID -#[forge::mutation(auth = "none")] +/// Delete user by ID. Requires the `admin` role. +#[forge::mutation(scope = "global")] pub async fn delete_user(ctx: &MutationContext, id: Uuid) -> Result { + ctx.auth.require_role("admin")?; let mut conn = ctx.conn().await.map_err(ForgeError::Database)?; let result = sqlx::query!("DELETE FROM users WHERE id = $1", id) .execute(&mut conn) @@ -136,3 +142,168 @@ pub async fn delete_user(ctx: &MutationContext, id: Uuid) -> Result { Ok(result.rows_affected() > 0) } + +#[cfg(all(test, feature = "testcontainers"))] +mod tests { + use super::*; + use forge::forge_core::function::{AuthContext, RequestMetadata}; + use forge::testing::{IsolatedTestDb, TestDatabase}; + use std::path::Path; + + async fn setup_db() -> IsolatedTestDb { + let base = TestDatabase::from_env().await.unwrap(); + let db = base.isolated("users_test").await.unwrap(); + db.run_sql(&forge::get_internal_sql()).await.unwrap(); + db.migrate(Path::new("migrations")).await.unwrap(); + db + } + + fn admin_auth() -> AuthContext { + AuthContext::authenticated(Uuid::new_v4(), vec!["admin".into()], Default::default()) + } + + fn query_ctx(pool: sqlx::PgPool) -> QueryContext { + QueryContext::new(pool, admin_auth(), RequestMetadata::default()) + } + + fn mutation_ctx(pool: sqlx::PgPool) -> MutationContext { + MutationContext::new(pool, admin_auth(), RequestMetadata::default()) + } + + #[tokio::test] + async fn test_create_user() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user(&ctx, "test@example.com".into(), "Test User".into(), None) + .await + .unwrap(); + + assert_eq!(user.email, "test@example.com"); + assert_eq!(user.name, "Test User"); + assert_eq!(user.role, UserRole::default()); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_create_user_with_role() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user( + &ctx, + "admin@example.com".into(), + "Admin".into(), + Some(UserRole::Admin), + ) + .await + .unwrap(); + + assert_eq!(user.role, UserRole::Admin); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_create_user_requires_admin_role() { + let db = setup_db().await; + let ctx = MutationContext::new( + db.pool().clone(), + AuthContext::authenticated(Uuid::new_v4(), vec!["member".into()], Default::default()), + RequestMetadata::default(), + ); + + let result = create_user(&ctx, "nope@example.com".into(), "No Admin".into(), None).await; + + assert!( + matches!(result, Err(ForgeError::Forbidden(_))), + "non-admin must be rejected with Forbidden, got {result:?}" + ); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_get_users() { + let db = setup_db().await; + let m_ctx = mutation_ctx(db.pool().clone()); + + create_user(&m_ctx, "a@test.com".into(), "User A".into(), None) + .await + .unwrap(); + create_user(&m_ctx, "b@test.com".into(), "User B".into(), None) + .await + .unwrap(); + + let q_ctx = query_ctx(db.pool().clone()); + let users = get_users(&q_ctx).await.unwrap(); + assert!(users.len() >= 2); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_get_user_by_id() { + let db = setup_db().await; + let m_ctx = mutation_ctx(db.pool().clone()); + + let created = create_user(&m_ctx, "find@test.com".into(), "Find Me".into(), None) + .await + .unwrap(); + + let q_ctx = query_ctx(db.pool().clone()); + let found = get_user(&q_ctx, created.id).await.unwrap(); + assert!(found.is_some()); + assert_eq!(found.unwrap().id, created.id); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_get_user_not_found() { + let db = setup_db().await; + let ctx = query_ctx(db.pool().clone()); + + let result = get_user(&ctx, Uuid::new_v4()).await.unwrap(); + assert!(result.is_none()); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_update_user() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user(&ctx, "update@test.com".into(), "Original".into(), None) + .await + .unwrap(); + + let updated = update_user( + &ctx, + user.id, + Some("new@test.com".into()), + Some("Updated".into()), + None, + ) + .await + .unwrap(); + + assert_eq!(updated.email, "new@test.com"); + assert_eq!(updated.name, "Updated"); + db.cleanup().await.unwrap(); + } + + #[tokio::test] + async fn test_delete_user() { + let db = setup_db().await; + let ctx = mutation_ctx(db.pool().clone()); + + let user = create_user(&ctx, "delete@test.com".into(), "Delete Me".into(), None) + .await + .unwrap(); + + let deleted = delete_user(&ctx, user.id).await.unwrap(); + assert!(deleted); + + let q_ctx = query_ctx(db.pool().clone()); + let found = get_user(&q_ctx, user.id).await.unwrap(); + assert!(found.is_none()); + db.cleanup().await.unwrap(); + } +} diff --git a/examples/with-dioxus/demo/src/functions/verification.rs b/examples/with-dioxus/demo/src/functions/verification.rs index f4dd645c..91f6e7f3 100644 --- a/examples/with-dioxus/demo/src/functions/verification.rs +++ b/examples/with-dioxus/demo/src/functions/verification.rs @@ -112,14 +112,15 @@ pub struct ConfirmVerificationInput { pub workflow_id: String, } -#[forge::mutation(auth = "none")] // forge_workflow_events is owned by the runtime, so the framework user's .sqlx // cache doesn't see it. Runtime sqlx::query is the right tool here. +#[forge::mutation(tables("forge_workflow_events"), scope = "global")] #[allow(clippy::disallowed_methods)] pub async fn confirm_verification( ctx: &MutationContext, input: ConfirmVerificationInput, ) -> Result { + let _ = ctx.user_id()?; // Insert the confirmation event into the workflow events table. // The scheduler's NOTIFY trigger will wake the waiting workflow. sqlx::query( diff --git a/examples/with-dioxus/demo/src/functions/webhook.rs b/examples/with-dioxus/demo/src/functions/webhook.rs index 8fe6cef8..70468ad9 100644 --- a/examples/with-dioxus/demo/src/functions/webhook.rs +++ b/examples/with-dioxus/demo/src/functions/webhook.rs @@ -1,4 +1,6 @@ use forge::prelude::*; +use hmac::{Hmac, Mac}; +use sha2::Sha256; /// Webhook event record stored in database #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, sqlx::FromRow)] @@ -59,3 +61,53 @@ pub async fn demo_webhook(ctx: &WebhookContext, payload: Value) -> Result Result { + let secret = ctx.env_require("WEBHOOK_SECRET")?; + let port: u16 = ctx.env_parse_or("PORT", 9081u16)?; + let payload = serde_json::json!({ + "action": "test", + "ts": Utc::now().timestamp_millis(), + }) + .to_string(); + + let mut mac = as Mac>::new_from_slice(secret.as_bytes()) + .map_err(|e| ForgeError::internal(format!("HMAC key init failed: {e}")))?; + mac.update(payload.as_bytes()); + let signature = hex::encode(mac.finalize().into_bytes()); + let timestamp = Utc::now().timestamp(); + + // Deliberate loopback call to this server's own webhook endpoint. The + // framework's `ctx.http()` client blocks private/loopback IPs (SSRF guard), + // so use a plain reqwest client for this intentional self-call. + let response = reqwest::Client::new() + .post(format!("http://127.0.0.1:{port}/_api/webhooks/demo")) + .header("Content-Type", "application/json") + .header("X-Webhook-Signature", signature) + .header("X-Webhook-Timestamp", timestamp.to_string()) + .header("X-Idempotency-Key", &input.idempotency_key) + .body(payload) + .send() + .await + .map_err(|e| ForgeError::internal(format!("Webhook self-call failed: {e}")))?; + + if !response.status().is_success() { + return Err(ForgeError::internal(format!( + "Webhook returned status {}", + response.status().as_u16() + ))); + } + + Ok(true) +} diff --git a/examples/with-dioxus/minimal/Dockerfile b/examples/with-dioxus/minimal/Dockerfile index 61905ba6..ac170b65 100644 --- a/examples/with-dioxus/minimal/Dockerfile +++ b/examples/with-dioxus/minimal/Dockerfile @@ -1,9 +1,9 @@ -FROM rust:1.92 AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app RUN cargo install cargo-watch --locked RUN apt-get update && apt-get install -y curl pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* -FROM rust:1.92 AS frontend-builder +FROM rust:1.92-slim-bookworm AS frontend-builder WORKDIR /app/frontend RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked @@ -11,7 +11,7 @@ COPY frontend/Cargo.toml frontend/Dioxus.toml ./ COPY frontend/src ./src RUN dx build --web --release -FROM rust:1.92 AS builder +FROM rust:1.92-slim-bookworm AS builder WORKDIR /app RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* @@ -24,7 +24,7 @@ COPY --from=frontend-builder /app/frontend/dist ./frontend/dist RUN cargo build --release -FROM debian:bookworm-slim AS runtime +FROM debian:bookworm-20250203-slim AS runtime RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/forge-dioxus-minimal-template /app/forge-dioxus-minimal-template diff --git a/examples/with-dioxus/minimal/docker-compose.yml b/examples/with-dioxus/minimal/docker-compose.yml index 308334b6..46487659 100644 --- a/examples/with-dioxus/minimal/docker-compose.yml +++ b/examples/with-dioxus/minimal/docker-compose.yml @@ -5,7 +5,7 @@ services: dockerfile: Dockerfile target: dev ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -44,7 +44,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-dioxus/minimal/forge.toml b/examples/with-dioxus/minimal/forge.toml index db95d297..0edf5eb7 100644 --- a/examples/with-dioxus/minimal/forge.toml +++ b/examples/with-dioxus/minimal/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -48,9 +48,12 @@ otlp_endpoint = "http://localhost:4318" # --- RSA (RS256/RS384/RS512) - asymmetric, use for external providers --- # jwks_url = "" # Provider JWKS URLs: see auth reference docs -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path (DDoS-grade). strict: PG round-trip (billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] # [cluster] # name = "node-1" # auto-generated if omitted diff --git a/examples/with-dioxus/minimal/frontend/playwright.config.ts b/examples/with-dioxus/minimal/frontend/playwright.config.ts index a6685492..7a839a3a 100644 --- a/examples/with-dioxus/minimal/frontend/playwright.config.ts +++ b/examples/with-dioxus/minimal/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 180_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-dioxus/minimal/frontend/src/forge/types.rs b/examples/with-dioxus/minimal/frontend/src/forge/types.rs index 58a2ca15..d0e29b99 100644 --- a/examples/with-dioxus/minimal/frontend/src/forge/types.rs +++ b/examples/with-dioxus/minimal/frontend/src/forge/types.rs @@ -1,10 +1,5 @@ // @generated by FORGE - DO NOT EDIT -#![allow( - dead_code, - unused_imports, - clippy::redundant_field_names, - clippy::too_many_arguments -)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; diff --git a/examples/with-dioxus/minimal/migrations/0001_initial.sql.example b/examples/with-dioxus/minimal/migrations/0001_initial.sql.example index efa755eb..3d416208 100644 --- a/examples/with-dioxus/minimal/migrations/0001_initial.sql.example +++ b/examples/with-dioxus/minimal/migrations/0001_initial.sql.example @@ -1,6 +1,4 @@ --- @up - --- Replace with your tables here +-- Migrations are forward-only. Add your schema changes below. -- Example: -- CREATE TABLE IF NOT EXISTS users ( -- id UUID PRIMARY KEY, @@ -9,10 +7,3 @@ -- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -- ); -- SELECT forge_enable_reactivity('users'); - --- @down - --- Add your rollback statements here --- Example: --- SELECT forge_disable_reactivity('users'); --- DROP TABLE IF EXISTS users; diff --git a/examples/with-dioxus/realtime-todo-list/.env b/examples/with-dioxus/realtime-todo-list/.env index 05b41b5e..87b0a353 100644 --- a/examples/with-dioxus/realtime-todo-list/.env +++ b/examples/with-dioxus/realtime-todo-list/.env @@ -11,8 +11,8 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=todo-dioxus POSTGRES_PORT=5432 -# Optional: JWT secret for authentication -# FORGE_SECRET=your-secret-key-here +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=dev-jwt-secret-not-for-production-use-please-rotate # Enable offline mode for sqlx compile-time checks SQLX_OFFLINE=true diff --git a/examples/with-dioxus/realtime-todo-list/.env.example b/examples/with-dioxus/realtime-todo-list/.env.example index bd7f2e73..4e5bf825 100644 --- a/examples/with-dioxus/realtime-todo-list/.env.example +++ b/examples/with-dioxus/realtime-todo-list/.env.example @@ -13,5 +13,5 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=todo-dioxus POSTGRES_PORT=5432 -# Optional: JWT secret for authentication -# FORGE_SECRET=your-secret-key-here +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=CHANGE_ME_USE_OPENSSL_RAND_BASE64_32 diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json new file mode 100644 index 00000000..fd027d65 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26.json @@ -0,0 +1,46 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT * FROM todos WHERE user_id = $1 ORDER BY created_at DESC", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "title", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "completed", + "type_info": "Bool" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + false, + false + ] + }, + "hash": "0ef63404257f0212092be1612e8f63c641c19b14ceb671016d66310a74dbca26" +} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json similarity index 50% rename from examples/with-dioxus/realtime-todo-list/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json rename to examples/with-dioxus/realtime-todo-list/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json index fb994d1e..5b63fbe4 100644 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19.json +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e.json @@ -1,14 +1,15 @@ { "db_name": "PostgreSQL", - "query": "DELETE FROM todos WHERE id = $1", + "query": "DELETE FROM todos WHERE id = $1 AND user_id = $2", "describe": { "columns": [], "parameters": { "Left": [ + "Uuid", "Uuid" ] }, "nullable": [] }, - "hash": "183ad1d8316ef2ae5ac6ae4811b8a2bdbaeabbe137a871e26741a419a1aa5b19" + "hash": "2e465c3f5f3b3fb29f51cefabfab678ae2f60db0bfdbd437fb00618618b0ab5e" } diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json new file mode 100644 index 00000000..e2ac3226 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c.json @@ -0,0 +1,57 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO users (id, email, name, password_hash, created_at, updated_at)\n VALUES ($1, $2, $3, $4, $5, $6)\n RETURNING id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Varchar", + "Varchar", + "Text", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "4aaff6edf11f4e43ee07cf8f58ebeb5479ac15d11cb1c4fdfd0cd7247f519f3c" +} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json new file mode 100644 index 00000000..3e00e871 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n FROM users WHERE email = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "6f810436b0b1e5e2e79b283ce91992a65a3d646156b70a30a3377a4e0f19f1f3" +} diff --git a/examples/with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json similarity index 68% rename from examples/with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json rename to examples/with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json index f5b84d9f..6ed720db 100644 --- a/examples/with-svelte/realtime-todo-list/.sqlx/query-289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39.json +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "INSERT INTO todos (title) VALUES ($1) RETURNING *", + "query": "INSERT INTO todos (user_id, title) VALUES ($1, $2) RETURNING *", "describe": { "columns": [ { @@ -10,22 +10,28 @@ }, { "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, "name": "title", "type_info": "Text" }, { - "ordinal": 2, + "ordinal": 3, "name": "completed", "type_info": "Bool" }, { - "ordinal": 3, + "ordinal": 4, "name": "created_at", "type_info": "Timestamptz" } ], "parameters": { "Left": [ + "Uuid", "Text" ] }, @@ -33,8 +39,9 @@ false, false, false, + false, false ] }, - "hash": "289c71ceebdcb32b1fa7de751cca0918c3286db00bfe90e56cdec7458e1e7b39" + "hash": "cd3a68eb363ca38467993dda8a6b904549a2f663613d7570b529eeaa13a6d9aa" } diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json similarity index 74% rename from examples/with-dioxus/realtime-todo-list/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json rename to examples/with-dioxus/realtime-todo-list/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json index 1ebed98d..37c21eb7 100644 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f.json +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "UPDATE todos\n SET title = COALESCE($1, title),\n completed = COALESCE($2, completed)\n WHERE id = $3\n RETURNING *", + "query": "UPDATE todos\n SET title = COALESCE($1, title),\n completed = COALESCE($2, completed)\n WHERE id = $3 AND user_id = $4\n RETURNING *", "describe": { "columns": [ { @@ -10,16 +10,21 @@ }, { "ordinal": 1, + "name": "user_id", + "type_info": "Uuid" + }, + { + "ordinal": 2, "name": "title", "type_info": "Text" }, { - "ordinal": 2, + "ordinal": 3, "name": "completed", "type_info": "Bool" }, { - "ordinal": 3, + "ordinal": 4, "name": "created_at", "type_info": "Timestamptz" } @@ -28,6 +33,7 @@ "Left": [ "Text", "Bool", + "Uuid", "Uuid" ] }, @@ -35,8 +41,9 @@ false, false, false, + false, false ] }, - "hash": "c2eda736e5f6342831005dfbd5281fbeb29cb84e74ff483828f9e3ee0fcc517f" + "hash": "d1b4c05fc3f85f6412e22e47208952ef8f47e93152275994011c4382e85ed7f2" } diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json new file mode 100644 index 00000000..a22ee7e8 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/.sqlx/query-e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT id, email, name, password_hash as \"password_hash!\", created_at, updated_at\n FROM users WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "email", + "type_info": "Varchar" + }, + { + "ordinal": 2, + "name": "name", + "type_info": "Varchar" + }, + { + "ordinal": 3, + "name": "password_hash!", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + true, + false, + false + ] + }, + "hash": "e7bce187d4ced5dfaa7bbc448fbb97af559e04fa847bfa0ec37ecc45354d3d22" +} diff --git a/examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json b/examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json deleted file mode 100644 index 71889f63..00000000 --- a/examples/with-dioxus/realtime-todo-list/.sqlx/query-fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT * FROM todos ORDER BY created_at DESC", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "title", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "completed", - "type_info": "Bool" - }, - { - "ordinal": 3, - "name": "created_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false, - false - ] - }, - "hash": "fe67fe1d5492dd97f324a0fa1a6b73e7b6282402018f150a58eefe67319ba763" -} diff --git a/examples/with-dioxus/realtime-todo-list/Cargo.toml b/examples/with-dioxus/realtime-todo-list/Cargo.toml index 2d0a46ab..f77143b6 100644 --- a/examples/with-dioxus/realtime-todo-list/Cargo.toml +++ b/examples/with-dioxus/realtime-todo-list/Cargo.toml @@ -18,6 +18,8 @@ uuid = { version = "1", features = ["v4", "serde"] } chrono = { version = "0.4", features = ["serde"] } sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "chrono", "uuid", "macros", "derive"] } dotenvy = "0.15" +argon2 = "0.5" +password-hash = "0.5" rust-embed = { version = "8", optional = true } [build-dependencies] diff --git a/examples/with-dioxus/realtime-todo-list/Dockerfile b/examples/with-dioxus/realtime-todo-list/Dockerfile index 509573a0..99f22236 100644 --- a/examples/with-dioxus/realtime-todo-list/Dockerfile +++ b/examples/with-dioxus/realtime-todo-list/Dockerfile @@ -1,11 +1,11 @@ -FROM rust:1.92 AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app RUN cargo install cargo-watch --locked RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked RUN apt-get update && apt-get install -y curl pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* -FROM rust:1.92 AS frontend-builder +FROM rust:1.92-slim-bookworm AS frontend-builder WORKDIR /app/examples/todo-dioxus/frontend RUN rustup target add wasm32-unknown-unknown RUN cargo install dioxus-cli --version 0.7.3 --locked @@ -14,7 +14,7 @@ COPY examples/todo-dioxus/frontend/.forge ./.forge COPY examples/todo-dioxus/frontend/src ./src RUN dx build --web --release -FROM rust:1.92 AS builder +FROM rust:1.92-slim-bookworm AS builder WORKDIR /app RUN apt-get update && apt-get install -y pkg-config libssl-dev && rm -rf /var/lib/apt/lists/* @@ -26,7 +26,7 @@ COPY --from=frontend-builder /app/examples/todo-dioxus/frontend/dist ./examples/ WORKDIR /app/examples/todo-dioxus RUN cargo build --release -FROM debian:bookworm-slim AS runtime +FROM debian:bookworm-20250203-slim AS runtime RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* WORKDIR /app COPY --from=builder /app/target/release/todo-dioxus /app/todo-dioxus diff --git a/examples/with-dioxus/realtime-todo-list/docker-compose.yml b/examples/with-dioxus/realtime-todo-list/docker-compose.yml index b4402e0c..69158e0d 100644 --- a/examples/with-dioxus/realtime-todo-list/docker-compose.yml +++ b/examples/with-dioxus/realtime-todo-list/docker-compose.yml @@ -6,7 +6,7 @@ services: target: dev working_dir: /workspace/examples/with-dioxus/realtime-todo-list ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -44,7 +44,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-dioxus/realtime-todo-list/forge.toml b/examples/with-dioxus/realtime-todo-list/forge.toml index 75f6f81f..3e711262 100644 --- a/examples/with-dioxus/realtime-todo-list/forge.toml +++ b/examples/with-dioxus/realtime-todo-list/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -39,7 +39,12 @@ otlp_endpoint = "http://localhost:4318" # job_timeout = "5m" # poll_interval = "1s" -# [auth] +[auth] +jwt_algorithm = "HS256" +jwt_secret = "${JWT_SECRET}" +jwt_audience = "${JWT_AUDIENCE}" + +# Legacy template values below kept for reference. # jwt_algorithm = "HS256" # HS256, HS384, HS512, RS256, RS384, RS512 # # --- HMAC (HS256/HS384/HS512) - symmetric, use for self-issued tokens --- @@ -48,9 +53,12 @@ otlp_endpoint = "http://localhost:4318" # --- RSA (RS256/RS384/RS512) - asymmetric, use for external providers --- # jwks_url = "" # Provider JWKS URLs: see auth reference docs -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path (DDoS-grade). strict: PG round-trip (billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] # [cluster] # name = "node-1" # auto-generated if omitted diff --git a/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts b/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts index a6685492..7a839a3a 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts +++ b/examples/with-dioxus/realtime-todo-list/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 180_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs index 99a20a4e..66adc0e2 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/api.rs @@ -29,6 +29,17 @@ pub fn use_list_todos() -> QueryState> { pub fn use_list_todos_subscription() -> SubscriptionState> { use_forge_subscription("list_todos", ()) } +pub async fn me(client: &ForgeClient) -> Result { + client.call("me", ()).await +} + +pub fn use_me() -> QueryState { + use_forge_query("me", ()) +} + +pub fn use_me_subscription() -> SubscriptionState { + use_forge_subscription("me", ()) +} pub async fn create_todo( client: &ForgeClient, args: CreateTodoInput, @@ -59,6 +70,36 @@ pub async fn delete_todo( pub fn use_delete_todo() -> Mutation { use_forge_mutation("delete_todo") } +pub async fn login( + client: &ForgeClient, + args: LoginInput, +) -> Result { + client.call("login", args).await +} + +pub fn use_login() -> Mutation { + use_forge_mutation("login") +} +pub async fn refresh_token( + client: &ForgeClient, + args: RefreshInput, +) -> Result { + client.call("refresh_token", args).await +} + +pub fn use_refresh_token() -> Mutation { + use_forge_mutation("refresh_token") +} +pub async fn register( + client: &ForgeClient, + args: RegisterInput, +) -> Result { + client.call("register", args).await +} + +pub fn use_register() -> Mutation { + use_forge_mutation("register") +} pub async fn update_todo( client: &ForgeClient, args: UpdateTodoInput, diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs index 42c8491f..d77824ca 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/forge/types.rs @@ -1,13 +1,29 @@ // @generated by FORGE - DO NOT EDIT -#![allow( - dead_code, - unused_imports, - clippy::redundant_field_names, - clippy::too_many_arguments -)] +#![allow(dead_code, unused_imports, clippy::too_many_arguments)] use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct AuthResponse { + pub access_token: String, + pub refresh_token: String, + pub user: UserPublic, +} + +impl AuthResponse { + pub fn new( + access_token: impl Into, + refresh_token: impl Into, + user: UserPublic, + ) -> Self { + Self { + access_token: access_token.into(), + refresh_token: refresh_token.into(), + user, + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct CreateTodoInput { pub title: String, @@ -21,9 +37,59 @@ impl CreateTodoInput { } } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct LoginInput { + pub email: String, + pub password: String, +} + +impl LoginInput { + pub fn new(email: impl Into, password: impl Into) -> Self { + Self { + email: email.into(), + password: password.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RefreshInput { + pub refresh_token: String, +} + +impl RefreshInput { + pub fn new(refresh_token: impl Into) -> Self { + Self { + refresh_token: refresh_token.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RegisterInput { + pub email: String, + pub name: String, + pub password: String, +} + +impl RegisterInput { + pub fn new( + email: impl Into, + name: impl Into, + password: impl Into, + ) -> Self { + Self { + email: email.into(), + name: name.into(), + password: password.into(), + } + } +} + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Todo { pub id: String, + pub user_id: String, pub title: String, pub completed: bool, pub created_at: String, @@ -55,3 +121,47 @@ impl UpdateTodoInput { self } } + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct User { + pub id: String, + pub email: String, + pub name: String, + pub created_at: String, + pub updated_at: String, +} + +impl User { + pub fn new( + id: impl Into, + email: impl Into, + name: impl Into, + created_at: impl Into, + updated_at: impl Into, + ) -> Self { + Self { + id: id.into(), + email: email.into(), + name: name.into(), + created_at: created_at.into(), + updated_at: updated_at.into(), + } + } +} + +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct UserPublic { + pub id: String, + pub email: String, + pub name: String, +} + +impl UserPublic { + pub fn new(id: impl Into, email: impl Into, name: impl Into) -> Self { + Self { + id: id.into(), + email: email.into(), + name: name.into(), + } + } +} diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs index 6d95b780..40e94038 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/main.rs @@ -3,7 +3,7 @@ mod todo_app; mod todo_item; use dioxus::prelude::*; -use forge::ForgeProvider; +use forge::ForgeAuthProvider; use todo_app::TodoApp; fn api_url() -> &'static str { @@ -19,8 +19,9 @@ fn App() -> Element { rsx! { document::Title { "Todos" } document::Stylesheet { href: asset!("/public/style.css") } - ForgeProvider { + ForgeAuthProvider { url: api_url().to_string(), + app_name: "todo-dioxus".to_string(), TodoApp {} } } diff --git a/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs b/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs index ec5980e0..dacb6a85 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs +++ b/examples/with-dioxus/realtime-todo-list/frontend/src/todo_app.rs @@ -2,11 +2,164 @@ use dioxus::prelude::*; use forge_dioxus::use_signals; use serde_json::json; -use crate::forge::{CreateTodoInput, use_create_todo, use_list_todos_subscription}; +use crate::forge::{ + CreateTodoInput, LoginInput, RegisterInput, UserPublic, use_create_todo, use_forge_auth, + use_list_todos_subscription, use_login, use_register, +}; use crate::todo_item::TodoItem; #[component] pub fn TodoApp() -> Element { + let auth = use_forge_auth(); + + rsx! { + main { + div { class: "shell", + header { class: "hero", + h1 { "Todos" } + if auth.is_authenticated() { + UserBar {} + } + } + if auth.is_authenticated() { + TodoList {} + } else { + AuthPanel {} + } + } + } + } +} + +#[component] +fn UserBar() -> Element { + let mut auth = use_forge_auth(); + let viewer = auth.viewer::(); + let label = viewer + .as_ref() + .map(|u| u.name.clone()) + .unwrap_or_default(); + + rsx! { + div { class: "user-row", + span { class: "user", "{label}" } + button { + class: "logout", + onclick: move |_| auth.logout(), + "Sign out" + } + } + } +} + +#[component] +fn AuthPanel() -> Element { + let mut auth = use_forge_auth(); + let signals = use_signals(); + let login_mut = use_login(); + let register_mut = use_register(); + + let mut mode = use_signal(|| "login".to_string()); + let mut email = use_signal(String::new); + let mut name = use_signal(String::new); + let mut password = use_signal(String::new); + let mut error = use_signal(|| None::); + let mut loading = use_signal(|| false); + + let handle_submit = { + let login_mut = login_mut.clone(); + let register_mut = register_mut.clone(); + let signals = signals.clone(); + move |evt: FormEvent| { + evt.prevent_default(); + let is_register = mode.read().as_str() == "register"; + let e = email.read().clone(); + let n = name.read().clone(); + let p = password.read().clone(); + let login_mut = login_mut.clone(); + let register_mut = register_mut.clone(); + let signals = signals.clone(); + spawn(async move { + loading.set(true); + error.set(None); + let res = if is_register { + register_mut.call(RegisterInput::new(&e, &n, &p)).await + } else { + login_mut.call(LoginInput::new(&e, &p)).await + }; + match res { + Ok(r) => { + signals.track_with_properties( + "auth_success", + json!({"mode": is_register}), + ); + auth.login_with_viewer( + r.access_token.clone(), + r.refresh_token.clone(), + &r.user, + ); + } + Err(e) => error.set(Some(e.message)), + } + loading.set(false); + }); + } + }; + + rsx! { + section { class: "auth-panel", + div { class: "tabs", + button { + class: if mode.read().as_str() == "login" { "active" } else { "" }, + onclick: move |_| mode.set("login".into()), + "Sign in" + } + button { + class: if mode.read().as_str() == "register" { "active" } else { "" }, + onclick: move |_| mode.set("register".into()), + "Sign up" + } + } + form { onsubmit: handle_submit, + if mode.read().as_str() == "register" { + input { + r#type: "text", + placeholder: "Name", + value: "{name}", + oninput: move |e: FormEvent| name.set(e.value()), + required: true, + } + } + input { + r#type: "email", + placeholder: "Email", + value: "{email}", + oninput: move |e: FormEvent| email.set(e.value()), + required: true, + } + input { + r#type: "password", + placeholder: "Password (min 8 chars)", + value: "{password}", + oninput: move |e: FormEvent| password.set(e.value()), + minlength: "8", + required: true, + } + button { + r#type: "submit", + disabled: loading(), + if loading() { "..." } else if mode.read().as_str() == "login" { "Sign in" } else { "Sign up" } + } + } + if let Some(msg) = error() { + p { class: "error", "{msg}" } + } + } + } +} + +#[component] +fn TodoList() -> Element { let signals = use_signals(); let create_todo = use_create_todo(); let todo_state = use_list_todos_subscription(); @@ -17,121 +170,91 @@ pub fn TodoApp() -> Element { let todo_items = todo_state.data.clone().unwrap_or_default(); let remaining_count = todo_items.iter().filter(|t| !t.completed).count(); - rsx! { - main { - div { - class: "shell", - header { - class: "hero", - h1 { "Todos" } - } - - section { - class: "input-panel", - div { - class: "input-row", - input { - r#type: "text", - placeholder: "What needs to be done?", - value: new_title(), - disabled: adding(), - oninput: move |event| new_title.set(event.value()), - onkeydown: { - let create_todo = create_todo.clone(); - let signals = signals.clone(); - move |event: KeyboardEvent| { - if event.key().to_string() == "Enter" { - let title = new_title().trim().to_string(); - if title.is_empty() || adding() { - return; - } - error.set(None); - adding.set(true); - let create_todo = create_todo.clone(); - let signals = signals.clone(); - spawn(async move { - match create_todo.call(CreateTodoInput::new(title.clone())).await { - Ok(_) => { - signals.track_with_properties("todo_created", json!({"title": &title})); - new_title.set(String::new()); - } - Err(err) => { - signals.track_with_properties("todo_create_error", json!({"error": &err.message})); - error.set(Some(err.message)); - } - } - adding.set(false); - }); - } - } - }, - } - button { - disabled: adding() || new_title().trim().is_empty(), - onclick: { - let create_todo = create_todo.clone(); - let signals = signals.clone(); - move |_| { - let title = new_title().trim().to_string(); - if title.is_empty() || adding() { - return; - } - error.set(None); - adding.set(true); - let create_todo = create_todo.clone(); - let signals = signals.clone(); - spawn(async move { - match create_todo.call(CreateTodoInput::new(title.clone())).await { - Ok(_) => { - signals.track_with_properties("todo_created", json!({"title": &title})); - new_title.set(String::new()); - } - Err(err) => { - signals.track_with_properties("todo_create_error", json!({"error": &err.message})); - error.set(Some(err.message)); - } - } - adding.set(false); - }); - } - }, - if adding() { "Adding..." } else { "Add" } - } + let submit = { + let create_todo = create_todo.clone(); + let signals = signals.clone(); + move || { + let title = new_title().trim().to_string(); + if title.is_empty() || adding() { + return; + } + error.set(None); + adding.set(true); + let create_todo = create_todo.clone(); + let signals = signals.clone(); + spawn(async move { + match create_todo.call(CreateTodoInput::new(title.clone())).await { + Ok(_) => { + signals.track_with_properties("todo_created", json!({"title": &title})); + new_title.set(String::new()); } - - if let Some(message) = error() { - p { class: "error", "{message}" } + Err(err) => { + signals.track_with_properties( + "todo_create_error", + json!({"error": &err.message}), + ); + error.set(Some(err.message)); } } + adding.set(false); + }); + } + }; - section { - class: "list-panel", - if !todo_items.is_empty() { - div { - class: "list-head", - span { class: "summary", "{remaining_count} remaining" } - } - } - - if todo_state.loading { - p { class: "status", "Loading..." } - } else if let Some(todo_error) = todo_state.error.as_ref() { - p { class: "error", "{todo_error.message}" } - } else if todo_items.is_empty() { - p { class: "status", "No todos yet. Add one above!" } - } else { - ul { - for todo in todo_items { - TodoItem { - key: "{todo.id}", - todo: todo, - error: error, - } + rsx! { + section { class: "input-panel", + div { class: "input-row", + input { + r#type: "text", + placeholder: "What needs to be done?", + value: new_title(), + disabled: adding(), + oninput: move |event| new_title.set(event.value()), + onkeydown: { + let mut submit = submit.clone(); + move |event: KeyboardEvent| { + if event.key().to_string() == "Enter" { + submit(); } } - p { class: "count", "{remaining_count} remaining" } + }, + } + button { + disabled: adding() || new_title().trim().is_empty(), + onclick: { + let mut submit = submit.clone(); + move |_| submit() + }, + if adding() { "Adding..." } else { "Add" } + } + } + if let Some(message) = error() { + p { class: "error", "{message}" } + } + } + section { class: "list-panel", + if !todo_items.is_empty() { + div { class: "list-head", + span { class: "summary", "{remaining_count} remaining" } + } + } + if todo_state.loading { + p { class: "status", "Loading..." } + } else if let Some(todo_error) = todo_state.error.as_ref() { + p { class: "error", "{todo_error.message}" } + } else if todo_items.is_empty() { + p { class: "status", "No todos yet. Add one above!" } + } else { + ul { + for todo in todo_items { + TodoItem { + key: "{todo.id}", + todo: todo, + error: error, + } } } + p { class: "count", "{remaining_count} remaining" } } } } diff --git a/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts b/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts index d2495029..3a663824 100644 --- a/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts +++ b/examples/with-dioxus/realtime-todo-list/frontend/tests/home.spec.ts @@ -5,42 +5,55 @@ import { uniqueId, trackConsoleErrors, } from "./fixtures"; +import type { Page } from "@playwright/test"; const INPUT = 'input[placeholder="What needs to be done?"]'; +const EMAIL = 'input[type="email"]'; +const PASSWORD = 'input[type="password"]'; +const NAME = 'input[placeholder="Name"]'; -async function deleteAllTodos( - rpc: (fn: string, args?: unknown) => Promise, +async function signUp( + page: Page, + email: string, + name: string, + password: string, ) { - const todos = await rpc("list_todos"); - if (!Array.isArray(todos)) return; - for (const todo of todos) { - await rpc("delete_todo", { id: todo.id }); - } + await page.getByRole("button", { name: "Sign up" }).first().click(); + await page.fill(NAME, name); + await page.fill(EMAIL, email); + await page.fill(PASSWORD, password); + await page.getByRole("button", { name: "Sign up" }).last().click(); + await expect(page.locator(INPUT)).toBeVisible({ timeout: ACTION_TIMEOUT }); } -test.beforeEach(async ({ rpc }) => { - await deleteAllTodos(rpc); -}); - -test.afterEach(async ({ rpc }) => { - await deleteAllTodos(rpc); -}); +// The app only subscribes to the todos query once authenticated, so reactivity +// readiness can't be detected until after sign-up. Arm the subscribe wait +// before submitting, then await it once the authed view renders. (×3 timeout +// for the WASM download → instantiate → init → SSE → subscribe path.) +async function signUpReady( + page: Page, + email: string, + name: string, + password: string, +) { + const subscribed = page.waitForResponse( + (res) => res.url().includes("/_api/subscribe") && res.status() === 200, + { timeout: ACTION_TIMEOUT * 3 }, + ); + await signUp(page, email, name, password); + await subscribed; +} -test("todo flow stays reactive through create, toggle, and delete", async ({ +test("authenticated user can create, toggle, and delete their todos", async ({ page, - gotoReady, }) => { - const title = uniqueId("release"); const errors = trackConsoleErrors(page); + const email = `${uniqueId("user")}@test.local`; - await gotoReady(); - await expect(page.locator("h1")).toHaveText("Todos"); - await expect( - page.locator(".status", { hasText: "No todos yet" }), - ).toBeVisible({ - timeout: ACTION_TIMEOUT, - }); + await page.goto("/"); + await signUpReady(page, email, "Solo", "password123"); + const title = uniqueId("release"); await page.fill(INPUT, title); await page.click(".input-row button"); @@ -49,22 +62,44 @@ test("todo flow stays reactive through create, toggle, and delete", async ({ await expect(page.locator(".count")).toHaveText("1 remaining", { timeout: ACTION_TIMEOUT, }); - await expect(page.locator(INPUT)).toHaveValue("", { - timeout: ACTION_TIMEOUT, - }); await todoItem.locator("button.toggle").click(); await expect(todoItem).toHaveClass(/completed/, { timeout: ACTION_TIMEOUT }); - await expect(page.locator(".count")).toHaveText("0 remaining", { - timeout: ACTION_TIMEOUT, - }); await todoItem.locator("button.delete").click(); await expect(todoItem).not.toBeVisible({ timeout: ACTION_TIMEOUT }); - await expect( - page.locator(".status", { hasText: "No todos yet" }), - ).toBeVisible({ + expect(errors).toHaveLength(0); +}); + +test("two users cannot see each other's todos", async ({ browser }) => { + const aliceEmail = `${uniqueId("alice")}@test.local`; + const bobEmail = `${uniqueId("bob")}@test.local`; + const aliceTitle = uniqueId("alice-task"); + const bobTitle = uniqueId("bob-task"); + + const aliceCtx = await browser.newContext(); + const alice = await aliceCtx.newPage(); + await alice.goto("/"); + await signUpReady(alice, aliceEmail, "Alice", "password123"); + await alice.fill(INPUT, aliceTitle); + await alice.click(".input-row button"); + await expect(alice.locator("li", { hasText: aliceTitle })).toBeVisible({ timeout: ACTION_TIMEOUT, }); - expect(errors).toHaveLength(0); + + const bobCtx = await browser.newContext(); + const bob = await bobCtx.newPage(); + await bob.goto("/"); + await signUpReady(bob, bobEmail, "Bob", "password123"); + await bob.fill(INPUT, bobTitle); + await bob.click(".input-row button"); + await expect(bob.locator("li", { hasText: bobTitle })).toBeVisible({ + timeout: ACTION_TIMEOUT, + }); + + await expect(bob.locator("li", { hasText: aliceTitle })).toHaveCount(0); + await expect(alice.locator("li", { hasText: bobTitle })).toHaveCount(0); + + await aliceCtx.close(); + await bobCtx.close(); }); diff --git a/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql b/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql index e717bf19..4f584e0f 100644 --- a/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql +++ b/examples/with-dioxus/realtime-todo-list/migrations/0001_todos.sql @@ -1,8 +1,22 @@ -CREATE TABLE todos ( +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY, + email VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + password_hash TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_users_email ON users(email); + +CREATE TABLE IF NOT EXISTS todos ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, title TEXT NOT NULL, completed BOOLEAN NOT NULL DEFAULT false, created_at TIMESTAMPTZ NOT NULL DEFAULT now() ); +CREATE INDEX IF NOT EXISTS idx_todos_user_id ON todos(user_id); + SELECT forge_enable_reactivity('todos'); diff --git a/examples/with-dioxus/realtime-todo-list/src/functions/auth.rs b/examples/with-dioxus/realtime-todo-list/src/functions/auth.rs new file mode 100644 index 00000000..b23c4ff0 --- /dev/null +++ b/examples/with-dioxus/realtime-todo-list/src/functions/auth.rs @@ -0,0 +1,140 @@ +use crate::schema::{AuthResponse, User, UserPublic}; +use forge::prelude::*; + +#[derive(Debug, Serialize, Deserialize)] +pub struct RegisterInput { + pub email: String, + pub name: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LoginInput { + pub email: String, + pub password: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RefreshInput { + pub refresh_token: String, +} + +async fn auth_response(ctx: &MutationContext, user: &User) -> Result { + let pair = ctx.issue_token_pair(user.id, &["user"]).await?; + Ok(AuthResponse { + access_token: pair.access_token, + refresh_token: pair.refresh_token, + user: UserPublic::from(user.clone()), + }) +} + +fn validate_register(input: &RegisterInput) -> Result<(String, String)> { + let email = input.email.trim(); + if email.is_empty() { + return Err(ForgeError::Validation("Email is required".into())); + } + let name = input.name.trim(); + if name.is_empty() { + return Err(ForgeError::Validation("Name is required".into())); + } + if input.password.len() < 8 { + return Err(ForgeError::Validation( + "Password must be at least 8 characters".into(), + )); + } + Ok((email.to_string(), name.to_string())) +} + +#[forge::mutation(auth = "none")] +pub async fn register(ctx: &MutationContext, input: RegisterInput) -> Result { + let (email, name) = validate_register(&input)?; + + let password_hash = { + use argon2::PasswordHasher; + use password_hash::SaltString; + let salt = SaltString::generate(&mut password_hash::rand_core::OsRng); + argon2::Argon2::default() + .hash_password(input.password.as_bytes(), &salt) + .map_err(|e| ForgeError::internal(e.to_string()))? + .to_string() + }; + + let id = Uuid::new_v4(); + let now = Utc::now(); + let mut conn = ctx.conn().await?; + + let user = sqlx::query_as!( + User, + r#" + INSERT INTO users (id, email, name, password_hash, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, email, name, password_hash as "password_hash!", created_at, updated_at + "#, + id, + &email, + &name, + &password_hash, + now, + now + ) + .fetch_one(&mut conn) + .await + .map_err(|e| match &e { + sqlx::Error::Database(db_err) if db_err.constraint() == Some("idx_users_email") => { + ForgeError::Validation("Email already registered".into()) + } + _ => ForgeError::from(e), + })?; + + auth_response(ctx, &user).await +} + +#[forge::mutation(auth = "none")] +pub async fn login(ctx: &MutationContext, input: LoginInput) -> Result { + let mut conn = ctx.conn().await?; + + let user = sqlx::query_as!( + User, + r#" + SELECT id, email, name, password_hash as "password_hash!", created_at, updated_at + FROM users WHERE email = $1 + "#, + &input.email + ) + .fetch_optional(&mut conn) + .await? + .ok_or_else(|| ForgeError::Validation("Invalid email or password".into()))?; + + { + use argon2::PasswordVerifier; + let parsed = password_hash::PasswordHash::new(&user.password_hash) + .map_err(|e| ForgeError::internal(e.to_string()))?; + argon2::Argon2::default() + .verify_password(input.password.as_bytes(), &parsed) + .map_err(|_| ForgeError::Validation("Invalid email or password".into()))?; + } + + auth_response(ctx, &user).await +} + +#[forge::mutation(auth = "none")] +pub async fn refresh_token(ctx: &MutationContext, input: RefreshInput) -> Result { + ctx.rotate_refresh_token(&input.refresh_token).await +} + +#[forge::query(scope = "global")] +pub async fn me(ctx: &QueryContext) -> Result { + let user_id = ctx.user_id()?; + let user = sqlx::query_as!( + User, + r#" + SELECT id, email, name, password_hash as "password_hash!", created_at, updated_at + FROM users WHERE id = $1 + "#, + user_id + ) + .fetch_optional(ctx.db()) + .await? + .ok_or_else(|| ForgeError::NotFound("User not found".into()))?; + Ok(UserPublic::from(user)) +} diff --git a/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs b/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs index 2fd7ed7a..5afdcc48 100644 --- a/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs +++ b/examples/with-dioxus/realtime-todo-list/src/functions/mod.rs @@ -1 +1,2 @@ +mod auth; mod todos; diff --git a/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs b/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs index 657492ec..fd49d3ab 100644 --- a/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs +++ b/examples/with-dioxus/realtime-todo-list/src/functions/todos.rs @@ -15,26 +15,33 @@ pub struct UpdateTodoInput { pub completed: Option, } -#[forge::query(auth = "none", tables("todos"))] +#[forge::query(tables("todos"))] pub async fn list_todos(ctx: &QueryContext) -> Result> { - sqlx::query_as!(Todo, "SELECT * FROM todos ORDER BY created_at DESC") - .fetch_all(ctx.db()) - .await - .map_err(Into::into) + let user_id = ctx.user_id()?; + sqlx::query_as!( + Todo, + "SELECT * FROM todos WHERE user_id = $1 ORDER BY created_at DESC", + user_id + ) + .fetch_all(ctx.db()) + .await + .map_err(Into::into) } -#[forge::mutation(auth = "none")] +#[forge::mutation(scope = "global")] pub async fn create_todo(ctx: &MutationContext, input: CreateTodoInput) -> Result { if input.title.trim().is_empty() { return Err(ForgeError::Validation("Title cannot be empty".into())); } + let user_id = ctx.user_id()?; let title = input.title.trim().to_string(); let mut conn = ctx.conn().await?; sqlx::query_as!( Todo, - "INSERT INTO todos (title) VALUES ($1) RETURNING *", + "INSERT INTO todos (user_id, title) VALUES ($1, $2) RETURNING *", + user_id, title ) .fetch_one(&mut conn) @@ -42,8 +49,9 @@ pub async fn create_todo(ctx: &MutationContext, input: CreateTodoInput) -> Resul .map_err(Into::into) } -#[forge::mutation(auth = "none")] +#[forge::mutation] pub async fn update_todo(ctx: &MutationContext, input: UpdateTodoInput) -> Result { + let user_id = ctx.user_id()?; let title = input.title.as_deref(); let mut conn = ctx.conn().await?; @@ -52,24 +60,174 @@ pub async fn update_todo(ctx: &MutationContext, input: UpdateTodoInput) -> Resul "UPDATE todos SET title = COALESCE($1, title), completed = COALESCE($2, completed) - WHERE id = $3 + WHERE id = $3 AND user_id = $4 RETURNING *", title, input.completed, - input.id + input.id, + user_id ) .fetch_optional(&mut conn) .await? .ok_or_else(|| ForgeError::NotFound("Todo not found".into())) } -#[forge::mutation(auth = "none")] +#[forge::mutation] pub async fn delete_todo(ctx: &MutationContext, id: Uuid) -> Result { + let user_id = ctx.user_id()?; let mut conn = ctx.conn().await?; - let result = sqlx::query!("DELETE FROM todos WHERE id = $1", id) - .execute(&mut conn) - .await?; + let result = sqlx::query!( + "DELETE FROM todos WHERE id = $1 AND user_id = $2", + id, + user_id + ) + .execute(&mut conn) + .await?; Ok(result.rows_affected() > 0) } + +#[cfg(all(test, feature = "testcontainers"))] +mod tests { + use super::*; + use forge::forge_core::function::{AuthContext, RequestMetadata}; + use forge::testing::{IsolatedTestDb, TestDatabase}; + use std::path::Path; + + async fn setup_db() -> IsolatedTestDb { + let base = TestDatabase::from_env().await.expect("test db"); + let db = base.isolated("todos_test").await.expect("isolated db"); + db.run_sql(&forge::get_internal_sql()) + .await + .expect("internal sql"); + db.migrate(Path::new("migrations")) + .await + .expect("migrations"); + db + } + + async fn seed_user(pool: &sqlx::PgPool) -> Uuid { + let id = Uuid::new_v4(); + sqlx::query!( + "INSERT INTO users (id, email, name, password_hash) VALUES ($1, $2, $3, $4)", + id, + format!("{id}@test.local"), + "Test User", + "x" + ) + .execute(pool) + .await + .expect("seed user"); + id + } + + fn query_ctx(pool: sqlx::PgPool, user_id: Uuid) -> QueryContext { + QueryContext::new( + pool, + AuthContext::authenticated(user_id, vec!["user".into()], Default::default()), + RequestMetadata::default(), + ) + } + + fn mutation_ctx(pool: sqlx::PgPool, user_id: Uuid) -> MutationContext { + MutationContext::new( + pool, + AuthContext::authenticated(user_id, vec!["user".into()], Default::default()), + RequestMetadata::default(), + ) + } + + #[tokio::test] + async fn create_todo_trims_and_persists_title() { + let db = setup_db().await; + let uid = seed_user(db.pool()).await; + let ctx = mutation_ctx(db.pool().clone(), uid); + + let todo = create_todo( + &ctx, + CreateTodoInput { + title: " ship tests ".into(), + }, + ) + .await + .expect("create"); + + assert_eq!(todo.title, "ship tests"); + assert_eq!(todo.user_id, uid); + assert!(!todo.completed); + db.cleanup().await.expect("cleanup"); + } + + #[tokio::test] + async fn list_todos_isolates_by_user() { + let db = setup_db().await; + let alice = seed_user(db.pool()).await; + let bob = seed_user(db.pool()).await; + + let alice_mut = mutation_ctx(db.pool().clone(), alice); + let bob_mut = mutation_ctx(db.pool().clone(), bob); + create_todo( + &alice_mut, + CreateTodoInput { + title: "alice".into(), + }, + ) + .await + .expect("alice todo"); + create_todo( + &bob_mut, + CreateTodoInput { + title: "bob".into(), + }, + ) + .await + .expect("bob todo"); + + let alice_q = query_ctx(db.pool().clone(), alice); + let bob_q = query_ctx(db.pool().clone(), bob); + let alice_todos = list_todos(&alice_q).await.expect("alice list"); + let bob_todos = list_todos(&bob_q).await.expect("bob list"); + + assert_eq!(alice_todos.len(), 1); + assert_eq!(alice_todos[0].title, "alice"); + assert_eq!(bob_todos.len(), 1); + assert_eq!(bob_todos[0].title, "bob"); + db.cleanup().await.expect("cleanup"); + } + + #[tokio::test] + async fn update_todo_blocks_other_users() { + let db = setup_db().await; + let alice = seed_user(db.pool()).await; + let bob = seed_user(db.pool()).await; + + let alice_mut = mutation_ctx(db.pool().clone(), alice); + let todo = create_todo( + &alice_mut, + CreateTodoInput { + title: "hers".into(), + }, + ) + .await + .expect("create"); + + let bob_mut = mutation_ctx(db.pool().clone(), bob); + let err = update_todo( + &bob_mut, + UpdateTodoInput { + id: todo.id, + title: Some("stolen".into()), + completed: None, + }, + ) + .await + .expect_err("bob must not update alice's todo"); + assert!(matches!(err, ForgeError::NotFound(_))); + + let deleted = delete_todo(&bob_mut, todo.id).await.expect("delete call"); + assert!(!deleted, "bob must not delete alice's todo"); + + db.cleanup().await.expect("cleanup"); + } +} diff --git a/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs b/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs index fccb7763..45fa3645 100644 --- a/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs +++ b/examples/with-dioxus/realtime-todo-list/src/schema/todo.rs @@ -6,7 +6,43 @@ use uuid::Uuid; #[forge::model] pub struct Todo { pub id: Uuid, + pub user_id: Uuid, pub title: String, pub completed: bool, pub created_at: DateTime, } + +#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)] +pub struct User { + pub id: Uuid, + pub email: String, + pub name: String, + pub created_at: DateTime, + pub updated_at: DateTime, + #[serde(skip_serializing)] + pub password_hash: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserPublic { + pub id: Uuid, + pub email: String, + pub name: String, +} + +impl From for UserPublic { + fn from(u: User) -> Self { + Self { + id: u.id, + email: u.email, + name: u.name, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AuthResponse { + pub access_token: String, + pub refresh_token: String, + pub user: UserPublic, +} diff --git a/examples/with-svelte/demo/.env b/examples/with-svelte/demo/.env index 7344ecbe..60da3aae 100644 --- a/examples/with-svelte/demo/.env +++ b/examples/with-svelte/demo/.env @@ -1,21 +1,17 @@ -# Server +# Dev-only environment for `forge test` and local runs. NOT shipped to users: +# `scripts/build-template-archive.sh` excludes `.env`, and the webhook secret is +# used server-side only (never in the browser bundle). Users copy `.env.example` +# and generate their own secrets. Mirrors the realtime-todo-list convention. HOST=0.0.0.0 PORT=9081 - -# Logging (error, warn, info, debug, trace) RUST_LOG=info,forge_runtime::function::executor=trace - -# Postgres container settings POSTGRES_USER=postgres POSTGRES_PASSWORD=forge POSTGRES_DB=forge_svelte_demo_template POSTGRES_PORT=5432 - -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production - -# Webhook secret for HMAC-SHA256 signature validation +JWT_SECRET=dev-jwt-secret-not-for-production-use-please-rotate +JWT_AUDIENCE=forge-demo-dev WEBHOOK_SECRET=demo-secret - -# Enable offline mode for sqlx compile-time checks +SEED_DEMO_USER=true +CORS_ORIGIN=http://localhost:9080 SQLX_OFFLINE=true diff --git a/examples/with-svelte/demo/.env.example b/examples/with-svelte/demo/.env.example index 7344ecbe..c0e7972e 100644 --- a/examples/with-svelte/demo/.env.example +++ b/examples/with-svelte/demo/.env.example @@ -1,3 +1,5 @@ +# Copy to `.env` and fill in real values. Never commit `.env`. + # Server HOST=0.0.0.0 PORT=9081 @@ -11,11 +13,22 @@ POSTGRES_PASSWORD=forge POSTGRES_DB=forge_svelte_demo_template POSTGRES_PORT=5432 -# JWT secret for authentication -JWT_SECRET=demo-jwt-secret-change-me-in-production +# JWT signing secret. Generate with: openssl rand -base64 32 +JWT_SECRET=CHANGE_ME_USE_OPENSSL_RAND_BASE64_32 + +# JWT audience claim. Must match the audience configured in your auth provider. +JWT_AUDIENCE=CHANGE_ME_YOUR_AUDIENCE + +# HMAC secret used to verify inbound webhook signatures. +# Generate with: openssl rand -hex 32 +WEBHOOK_SECRET=CHANGE_ME_USE_OPENSSL_RAND_HEX_32 + +# Seed the demo user (demo@example.com / password123) at first migration. +# DEV ONLY. Leave unset (or `false`) in any deployed environment. +SEED_DEMO_USER=true -# Webhook secret for HMAC-SHA256 signature validation -WEBHOOK_SECRET=demo-secret +# CORS origin for the SvelteKit frontend. Override per environment. +CORS_ORIGIN=http://localhost:9080 # Enable offline mode for sqlx compile-time checks SQLX_OFFLINE=true diff --git a/examples/with-svelte/demo/.gitignore b/examples/with-svelte/demo/.gitignore index 6c4eb93f..c082ec29 100644 --- a/examples/with-svelte/demo/.gitignore +++ b/examples/with-svelte/demo/.gitignore @@ -13,6 +13,8 @@ frontend/playwright-report/ frontend/test-results/ # Environment +# `.env` is tracked with dev-only secrets so `forge test` works from a clean +# checkout; the template archive excludes it (see build-template-archive.sh). .env.local .env.*.local diff --git a/examples/with-svelte/demo/Cargo.toml b/examples/with-svelte/demo/Cargo.toml index 080cf5c3..f2df1231 100644 --- a/examples/with-svelte/demo/Cargo.toml +++ b/examples/with-svelte/demo/Cargo.toml @@ -6,7 +6,7 @@ rust-version = "1.92" publish = false [features] -default = ["embedded-frontend"] +default = [] embedded-frontend = ["dep:rust-embed", "forge/embedded-frontend"] testcontainers = ["forge/testcontainers"] @@ -25,6 +25,9 @@ tokio-tungstenite = { version = "0.26", features = ["rustls-tls-webpki-roots"] } futures-util = "0.3" argon2 = "0.5" password-hash = "0.5" +hmac = "0.12" +sha2 = "0.10" +hex = "0.4" rust-embed = { version = "8", optional = true } [build-dependencies] diff --git a/examples/with-svelte/demo/Dockerfile b/examples/with-svelte/demo/Dockerfile index e9d2f26c..57890816 100644 --- a/examples/with-svelte/demo/Dockerfile +++ b/examples/with-svelte/demo/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1-slim-bookworm AS dev +FROM rust:1.92-slim-bookworm AS dev WORKDIR /app @@ -15,19 +15,19 @@ RUN cargo install cargo-watch --locked && \ # Development command - no frontend embedding, watch for changes CMD ["cargo", "watch", "-x", "run --no-default-features"] -FROM oven/bun:1-alpine AS frontend-builder +FROM oven/bun:1.1.34-alpine AS frontend-builder WORKDIR /app/frontend COPY frontend/package.json frontend/bun.lock* ./ -RUN bun install --frozen-lockfile || bun install +RUN bun install --frozen-lockfile COPY frontend ./ RUN bun run build -FROM rust:1-alpine AS builder +FROM rust:1.92-alpine AS builder WORKDIR /app diff --git a/examples/with-svelte/demo/docker-compose.yml b/examples/with-svelte/demo/docker-compose.yml index 18ed2553..eb9c716b 100644 --- a/examples/with-svelte/demo/docker-compose.yml +++ b/examples/with-svelte/demo/docker-compose.yml @@ -5,7 +5,7 @@ services: dockerfile: Dockerfile target: dev ports: - - "9081:9081" + - "127.0.0.1:9081:9081" env_file: - .env environment: @@ -34,7 +34,7 @@ services: working_dir: /app command: sh -c "bun install && bun run dev --host 0.0.0.0 --port 9080" ports: - - "9080:9080" + - "127.0.0.1:9080:9080" env_file: - ./frontend/.env volumes: @@ -60,7 +60,7 @@ services: otel: build: ../../../docker/otel-lgtm ports: - - "3000:3000" + - "127.0.0.1:3000:3000" env_file: - .env environment: diff --git a/examples/with-svelte/demo/forge.toml b/examples/with-svelte/demo/forge.toml index dec1a2c7..ad7aa4a4 100644 --- a/examples/with-svelte/demo/forge.toml +++ b/examples/with-svelte/demo/forge.toml @@ -15,7 +15,7 @@ url = "${DATABASE_URL}" [gateway] port = 9081 cors_enabled = true -cors_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] +cors_origins = ["${CORS_ORIGIN-http://localhost:9080}", "http://127.0.0.1:9080"] # request_timeout = "30s" # max_body_size = "10mb" # quiet_paths = ["/_api/health", "/_api/ready"] # Routes excluded from traces/metrics/logs @@ -42,7 +42,7 @@ otlp_endpoint = "${FORGE_OTEL_ENDPOINT-http://localhost:4318}" [auth] jwt_algorithm = "HS256" jwt_secret = "${JWT_SECRET}" -jwt_audience = "${JWT_AUDIENCE-https://api.forge-demo.local}" +jwt_audience = "${JWT_AUDIENCE}" [mcp] enabled = true @@ -52,9 +52,18 @@ session_ttl = "1h" allowed_origins = ["http://localhost:9080", "http://127.0.0.1:9080"] require_protocol_version_header = true -# [rate_limit] -# mode = "local" # local, distributed -# max_local_buckets = 10000 +[rate_limit] +# hybrid: per-node DashMap fast path, PG fallback for global keys (DDoS-grade). +# strict: every check round-trips to PG (cluster-wide correct, billing-grade). +mode = "hybrid" +max_local_buckets = 10000 +# Per-handler quotas live on the function macros, e.g. +# #[forge::mutation(rate_limit_requests = 10, rate_limit_per_secs = 60, rate_limit_key = "ip")] + +[signals] +# Product analytics + diagnostics are off by default; this demo opts in to +# exercise the /_api/signal endpoint and the client SDK. +enabled = true # [cluster] # name = "node-1" # auto-generated if omitted diff --git a/examples/with-svelte/demo/frontend/playwright.config.ts b/examples/with-svelte/demo/frontend/playwright.config.ts index 678e4e7b..0c5b1573 100644 --- a/examples/with-svelte/demo/frontend/playwright.config.ts +++ b/examples/with-svelte/demo/frontend/playwright.config.ts @@ -6,7 +6,7 @@ export default defineConfig({ testDir: "./tests", fullyParallel: false, forbidOnly: !!process.env.CI, - retries: process.env.CI ? 1 : 1, + retries: process.env.CI ? 2 : 0, timeout: 90_000, workers: process.env.CI ? 1 : undefined, reporter: "html", diff --git a/examples/with-svelte/demo/frontend/src/lib/forge/api.ts b/examples/with-svelte/demo/frontend/src/lib/forge/api.ts index ea248e6e..462a83ff 100644 --- a/examples/with-svelte/demo/frontend/src/lib/forge/api.ts +++ b/examples/with-svelte/demo/frontend/src/lib/forge/api.ts @@ -18,6 +18,7 @@ import type { RegisterInput, TokenPair, Trade, + TriggerDemoWebhookInput, User, UserRole, VerificationInput, @@ -70,6 +71,9 @@ export const refreshToken = (args: RefreshInput): Promise => getForgeClient().call("refresh_token", args); export const register = (args: RegisterInput): Promise => getForgeClient().call("register", args); +export const triggerDemoWebhook = ( + args: TriggerDemoWebhookInput, +): Promise => getForgeClient().call("trigger_demo_webhook", args); export const updateUser = (args: { id: string; email: string | null; diff --git a/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts b/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts index 520357f8..6f651f14 100644 --- a/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts +++ b/examples/with-svelte/demo/frontend/src/lib/forge/reactive.svelte.ts @@ -13,6 +13,7 @@ import { login, refreshToken, register, + triggerDemoWebhook, updateUser, } from "./api"; import { @@ -31,6 +32,7 @@ import type { RegisterInput, TokenPair, Trade, + TriggerDemoWebhookInput, User, UserRole, WebhookEvent, @@ -63,6 +65,10 @@ export const refreshToken$ = (): ReactiveMutation => toReactiveMutation(refreshToken); export const register$ = (): ReactiveMutation => toReactiveMutation(register); +export const triggerDemoWebhook$ = (): ReactiveMutation< + TriggerDemoWebhookInput, + boolean +> => toReactiveMutation(triggerDemoWebhook); export const updateUser$ = (): ReactiveMutation< { id: string; diff --git a/examples/with-svelte/demo/frontend/src/lib/forge/types.ts b/examples/with-svelte/demo/frontend/src/lib/forge/types.ts index 7a0560c9..d93cdd3a 100644 --- a/examples/with-svelte/demo/frontend/src/lib/forge/types.ts +++ b/examples/with-svelte/demo/frontend/src/lib/forge/types.ts @@ -7,11 +7,11 @@ export interface AuthResponse { } export interface BinanceTrade { - symbol: string; - price: string; - quantity: string; - trade_time: number; - is_buyer_maker: boolean; + s: string; + p: string; + q: string; + T: number; + m: boolean; } export interface ConfirmVerificationInput { @@ -90,6 +90,10 @@ export interface Trade { created_at: string; } +export interface TriggerDemoWebhookInput { + idempotency_key: string; +} + export interface User { id: string; email: string; @@ -97,7 +101,6 @@ export interface User { role: UserRole; created_at: string; updated_at: string; - password_hash?: string; } export interface UserPublic { diff --git a/examples/with-svelte/demo/frontend/src/routes/+page.svelte b/examples/with-svelte/demo/frontend/src/routes/+page.svelte index 7f6c5842..70a4e262 100644 --- a/examples/with-svelte/demo/frontend/src/routes/+page.svelte +++ b/examples/with-svelte/demo/frontend/src/routes/+page.svelte @@ -10,6 +10,7 @@ trackExportUsers, trackAccountVerification, confirmVerification, + triggerDemoWebhook, getUsers$, getIssLocation$, getTrades$, @@ -26,7 +27,8 @@ const signals = getForgeSignals(); const apiUrl = PUBLIC_API_URL; - const users = getUsers$(); + // `get_users` requires auth; only subscribe once logged in (avoids an + // anonymous 401 and a wasted SSE subscription). Created in the template. const issLocation = getIssLocation$(); const trades = getTrades$(); const webhookEvents = getWebhookEvents$(); @@ -55,8 +57,11 @@ // Auth form state (only form inputs and UI state are local) let authMode = $state<"login" | "register">("login"); - let authEmail = $state("demo@example.com"); - let authPassword = $state("password123"); + // Prefill credentials only when the SvelteKit build runs in dev mode (Vite `import.meta.env.DEV`). + // Production bundles ship with empty fields so leaked demos don't double as one-click logins. + const DEV_PREFILL = import.meta.env.DEV; + let authEmail = $state(DEV_PREFILL ? "demo@example.com" : ""); + let authPassword = $state(DEV_PREFILL ? "password123" : ""); let authName = $state(""); let authLoading = $state(false); let authError = $state(null); @@ -173,39 +178,15 @@ async function triggerWebhook() { signals.breadcrumb("Sending webhook"); webhookError = null; - const secret = "demo-secret"; - const payload = JSON.stringify({ action: "test", ts: Date.now() }); - - const encoder = new TextEncoder(); - const key = await crypto.subtle.importKey( - "raw", - encoder.encode(secret), - { name: "HMAC", hash: "SHA-256" }, - false, - ["sign"] - ); - const signature = await crypto.subtle.sign("HMAC", key, encoder.encode(payload)); - const signatureHex = Array.from(new Uint8Array(signature)) - .map((b) => b.toString(16).padStart(2, "0")) - .join(""); - - const res = await fetch(`${apiUrl}/_api/webhooks/demo`, { - method: "POST", - headers: { - "Content-Type": "application/json", - "X-Webhook-Signature": signatureHex, - "X-Webhook-Timestamp": Math.floor(Date.now() / 1000).toString(), - "X-Idempotency-Key": idempotencyKey, - }, - body: payload, - }); - - if (res.ok) { + // The HMAC secret lives on the server. The backend signs and POSTs the + // webhook to itself so the browser bundle never ships the secret. + try { + await triggerDemoWebhook({ idempotency_key: idempotencyKey }); keyUsed = true; signals.track("webhook_sent", { idempotency_key: idempotencyKey }); - } else { - webhookError = `Error: ${res.status}`; - signals.track("webhook_error", { status: res.status }); + } catch (err: unknown) { + webhookError = err instanceof Error ? err.message : String(err); + signals.track("webhook_error"); } } @@ -562,7 +543,9 @@ -
+ {#if auth.isAuthenticated} + {@const users = getUsers$()} +

Users crud + subscribe

@@ -612,7 +595,8 @@ {:else if !users.loading}

No users yet. Create one above.

{/if} -
+
+ {/if}