diff --git a/.gitignore b/.gitignore index 077bfbece7..66351b131c 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,5 @@ Cargo.lock **/.forge/request.body.json node_modules/ bench/__pycache__ +/hooksref* +#/cc diff --git a/crates/forge_app/src/agent_executor.rs b/crates/forge_app/src/agent_executor.rs index 4c5ed94ff2..37e761742e 100644 --- a/crates/forge_app/src/agent_executor.rs +++ b/crates/forge_app/src/agent_executor.rs @@ -94,6 +94,7 @@ impl AgentExecutor { ChatResponse::ToolCallStart { .. } => ctx.send(message).await?, ChatResponse::ToolCallEnd(_) => ctx.send(message).await?, ChatResponse::RetryAttempt { .. } => ctx.send(message).await?, + ChatResponse::HookError { .. } => ctx.send(message).await?, ChatResponse::Interrupt { reason } => { return Err(Error::AgentToolInterrupted(reason)) .context(format!( diff --git a/crates/forge_app/src/app.rs b/crates/forge_app/src/app.rs index 169a002f03..1e1f3f7e36 100644 --- a/crates/forge_app/src/app.rs +++ b/crates/forge_app/src/app.rs @@ -9,10 +9,14 @@ use forge_stream::MpscStream; use crate::apply_tunable_parameters::ApplyTunableParameters; use crate::changed_files::ChangedFiles; use crate::dto::ToolsOverview; -use crate::hooks::{CompactionHandler, DoomLoopDetector, TitleGenerationHandler, TracingHandler}; +use crate::hooks::{ + CompactionHandler, DoomLoopDetector, TitleGenerationHandler, TracingHandler, UserHookHandler, +}; use crate::init_conversation_metrics::InitConversationMetrics; use crate::orch::Orchestrator; -use crate::services::{AgentRegistry, CustomInstructionsService, ProviderAuthService}; +use crate::services::{ + AgentRegistry, CustomInstructionsService, ProviderAuthService, UserHookConfigService, +}; use crate::set_conversation_id::SetConversationId; use crate::system_prompt::SystemPrompt; use crate::tool_registry::ToolRegistry; @@ -143,7 +147,7 @@ impl ForgeApp { // Create the orchestrator with all necessary dependencies let tracing_handler = TracingHandler::new(); let title_handler = TitleGenerationHandler::new(services.clone()); - let hook = Hook::default() + let internal_hook = Hook::default() .on_start(tracing_handler.clone().and(title_handler.clone())) .on_request(tracing_handler.clone().and(DoomLoopDetector::default())) .on_response( @@ -155,6 +159,29 @@ impl ForgeApp { .on_toolcall_end(tracing_handler.clone()) .on_end(tracing_handler.and(title_handler)); + // Load user-configurable hooks from settings files + let user_hook_config = services.get_user_hook_config().await?; + + let hook = if !user_hook_config.is_empty() { + let user_handler = UserHookHandler::new( + services.hook_command_service().clone(), + services.get_env_vars(), + user_hook_config, + environment.cwd.clone(), + conversation.id.to_string(), + ); + let user_hook = Hook::default() + .on_start(user_handler.clone()) + .on_request(user_handler.clone()) + .on_response(user_handler.clone()) + .on_toolcall_start(user_handler.clone()) + .on_toolcall_end(user_handler.clone()) + .on_end(user_handler); + internal_hook.zip(user_hook) + } else { + internal_hook + }; + let retry_config = forge_config.retry.clone().unwrap_or_default(); let orch = Orchestrator::new(services.clone(), retry_config, conversation, agent) diff --git a/crates/forge_app/src/hooks/mod.rs b/crates/forge_app/src/hooks/mod.rs index fb5447a8e6..cdc5c8f0af 100644 --- a/crates/forge_app/src/hooks/mod.rs +++ b/crates/forge_app/src/hooks/mod.rs @@ -2,8 +2,11 @@ mod compaction; mod doom_loop; mod title_generation; mod tracing; +mod user_hook_executor; +mod user_hook_handler; pub use compaction::CompactionHandler; pub use doom_loop::DoomLoopDetector; pub use title_generation::TitleGenerationHandler; pub use tracing::TracingHandler; +pub use user_hook_handler::UserHookHandler; diff --git a/crates/forge_app/src/hooks/user_hook_executor.rs b/crates/forge_app/src/hooks/user_hook_executor.rs new file mode 100644 index 0000000000..cc267e77c5 --- /dev/null +++ b/crates/forge_app/src/hooks/user_hook_executor.rs @@ -0,0 +1,245 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::time::Duration; + +use forge_domain::{CommandOutput, HookExecutionResult}; +use tracing::debug; + +use crate::services::HookCommandService; + +/// Executes user hook commands by delegating to a [`HookCommandService`]. +/// +/// Holds the service by value; the service itself is responsible for any +/// internal reference counting (`Arc`). Keeps hook-specific timeout resolution +/// in one place. +#[derive(Clone)] +pub struct UserHookExecutor(S); + +impl UserHookExecutor { + /// Creates a new `UserHookExecutor` backed by the given service. + pub fn new(service: S) -> Self { + Self(service) + } +} + +impl UserHookExecutor { + /// Executes a shell command, piping `input_json` to stdin and capturing + /// stdout/stderr. + /// + /// Applies `timeout_duration` by racing the service call against the + /// deadline. On timeout, returns a `HookExecutionResult` with + /// `exit_code: None` and a descriptive message in `stderr`. + /// + /// # Arguments + /// * `command` - The shell command string to execute. + /// * `input_json` - JSON string to pipe to the command's stdin. + /// * `timeout_duration` - Maximum time to wait for the command. + /// * `cwd` - Working directory for the command. + /// * `env_vars` - Additional environment variables to set. + /// + /// # Errors + /// Returns an error if the process cannot be spawned. + pub async fn execute( + &self, + command: &str, + input_json: &str, + timeout_duration: Duration, + cwd: &PathBuf, + env_vars: &HashMap, + ) -> anyhow::Result { + debug!( + command = command, + cwd = %cwd.display(), + timeout_ms = timeout_duration.as_millis() as u64, + "Executing user hook command" + ); + + let result = tokio::time::timeout( + timeout_duration, + self.0.execute_command_with_input( + command.to_string(), + cwd.clone(), + input_json.to_string(), + env_vars.clone(), + ), + ) + .await; + + let output = match result { + Ok(Ok(output)) => output, + Ok(Err(e)) => return Err(e), + Err(_) => { + tracing::warn!( + command = command, + timeout_ms = timeout_duration.as_millis() as u64, + "Hook command timed out" + ); + CommandOutput { + command: command.to_string(), + exit_code: None, + stdout: String::new(), + stderr: format!( + "Hook command timed out after {}ms", + timeout_duration.as_millis() + ), + } + } + }; + + debug!( + command = command, + exit_code = ?output.exit_code, + stdout_len = output.stdout.len(), + stderr_len = output.stderr.len(), + "Hook command completed" + ); + + Ok(HookExecutionResult { + exit_code: output.exit_code, + stdout: output.stdout, + stderr: output.stderr, + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::path::PathBuf; + use std::time::Duration; + + use forge_domain::CommandOutput; + use pretty_assertions::assert_eq; + + use super::*; + + /// A minimal service stub that records calls and returns a fixed result. + #[derive(Clone)] + struct StubInfra { + result: CommandOutput, + } + + impl StubInfra { + fn success(stdout: &str) -> Self { + Self { + result: CommandOutput { + command: String::new(), + exit_code: Some(0), + stdout: stdout.to_string(), + stderr: String::new(), + }, + } + } + + fn exit(code: i32, stderr: &str) -> Self { + Self { + result: CommandOutput { + command: String::new(), + exit_code: Some(code), + stdout: String::new(), + stderr: stderr.to_string(), + }, + } + } + + fn timeout() -> Self { + Self { + result: CommandOutput { + command: String::new(), + exit_code: None, + stdout: String::new(), + stderr: "Hook command timed out after 100ms".to_string(), + }, + } + } + } + + #[async_trait::async_trait] + impl HookCommandService for StubInfra { + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: HashMap, + ) -> anyhow::Result { + let mut out = self.result.clone(); + out.command = command; + Ok(out) + } + } + + #[tokio::test] + async fn test_execute_success() { + let fixture = UserHookExecutor::new(StubInfra::success("hello")); + let actual = fixture + .execute( + "echo hello", + "{}", + Duration::from_secs(0), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!(actual.exit_code, Some(0)); + assert_eq!(actual.stdout, "hello"); + assert!(actual.is_success()); + } + + #[tokio::test] + async fn test_execute_exit_code_2() { + let fixture = UserHookExecutor::new(StubInfra::exit(2, "blocked")); + let actual = fixture + .execute( + "exit 2", + "{}", + Duration::from_secs(0), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!(actual.exit_code, Some(2)); + assert!(actual.is_blocking_exit()); + assert!(actual.stderr.contains("blocked")); + } + + #[tokio::test] + async fn test_execute_non_blocking_error() { + let fixture = UserHookExecutor::new(StubInfra::exit(1, "")); + let actual = fixture + .execute( + "exit 1", + "{}", + Duration::from_secs(0), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert_eq!(actual.exit_code, Some(1)); + assert!(actual.is_non_blocking_error()); + } + + #[tokio::test] + async fn test_execute_timeout() { + let fixture = UserHookExecutor::new(StubInfra::timeout()); + let actual = fixture + .execute( + "sleep 10", + "{}", + Duration::from_millis(100), + &std::env::current_dir().unwrap(), + &HashMap::new(), + ) + .await + .unwrap(); + + assert!(actual.exit_code.is_none()); + assert!(actual.stderr.contains("timed out")); + } +} diff --git a/crates/forge_app/src/hooks/user_hook_handler.rs b/crates/forge_app/src/hooks/user_hook_handler.rs new file mode 100644 index 0000000000..57d93ddc9e --- /dev/null +++ b/crates/forge_app/src/hooks/user_hook_handler.rs @@ -0,0 +1,820 @@ +use std::collections::{BTreeMap, HashMap}; +use std::path::PathBuf; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use forge_domain::{ + ContextMessage, Conversation, EndPayload, EventData, EventHandle, HookEventInput, + HookExecutionResult, HookInput, HookOutput, RequestPayload, ResponsePayload, Role, + StartPayload, ToolcallEndPayload, ToolcallStartPayload, UserHookConfig, UserHookEntry, + UserHookEventName, UserHookMatcherGroup, +}; +use regex::Regex; +use tracing::{debug, warn}; + +use super::user_hook_executor::UserHookExecutor; +use crate::services::HookCommandService; + +/// Default timeout for hook commands (10 minutes). +const DEFAULT_HOOK_TIMEOUT: Duration = Duration::from_secs(600); + +/// EventHandle implementation that bridges user-configured hooks with the +/// existing lifecycle event system. +/// +/// This handler is constructed from a `UserHookConfig` and executes matching +/// hook commands at each lifecycle event point. It wires into the existing +/// `Hook` system via `Hook::zip()`. +#[derive(Clone)] +pub struct UserHookHandler { + executor: UserHookExecutor, + config: UserHookConfig, + cwd: PathBuf, + env_vars: HashMap, + /// Tracks whether a Stop hook has already fired to prevent infinite loops. + stop_hook_active: std::sync::Arc, +} + +impl UserHookHandler { + /// Creates a new user hook handler from configuration. + /// + /// # Arguments + /// * `service` - The hook command service used to execute hook commands. + /// * `config` - The merged user hook configuration. + /// * `cwd` - Current working directory for command execution. + /// * `project_dir` - Project root directory for `FORGE_PROJECT_DIR` env + /// var. + /// * `session_id` - Current session/conversation ID. + /// * `default_hook_timeout` - Default timeout in milliseconds for hook + /// commands. + pub fn new( + service: I, + mut env_vars: BTreeMap, + config: UserHookConfig, + cwd: PathBuf, + session_id: String, + ) -> Self { + env_vars.insert( + "FORGE_PROJECT_DIR".to_string(), + cwd.to_string_lossy().to_string(), + ); + env_vars.insert("FORGE_SESSION_ID".to_string(), session_id); + env_vars.insert("FORGE_CWD".to_string(), cwd.to_string_lossy().to_string()); + + Self { + executor: UserHookExecutor::new(service), + config, + cwd, + env_vars: env_vars.into_iter().collect(), + stop_hook_active: std::sync::Arc::new(AtomicBool::new(false)), + } + } + + /// Checks if the config has any hooks for the given event. + fn has_hooks(&self, event: &UserHookEventName) -> bool { + !self.config.get_groups(event).is_empty() + } + + /// Finds matching hook entries for an event, filtered by the optional + /// matcher regex against the given subject string. + fn find_matching_hooks<'a>( + groups: &'a [UserHookMatcherGroup], + subject: Option<&str>, + ) -> Vec<&'a UserHookEntry> { + let mut matching = Vec::new(); + + for group in groups { + let matches = match (&group.matcher, subject) { + (Some(pattern), Some(subj)) => match Regex::new(pattern) { + Ok(re) => re.is_match(subj), + Err(e) => { + warn!( + pattern = pattern, + error = %e, + "Invalid regex in hook matcher, skipping" + ); + false + } + }, + (Some(_), None) => { + // Matcher specified but no subject to match against; skip + false + } + (None, _) => { + // No matcher means unconditional match + true + } + }; + + if matches { + matching.extend(group.hooks.iter()); + } + } + + matching + } + + /// Executes a list of hook entries and returns their results. + async fn execute_hooks( + &self, + hooks: &[&UserHookEntry], + input: &HookInput, + ) -> Vec + where + I: HookCommandService, + { + let input_json = match serde_json::to_string(input) { + Ok(json) => json, + Err(e) => { + warn!(error = %e, "Failed to serialize hook input"); + return Vec::new(); + } + }; + + let mut results = Vec::new(); + for hook in hooks { + if let Some(command) = &hook.command { + match self + .executor + .execute( + command, + &input_json, + hook.timeout + .map(Duration::from_millis) + .unwrap_or(DEFAULT_HOOK_TIMEOUT), + &self.cwd, + &self.env_vars, + ) + .await + { + Ok(result) => results.push(result), + Err(e) => { + warn!( + command = command, + error = %e, + "Hook command failed to execute" + ); + } + } + } + } + + results + } + + /// Processes hook results, returning a blocking reason if any hook blocked. + fn process_results(results: &[HookExecutionResult]) -> Option { + for result in results { + // Exit code 2 = blocking error + if result.is_blocking_exit() { + let message = result + .blocking_message() + .unwrap_or("Hook blocked execution") + .to_string(); + return Some(message); + } + + // Exit code 0 = check stdout for JSON decisions + if let Some(output) = result.parse_output() + && output.is_blocking() + { + let reason = output + .reason + .unwrap_or_else(|| "Hook blocked execution".to_string()); + return Some(reason); + } + + // Non-blocking errors (exit code 1, etc.) are logged but don't block + if result.is_non_blocking_error() { + warn!( + exit_code = ?result.exit_code, + stderr = result.stderr.as_str(), + "Hook command returned non-blocking error" + ); + } + } + + None + } + + /// Processes PreToolUse results, extracting updated input if present. + fn process_pre_tool_use_output(results: &[HookExecutionResult]) -> PreToolUseDecision { + for result in results { + // Exit code 2 = blocking error + if result.is_blocking_exit() { + let message = result + .blocking_message() + .unwrap_or("Hook blocked tool execution") + .to_string(); + return PreToolUseDecision::Block(message); + } + + // Exit code 0 = check stdout for JSON decisions + if let Some(output) = result.parse_output() { + // Check permission decision + if output.permission_decision.as_deref() == Some("deny") { + let reason = output + .reason + .unwrap_or_else(|| "Tool execution denied by hook".to_string()); + return PreToolUseDecision::Block(reason); + } + + // Check generic block decision + if output.is_blocking() { + let reason = output + .reason + .unwrap_or_else(|| "Hook blocked tool execution".to_string()); + return PreToolUseDecision::Block(reason); + } + + // Check for updated input + if output.updated_input.is_some() { + return PreToolUseDecision::AllowWithUpdate(output); + } + } + + // Non-blocking errors are logged but don't block + if result.is_non_blocking_error() { + warn!( + exit_code = ?result.exit_code, + stderr = result.stderr.as_str(), + "PreToolUse hook command returned non-blocking error" + ); + } + } + + PreToolUseDecision::Allow + } +} + +/// Decision result from PreToolUse hook processing. +enum PreToolUseDecision { + /// Allow the tool call to proceed. + Allow, + /// Allow but with updated input from the hook output. + AllowWithUpdate(HookOutput), + /// Block the tool call with the given reason. + Block(String), +} + +// --- EventHandle implementations --- + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + _event: &EventData, + _conversation: &mut Conversation, + ) -> anyhow::Result<()> { + if !self.has_hooks(&UserHookEventName::SessionStart) { + return Ok(()); + } + + let groups = self.config.get_groups(&UserHookEventName::SessionStart); + let hooks = Self::find_matching_hooks(groups, Some("startup")); + + if hooks.is_empty() { + return Ok(()); + } + + let input = HookInput { + hook_event_name: UserHookEventName::SessionStart.to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data: HookEventInput::SessionStart { source: "startup".to_string() }, + }; + + let results = self.execute_hooks(&hooks, &input).await; + + // FIXME: SessionStart hooks can provide additional context but not block; + // additional_context is detected here but never injected into the conversation. + for result in &results { + if let Some(output) = result.parse_output() + && let Some(context) = &output.additional_context + { + debug!( + context_len = context.len(), + "SessionStart hook provided additional context" + ); + } + } + + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + // Only fire on the first request of a turn (user-submitted prompt). + // Subsequent iterations are internal LLM retry/tool-call loops and + // should not re-trigger UserPromptSubmit. + if event.payload.request_count != 0 { + return Ok(()); + } + + if !self.has_hooks(&UserHookEventName::UserPromptSubmit) { + return Ok(()); + } + + let groups = self.config.get_groups(&UserHookEventName::UserPromptSubmit); + let hooks = Self::find_matching_hooks(groups, None); + + if hooks.is_empty() { + return Ok(()); + } + + // Extract the last user message text as the prompt sent to the hook. + let prompt = conversation + .context + .as_ref() + .and_then(|ctx| { + ctx.messages + .iter() + .rev() + .find(|m| m.has_role(Role::User)) + .and_then(|m| m.content()) + .map(|s| s.to_string()) + }) + .unwrap_or_default(); + + let input = HookInput { + hook_event_name: "UserPromptSubmit".to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data: HookEventInput::UserPromptSubmit { prompt }, + }; + + let results = self.execute_hooks(&hooks, &input).await; + + if let Some(reason) = Self::process_results(&results) { + debug!( + reason = reason.as_str(), + "UserPromptSubmit hook blocked with feedback" + ); + // Inject feedback so the model sees why the prompt was flagged. + if let Some(context) = conversation.context.as_mut() { + let feedback_msg = format!( + "\nUserPromptSubmit\nblocked\n{reason}\n" + ); + context + .messages + .push(ContextMessage::user(feedback_msg, None).into()); + } + } + + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + _event: &EventData, + _conversation: &mut Conversation, + ) -> anyhow::Result<()> { + // FIXME: No user hook events map to Response currently + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &EventData, + _conversation: &mut Conversation, + ) -> anyhow::Result<()> { + if !self.has_hooks(&UserHookEventName::PreToolUse) { + return Ok(()); + } + + let tool_name = event.payload.tool_call.name.as_str(); + // FIXME: Add a tool name transformer to map tool names to Forge + // equivalents (e.g. "Bash" → "shell") so that hook configs written + let groups = self.config.get_groups(&UserHookEventName::PreToolUse); + let hooks = Self::find_matching_hooks(groups, Some(tool_name)); + + if hooks.is_empty() { + return Ok(()); + } + + let tool_input = + serde_json::to_value(&event.payload.tool_call.arguments).unwrap_or_default(); + + let input = HookInput { + hook_event_name: "PreToolUse".to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data: HookEventInput::PreToolUse { tool_name: tool_name.to_string(), tool_input }, + }; + + let results = self.execute_hooks(&hooks, &input).await; + let decision = Self::process_pre_tool_use_output(&results); + + match decision { + PreToolUseDecision::Allow => Ok(()), + PreToolUseDecision::AllowWithUpdate(_output) => { + // FIXME: Updating tool call input would require modifying the tool call + // in-flight, which would need changes to the orchestrator. + // For now, we log and proceed. + debug!( + tool_name = tool_name, + "PreToolUse hook returned updatedInput (not yet supported for modification)" + ); + Ok(()) + } + PreToolUseDecision::Block(reason) => { + debug!( + tool_name = tool_name, + reason = reason.as_str(), + "PreToolUse hook blocked tool call" + ); + // Return an error to signal the orchestrator to skip this tool call. + // The orchestrator converts this into an error ToolResult visible to + // the model. + Err(anyhow::anyhow!( + "Tool call '{}' blocked by PreToolUse hook: {}", + tool_name, + reason + )) + } + } + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + event: &EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + let is_error = event.payload.result.is_error(); + let event_name = if is_error { + UserHookEventName::PostToolUseFailure + } else { + UserHookEventName::PostToolUse + }; + + if !self.has_hooks(&event_name) { + return Ok(()); + } + + let tool_name = event.payload.tool_call.name.as_str(); + let groups = self.config.get_groups(&event_name); + let hooks = Self::find_matching_hooks(groups, Some(tool_name)); + + if hooks.is_empty() { + return Ok(()); + } + + let tool_input = + serde_json::to_value(&event.payload.tool_call.arguments).unwrap_or_default(); + let tool_response = serde_json::to_value(&event.payload.result.output).unwrap_or_default(); + + let input = HookInput { + hook_event_name: event_name.to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data: HookEventInput::PostToolUse { + tool_name: tool_name.to_string(), + tool_input, + tool_response, + }, + }; + + let results = self.execute_hooks(&hooks, &input).await; + + // PostToolUse can provide feedback via blocking + if let Some(reason) = Self::process_results(&results) { + debug!( + tool_name = tool_name, + event = %event_name, + reason = reason.as_str(), + "PostToolUse hook blocked with feedback" + ); + // Inject feedback as a user message + if let Some(context) = conversation.context.as_mut() { + let feedback_msg = format!( + "\n{}\n{}\nblocked\n{}\n", + event_name, tool_name, reason + ); + context + .messages + .push(forge_domain::ContextMessage::user(feedback_msg, None).into()); + } + } + + Ok(()) + } +} + +#[async_trait] +impl EventHandle> for UserHookHandler { + async fn handle( + &self, + _event: &EventData, + conversation: &mut Conversation, + ) -> anyhow::Result<()> { + // Fire SessionEnd hooks + if self.has_hooks(&UserHookEventName::SessionEnd) { + let groups = self.config.get_groups(&UserHookEventName::SessionEnd); + let hooks = Self::find_matching_hooks(groups, None); + + if !hooks.is_empty() { + let input = HookInput { + hook_event_name: "SessionEnd".to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data: HookEventInput::Empty {}, + }; + self.execute_hooks(&hooks, &input).await; + } + } + + // Fire Stop hooks + if !self.has_hooks(&UserHookEventName::Stop) { + return Ok(()); + } + + // Prevent infinite loops + let was_active = self.stop_hook_active.swap(true, Ordering::SeqCst); + if was_active { + debug!("Stop hook already active, skipping to prevent infinite loop"); + return Ok(()); + } + + let groups = self.config.get_groups(&UserHookEventName::Stop); + let hooks = Self::find_matching_hooks(groups, None); + + if hooks.is_empty() { + self.stop_hook_active.store(false, Ordering::SeqCst); + return Ok(()); + } + + let input = HookInput { + hook_event_name: "Stop".to_string(), + cwd: self.cwd.to_string_lossy().to_string(), + session_id: self.env_vars.get("FORGE_SESSION_ID").cloned(), + event_data: HookEventInput::Stop { stop_hook_active: was_active }, + }; + + let results = self.execute_hooks(&hooks, &input).await; + + if let Some(reason) = Self::process_results(&results) { + debug!( + reason = reason.as_str(), + "Stop hook wants to continue conversation" + ); + // Inject a message to continue the conversation + if let Some(context) = conversation.context.as_mut() { + let continue_msg = format!( + "\nStop\ncontinue\n{}\n", + reason + ); + context + .messages + .push(forge_domain::ContextMessage::user(continue_msg, None).into()); + } + } + + // Reset the stop hook active flag + self.stop_hook_active.store(false, Ordering::SeqCst); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::path::PathBuf; + + use forge_domain::{ + CommandOutput, HookExecutionResult, UserHookEntry, UserHookEventName, UserHookMatcherGroup, + UserHookType, + }; + use pretty_assertions::assert_eq; + + use super::*; + + /// A no-op service stub for tests that only exercise config/matching logic. + #[derive(Clone)] + struct NullInfra; + + #[async_trait::async_trait] + impl HookCommandService for NullInfra { + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: HashMap, + ) -> anyhow::Result { + Ok(CommandOutput { + command, + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }) + } + } + + fn null_handler(config: UserHookConfig) -> UserHookHandler { + UserHookHandler::new( + NullInfra, + BTreeMap::new(), + config, + PathBuf::from("/tmp"), + "sess-1".to_string(), + ) + } + + fn make_entry(command: &str) -> UserHookEntry { + UserHookEntry { + hook_type: UserHookType::Command, + command: Some(command.to_string()), + timeout: None, + } + } + + fn make_group(matcher: Option<&str>, commands: &[&str]) -> UserHookMatcherGroup { + UserHookMatcherGroup { + matcher: matcher.map(|s| s.to_string()), + hooks: commands.iter().map(|c| make_entry(c)).collect(), + } + } + + #[test] + fn test_find_matching_hooks_no_matcher_fires_unconditionally() { + let groups = vec![make_group(None, &["echo hi"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, Some("Bash")); + assert_eq!(actual.len(), 1); + assert_eq!(actual[0].command, Some("echo hi".to_string())); + } + + #[test] + fn test_find_matching_hooks_no_matcher_fires_without_subject() { + let groups = vec![make_group(None, &["echo hi"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, None); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_find_matching_hooks_regex_match() { + let groups = vec![make_group(Some("Bash"), &["block.sh"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, Some("Bash")); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_find_matching_hooks_regex_no_match() { + let groups = vec![make_group(Some("Bash"), &["block.sh"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, Some("Write")); + assert!(actual.is_empty()); + } + + #[test] + fn test_find_matching_hooks_regex_partial_match() { + let groups = vec![make_group(Some("Bash|Write"), &["check.sh"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, Some("Bash")); + assert_eq!(actual.len(), 1); + } + + #[test] + fn test_find_matching_hooks_matcher_but_no_subject() { + let groups = vec![make_group(Some("Bash"), &["block.sh"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, None); + assert!(actual.is_empty()); + } + + #[test] + fn test_find_matching_hooks_invalid_regex_skipped() { + let groups = vec![make_group(Some("[invalid"), &["block.sh"])]; + let actual = UserHookHandler::::find_matching_hooks(&groups, Some("anything")); + assert!(actual.is_empty()); + } + + #[test] + fn test_find_matching_hooks_multiple_groups() { + let groups = vec![ + make_group(Some("Bash"), &["bash-hook.sh"]), + make_group(Some("Write"), &["write-hook.sh"]), + make_group(None, &["always.sh"]), + ]; + let actual = UserHookHandler::::find_matching_hooks(&groups, Some("Bash")); + assert_eq!(actual.len(), 2); // Bash match + unconditional + } + + #[test] + fn test_process_pre_tool_use_output_allow_on_success() { + let results = vec![HookExecutionResult { + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Allow)); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_exit_2() { + let results = vec![HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "Blocked: dangerous command".to_string(), + }]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!( + matches!(actual, PreToolUseDecision::Block(msg) if msg.contains("dangerous command")) + ); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_deny() { + let results = vec![HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"permissionDecision": "deny", "reason": "Not allowed"}"#.to_string(), + stderr: String::new(), + }]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "Not allowed")); + } + + #[test] + fn test_process_pre_tool_use_output_block_on_decision() { + let results = vec![HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "Blocked by policy"}"#.to_string(), + stderr: String::new(), + }]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Block(msg) if msg == "Blocked by policy")); + } + + #[test] + fn test_process_pre_tool_use_output_non_blocking_error_allows() { + let results = vec![HookExecutionResult { + exit_code: Some(1), + stdout: String::new(), + stderr: "some error".to_string(), + }]; + let actual = UserHookHandler::::process_pre_tool_use_output(&results); + assert!(matches!(actual, PreToolUseDecision::Allow)); + } + + #[test] + fn test_process_results_no_blocking() { + let results = vec![HookExecutionResult { + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }]; + let actual = UserHookHandler::::process_results(&results); + assert!(actual.is_none()); + } + + #[test] + fn test_process_results_blocking_exit_code() { + let results = vec![HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "stop reason".to_string(), + }]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!(actual, Some("stop reason".to_string())); + } + + #[test] + fn test_process_results_blocking_json_decision() { + let results = vec![HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "keep going"}"#.to_string(), + stderr: String::new(), + }]; + let actual = UserHookHandler::::process_results(&results); + assert_eq!(actual, Some("keep going".to_string())); + } + + #[test] + fn test_has_hooks_returns_false_for_empty_config() { + let config = UserHookConfig::new(); + let handler = null_handler(config); + assert!(!handler.has_hooks(&UserHookEventName::PreToolUse)); + } + + #[test] + fn test_has_hooks_returns_true_when_configured() { + let json = r#"{"PreToolUse": [{"hooks": [{"type": "command", "command": "echo hi"}]}]}"#; + let config: UserHookConfig = serde_json::from_str(json).unwrap(); + let handler = null_handler(config); + assert!(handler.has_hooks(&UserHookEventName::PreToolUse)); + assert!(!handler.has_hooks(&UserHookEventName::Stop)); + } +} diff --git a/crates/forge_app/src/infra.rs b/crates/forge_app/src/infra.rs index 8c1c567772..fa71f7b479 100644 --- a/crates/forge_app/src/infra.rs +++ b/crates/forge_app/src/infra.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::hash::Hash; use std::path::{Path, PathBuf}; @@ -160,6 +160,28 @@ pub trait CommandInfra: Send + Sync { working_dir: PathBuf, env_vars: Option>, ) -> anyhow::Result; + + /// Executes a shell command with stdin input. + /// + /// Pipes `stdin_input` to the process stdin, captures stdout and stderr, + /// and waits for the process to complete. Timeout enforcement is handled + /// by the caller. + /// + /// # Arguments + /// * `command` - Shell command string to execute. + /// * `working_dir` - Working directory for the command. + /// * `stdin_input` - Data to pipe to the process stdin. + /// * `env_vars` - Additional environment variables as key-value pairs. + /// + /// # Errors + /// Returns an error if the process cannot be spawned. + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result; } #[async_trait::async_trait] diff --git a/crates/forge_app/src/orch.rs b/crates/forge_app/src/orch.rs index 86202f9266..636ed3441d 100644 --- a/crates/forge_app/src/orch.rs +++ b/crates/forge_app/src/orch.rs @@ -84,22 +84,35 @@ impl Orchestrator { notifier.notified().await; } - // Fire the ToolcallStart lifecycle event + // Fire the ToolcallStart lifecycle event. + // If a hook returns an error (e.g., PreToolUse hook blocked the + // call), skip execution and record an error result instead. let toolcall_start_event = LifecycleEvent::ToolcallStart(EventData::new( self.agent.clone(), self.agent.model.clone(), ToolcallStartPayload::new(tool_call.clone()), )); - self.hook + let hook_result = self + .hook .handle(&toolcall_start_event, &mut self.conversation) - .await?; - - // Execute the tool - let tool_result = self - .services - .call(&self.agent, tool_context, tool_call.clone()) .await; + let tool_result = if let Err(hook_err) = hook_result { + // Hook blocked this tool call — notify the UI and produce an + // error ToolResult so the model sees feedback without aborting. + self.send(ChatResponse::HookError { + tool_name: tool_call.name.clone(), + reason: hook_err.to_string(), + }) + .await?; + ToolResult::from(tool_call.clone()).failure(hook_err) + } else { + // Execute the tool normally + self.services + .call(&self.agent, tool_context, tool_call.clone()) + .await + }; + // Fire the ToolcallEnd lifecycle event (fires on both success and failure) let toolcall_end_event = LifecycleEvent::ToolcallEnd(EventData::new( self.agent.clone(), diff --git a/crates/forge_app/src/services.rs b/crates/forge_app/src/services.rs index 9cf7a12c89..46386b862e 100644 --- a/crates/forge_app/src/services.rs +++ b/crates/forge_app/src/services.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::time::Duration; @@ -493,6 +494,18 @@ pub trait CommandLoaderService: Send + Sync { async fn get_commands(&self) -> anyhow::Result>; } +#[async_trait::async_trait] +pub trait UserHookConfigService: Send + Sync { + /// Loads and merges user hook configurations from all settings file + /// locations. + /// + /// Resolution order (all merged, not overridden): + /// 1. `~/.forge/settings.json` (user-level, applies to all projects) + /// 2. `.forge/settings.json` (project-level, committable) + /// 3. `.forge/settings.local.json` (project-level, gitignored) + async fn get_user_hook_config(&self) -> anyhow::Result; +} + #[async_trait::async_trait] pub trait PolicyService: Send + Sync { /// Check if an operation is allowed and handle user confirmation if needed @@ -548,6 +561,34 @@ pub trait ProviderAuthService: Send + Sync { ) -> anyhow::Result>; } +/// Service for executing hook commands with stdin input and timeout. +/// +/// Abstracts over the underlying process execution so that `UserHookExecutor` +/// depends on a service rather than infrastructure directly. +#[async_trait::async_trait] +pub trait HookCommandService: Send + Sync { + /// Executes a shell command with stdin input. + /// + /// Pipes `stdin_input` to the process stdin and captures stdout/stderr. + /// Timeout enforcement is handled by the caller. + /// + /// # Arguments + /// * `command` - Shell command string to execute. + /// * `working_dir` - Working directory for the command. + /// * `stdin_input` - Data to pipe to the process stdin. + /// * `env_vars` - Additional environment variables as key-value pairs. + /// + /// # Errors + /// Returns an error if the process cannot be spawned. + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result; +} + pub trait Services: Send + Sync + 'static + Clone + EnvironmentInfra { type ProviderService: ProviderService; type AppConfigService: AppConfigService; @@ -572,10 +613,12 @@ pub trait Services: Send + Sync + 'static + Clone + EnvironmentInfra { type AuthService: AuthService; type AgentRegistry: AgentRegistry; type CommandLoaderService: CommandLoaderService; + type UserHookConfigService: UserHookConfigService; type PolicyService: PolicyService; type ProviderAuthService: ProviderAuthService; type WorkspaceService: WorkspaceService; type SkillFetchService: SkillFetchService; + type HookCommandService: HookCommandService + Clone; fn provider_service(&self) -> &Self::ProviderService; fn config_service(&self) -> &Self::AppConfigService; @@ -600,10 +643,12 @@ pub trait Services: Send + Sync + 'static + Clone + EnvironmentInfra { fn auth_service(&self) -> &Self::AuthService; fn agent_registry(&self) -> &Self::AgentRegistry; fn command_loader_service(&self) -> &Self::CommandLoaderService; + fn user_hook_config_service(&self) -> &Self::UserHookConfigService; fn policy_service(&self) -> &Self::PolicyService; fn provider_auth_service(&self) -> &Self::ProviderAuthService; fn workspace_service(&self) -> &Self::WorkspaceService; fn skill_fetch_service(&self) -> &Self::SkillFetchService; + fn hook_command_service(&self) -> &Self::HookCommandService; } #[async_trait::async_trait] @@ -937,6 +982,13 @@ impl CommandLoaderService for I { } } +#[async_trait::async_trait] +impl UserHookConfigService for I { + async fn get_user_hook_config(&self) -> anyhow::Result { + self.user_hook_config_service().get_user_hook_config().await + } +} + #[async_trait::async_trait] impl PolicyService for I { async fn check_operation_permission( diff --git a/crates/forge_config/.forge.toml b/crates/forge_config/.forge.toml index e0e4268789..6bf8423ff3 100644 --- a/crates/forge_config/.forge.toml +++ b/crates/forge_config/.forge.toml @@ -23,6 +23,7 @@ sem_search_top_k = 10 services_url = "https://api.forgecode.dev/" tool_supported = true tool_timeout_secs = 300 +hook_timeout_ms = 600000 top_k = 30 top_p = 0.8 diff --git a/crates/forge_domain/src/chat_response.rs b/crates/forge_domain/src/chat_response.rs index e24cd9d731..3f5f5347ec 100644 --- a/crates/forge_domain/src/chat_response.rs +++ b/crates/forge_domain/src/chat_response.rs @@ -65,6 +65,14 @@ pub enum ChatResponse { notifier: Arc, }, ToolCallEnd(ToolResult), + /// A user-configured hook blocked execution of a tool call. + HookError { + /// Name of the tool that was blocked. + tool_name: ToolName, + /// Human-readable reason provided by the hook (from stderr or JSON + /// output). + reason: String, + }, RetryAttempt { cause: Cause, duration: Duration, diff --git a/crates/forge_domain/src/lib.rs b/crates/forge_domain/src/lib.rs index 13a6b18135..1fd033191c 100644 --- a/crates/forge_domain/src/lib.rs +++ b/crates/forge_domain/src/lib.rs @@ -51,6 +51,8 @@ mod top_k; mod top_p; mod transformer; mod update; +mod user_hook_config; +mod user_hook_io; mod validation; mod workspace; mod xml; @@ -106,6 +108,8 @@ pub use top_k::*; pub use top_p::*; pub use transformer::*; pub use update::*; +pub use user_hook_config::*; +pub use user_hook_io::*; pub use validation::*; pub use workspace::*; pub use xml::*; diff --git a/crates/forge_domain/src/user_hook_config.rs b/crates/forge_domain/src/user_hook_config.rs new file mode 100644 index 0000000000..9a05d77786 --- /dev/null +++ b/crates/forge_domain/src/user_hook_config.rs @@ -0,0 +1,293 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; +use strum_macros::Display; + +/// Top-level user hook configuration. +/// +/// Maps hook event names to a list of matcher groups. This is deserialized +/// from the `"hooks"` key in `.forge/settings.json` or +/// `~/.forge/settings.json`. +/// +/// Example JSON: +/// ```json +/// { +/// "PreToolUse": [ +/// { "matcher": "Bash", "hooks": [{ "type": "command", "command": "echo hi" }] } +/// ] +/// } +/// ``` +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UserHookConfig { + /// Map of event name -> list of matcher groups + #[serde(flatten)] + pub events: HashMap>, +} + +impl UserHookConfig { + /// Creates an empty user hook configuration. + pub fn new() -> Self { + Self { events: HashMap::new() } + } + + /// Returns the matcher groups for a given event name, or an empty slice if + /// none. + pub fn get_groups(&self, event: &UserHookEventName) -> &[UserHookMatcherGroup] { + self.events.get(event).map_or(&[], |v| v.as_slice()) + } + + /// Merges another config into this one, appending matcher groups for each + /// event. + pub fn merge(&mut self, other: UserHookConfig) { + for (event, groups) in other.events { + self.events.entry(event).or_default().extend(groups); + } + } + + /// Returns true if no hook events are configured. + pub fn is_empty(&self) -> bool { + self.events.is_empty() + } +} + +/// Supported hook event names that map to lifecycle points in the +/// orchestrator. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Display)] +pub enum UserHookEventName { + /// Fired before a tool call executes. Can block execution. + PreToolUse, + /// Fired after a tool call succeeds. + PostToolUse, + /// Fired after a tool call fails. + PostToolUseFailure, + /// Fired when the agent finishes responding. Can block stop to continue. + Stop, + /// FIXME: Fired when a notification is sent; no lifecycle point fires this + /// event and no handler exists yet. + Notification, + /// Fired when a session starts or resumes. + SessionStart, + /// Fired when a session ends/terminates. + SessionEnd, + /// Fired when a user prompt is submitted. + UserPromptSubmit, + /// FIXME: Fired before context compaction; no lifecycle point fires this + /// event and no handler exists yet. + PreCompact, + /// FIXME: Fired after context compaction; no lifecycle point fires this + /// event and no handler exists yet. + PostCompact, +} + +/// A matcher group pairs an optional regex matcher with a list of hook +/// handlers. +/// +/// When a lifecycle event fires, only matcher groups whose `matcher` regex +/// matches the relevant event context (e.g., tool name) will have their hooks +/// executed. If `matcher` is `None`, all hooks in this group fire +/// unconditionally. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UserHookMatcherGroup { + /// Optional regex pattern to match against (e.g., tool name for + /// PreToolUse/PostToolUse). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub matcher: Option, + + /// List of hook handlers to execute when this matcher matches. + #[serde(default)] + pub hooks: Vec, +} + +/// A single hook handler entry that defines what action to take. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UserHookEntry { + /// The type of hook handler. + #[serde(rename = "type")] + pub hook_type: UserHookType, + + /// The shell command to execute (for `Command` type hooks). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub command: Option, + + /// Timeout in milliseconds for this hook. Defaults to 600000ms (10 + /// minutes). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +/// The type of hook handler to execute. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum UserHookType { + /// Executes a shell command, piping JSON to stdin and reading JSON from + /// stdout. + Command, +} + +/// Wrapper for the top-level settings JSON that contains the hooks key. +/// +/// Used for deserializing the entire settings file and extracting just the +/// `"hooks"` section. +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UserSettings { + /// User hook configuration. + #[serde(default)] + pub hooks: UserHookConfig, +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_deserialize_empty_config() { + let json = r#"{}"#; + let actual: UserHookConfig = serde_json::from_str(json).unwrap(); + let expected = UserHookConfig::new(); + assert_eq!(actual, expected); + } + + #[test] + fn test_deserialize_pre_tool_use_hook() { + let json = r#"{ + "PreToolUse": [ + { + "matcher": "Bash", + "hooks": [ + { + "type": "command", + "command": "echo 'blocked'" + } + ] + } + ] + }"#; + + let actual: UserHookConfig = serde_json::from_str(json).unwrap(); + let groups = actual.get_groups(&UserHookEventName::PreToolUse); + + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].matcher, Some("Bash".to_string())); + assert_eq!(groups[0].hooks.len(), 1); + assert_eq!(groups[0].hooks[0].hook_type, UserHookType::Command); + assert_eq!( + groups[0].hooks[0].command, + Some("echo 'blocked'".to_string()) + ); + } + + #[test] + fn test_deserialize_multiple_events() { + let json = r#"{ + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "pre.sh" }] } + ], + "PostToolUse": [ + { "hooks": [{ "type": "command", "command": "post.sh" }] } + ], + "Stop": [ + { "hooks": [{ "type": "command", "command": "stop.sh" }] } + ] + }"#; + + let actual: UserHookConfig = serde_json::from_str(json).unwrap(); + + assert_eq!(actual.get_groups(&UserHookEventName::PreToolUse).len(), 1); + assert_eq!(actual.get_groups(&UserHookEventName::PostToolUse).len(), 1); + assert_eq!(actual.get_groups(&UserHookEventName::Stop).len(), 1); + assert!( + actual + .get_groups(&UserHookEventName::SessionStart) + .is_empty() + ); + } + + #[test] + fn test_deserialize_hook_with_timeout() { + let json = r#"{ + "PreToolUse": [ + { + "hooks": [ + { "type": "command", "command": "slow.sh", "timeout": 30000 } + ] + } + ] + }"#; + + let actual: UserHookConfig = serde_json::from_str(json).unwrap(); + let groups = actual.get_groups(&UserHookEventName::PreToolUse); + + assert_eq!(groups[0].hooks[0].timeout, Some(30000)); + } + + #[test] + fn test_merge_configs() { + let json1 = r#"{ + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "hook1.sh" }] } + ] + }"#; + let json2 = r#"{ + "PreToolUse": [ + { "matcher": "Write", "hooks": [{ "type": "command", "command": "hook2.sh" }] } + ], + "Stop": [ + { "hooks": [{ "type": "command", "command": "stop.sh" }] } + ] + }"#; + + let mut actual: UserHookConfig = serde_json::from_str(json1).unwrap(); + let config2: UserHookConfig = serde_json::from_str(json2).unwrap(); + actual.merge(config2); + + assert_eq!(actual.get_groups(&UserHookEventName::PreToolUse).len(), 2); + assert_eq!(actual.get_groups(&UserHookEventName::Stop).len(), 1); + } + + #[test] + fn test_deserialize_settings_with_hooks() { + let json = r#"{ + "hooks": { + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "check.sh" }] } + ] + } + }"#; + + let actual: UserSettings = serde_json::from_str(json).unwrap(); + + assert!(!actual.hooks.is_empty()); + assert_eq!( + actual + .hooks + .get_groups(&UserHookEventName::PreToolUse) + .len(), + 1 + ); + } + + #[test] + fn test_deserialize_settings_without_hooks() { + let json = r#"{}"#; + let actual: UserSettings = serde_json::from_str(json).unwrap(); + + assert!(actual.hooks.is_empty()); + } + + #[test] + fn test_no_matcher_group_fires_unconditionally() { + let json = r#"{ + "PostToolUse": [ + { "hooks": [{ "type": "command", "command": "always.sh" }] } + ] + }"#; + + let actual: UserHookConfig = serde_json::from_str(json).unwrap(); + let groups = actual.get_groups(&UserHookEventName::PostToolUse); + + assert_eq!(groups.len(), 1); + assert_eq!(groups[0].matcher, None); + } +} diff --git a/crates/forge_domain/src/user_hook_io.rs b/crates/forge_domain/src/user_hook_io.rs new file mode 100644 index 0000000000..c66755a8b5 --- /dev/null +++ b/crates/forge_domain/src/user_hook_io.rs @@ -0,0 +1,335 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +/// Exit code constants for hook script results. +pub mod exit_codes { + /// Hook executed successfully. stdout may contain JSON output. + pub const SUCCESS: i32 = 0; + /// Blocking error. stderr is used as feedback message. + pub const BLOCK: i32 = 2; +} + +/// JSON input sent to hook scripts via stdin. +/// +/// Contains common fields shared across all hook events plus event-specific +/// data in the `event_data` field. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct HookInput { + /// The hook event name (e.g., "PreToolUse", "PostToolUse", "Stop"). + pub hook_event_name: String, + + /// Current working directory. + pub cwd: String, + + /// Session/conversation ID. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub session_id: Option, + + /// Event-specific payload data. + #[serde(flatten)] + pub event_data: HookEventInput, +} + +/// Event-specific input data variants. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum HookEventInput { + /// Input for PreToolUse events. + PreToolUse { + /// Name of the tool being called. + tool_name: String, + /// Tool call arguments as a JSON value. + tool_input: Value, + }, + /// Input for PostToolUse events. + PostToolUse { + /// Name of the tool that was called. + tool_name: String, + /// Tool call arguments as a JSON value. + tool_input: Value, + /// Tool output/response as a JSON value. + tool_response: Value, + }, + /// Input for Stop events. + Stop { + /// Whether a Stop hook has already fired (prevents infinite loops). + stop_hook_active: bool, + }, + /// Input for SessionStart events. + SessionStart { + /// Source of the session start (e.g., "startup", "resume"). + source: String, + }, + /// Input for UserPromptSubmit events. + UserPromptSubmit { + /// The raw prompt text submitted by the user. + prompt: String, + }, + /// Empty input for events that don't need event-specific data. + Empty {}, +} + +/// JSON output parsed from hook script stdout. +/// +/// Fields are optional; scripts that don't need to control behavior can simply +/// exit 0 with empty stdout. +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct HookOutput { + /// FIXME: Whether execution should continue; deserialized from hook stdout + /// but never checked in any decision logic (`process_results`, + /// `is_blocking()`). + #[serde(default, rename = "continue", skip_serializing_if = "Option::is_none")] + pub continue_execution: Option, + + /// Decision for blocking events. `"block"` blocks the operation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decision: Option, + + /// Reason for blocking, used as feedback to the agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub reason: Option, + + /// For PreToolUse: permission decision ("allow", "deny", "ask"). + #[serde( + default, + rename = "permissionDecision", + skip_serializing_if = "Option::is_none" + )] + pub permission_decision: Option, + + /// For PreToolUse: modified tool input to replace the original. + #[serde( + default, + rename = "updatedInput", + skip_serializing_if = "Option::is_none" + )] + pub updated_input: Option, + + /// Additional context to inject into the conversation. + #[serde( + default, + rename = "additionalContext", + skip_serializing_if = "Option::is_none" + )] + pub additional_context: Option, + + /// FIXME: Reason for stopping (for Stop hooks); deserialized from hook + /// stdout but never consumed anywhere in decision logic. + #[serde( + default, + rename = "stopReason", + skip_serializing_if = "Option::is_none" + )] + pub stop_reason: Option, +} + +impl HookOutput { + /// Attempts to parse stdout as JSON. Falls back to empty output on failure. + pub fn parse(stdout: &str) -> Self { + if stdout.trim().is_empty() { + return Self::default(); + } + serde_json::from_str(stdout).unwrap_or_default() + } + + /// Returns true if this output requests blocking. + pub fn is_blocking(&self) -> bool { + self.decision.as_deref() == Some("block") + || self.permission_decision.as_deref() == Some("deny") + } +} + +/// Result of executing a hook command. +#[derive(Debug, Clone)] +pub struct HookExecutionResult { + /// Process exit code (None if terminated by signal). + pub exit_code: Option, + /// Captured stdout. + pub stdout: String, + /// Captured stderr. + pub stderr: String, +} + +impl HookExecutionResult { + /// Returns true if the hook exited with the blocking exit code (2). + pub fn is_blocking_exit(&self) -> bool { + self.exit_code == Some(exit_codes::BLOCK) + } + + /// Returns true if the hook exited successfully (0). + pub fn is_success(&self) -> bool { + self.exit_code == Some(exit_codes::SUCCESS) + } + + /// Returns true if the hook exited with a non-blocking error (non-0, + /// non-2). + pub fn is_non_blocking_error(&self) -> bool { + match self.exit_code { + Some(code) => code != exit_codes::SUCCESS && code != exit_codes::BLOCK, + None => true, + } + } + + /// Parses the stdout as a HookOutput if the exit was successful. + pub fn parse_output(&self) -> Option { + if self.is_success() { + Some(HookOutput::parse(&self.stdout)) + } else { + None + } + } + + /// Returns the feedback message for blocking errors (stderr content). + pub fn blocking_message(&self) -> Option<&str> { + if self.is_blocking_exit() { + let msg = self.stderr.trim(); + if msg.is_empty() { None } else { Some(msg) } + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + + #[test] + fn test_hook_input_serialization_pre_tool_use() { + let fixture = HookInput { + hook_event_name: "PreToolUse".to_string(), + cwd: "/project".to_string(), + session_id: Some("sess-123".to_string()), + event_data: HookEventInput::PreToolUse { + tool_name: "Bash".to_string(), + tool_input: serde_json::json!({"command": "ls"}), + }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["hook_event_name"], "PreToolUse"); + assert_eq!(actual["cwd"], "/project"); + assert_eq!(actual["tool_name"], "Bash"); + assert_eq!(actual["tool_input"]["command"], "ls"); + } + + #[test] + fn test_hook_input_serialization_stop() { + let fixture = HookInput { + hook_event_name: "Stop".to_string(), + cwd: "/project".to_string(), + session_id: None, + event_data: HookEventInput::Stop { stop_hook_active: false }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["hook_event_name"], "Stop"); + assert_eq!(actual["stop_hook_active"], false); + } + + #[test] + fn test_hook_input_serialization_user_prompt_submit() { + let fixture = HookInput { + hook_event_name: "UserPromptSubmit".to_string(), + cwd: "/project".to_string(), + session_id: Some("sess-abc".to_string()), + event_data: HookEventInput::UserPromptSubmit { prompt: "fix the bug".to_string() }, + }; + + let actual = serde_json::to_value(&fixture).unwrap(); + + assert_eq!(actual["hook_event_name"], "UserPromptSubmit"); + assert_eq!(actual["cwd"], "/project"); + assert_eq!(actual["session_id"], "sess-abc"); + assert_eq!(actual["prompt"], "fix the bug"); + // No tool_name, stop_hook_active, or other variant fields present + assert!(actual["tool_name"].is_null()); + assert!(actual["stop_hook_active"].is_null()); + } + + #[test] + fn test_hook_output_parse_valid_json() { + let stdout = r#"{"decision": "block", "reason": "unsafe command"}"#; + let actual = HookOutput::parse(stdout); + + assert_eq!(actual.decision, Some("block".to_string())); + assert_eq!(actual.reason, Some("unsafe command".to_string())); + } + + #[test] + fn test_hook_output_parse_empty_string() { + let actual = HookOutput::parse(""); + let expected = HookOutput::default(); + assert_eq!(actual, expected); + } + + #[test] + fn test_hook_output_parse_invalid_json_returns_default() { + let actual = HookOutput::parse("not json at all"); + let expected = HookOutput::default(); + assert_eq!(actual, expected); + } + + #[test] + fn test_hook_output_is_blocking() { + let fixture = HookOutput { decision: Some("block".to_string()), ..Default::default() }; + assert!(fixture.is_blocking()); + + let fixture = HookOutput { + permission_decision: Some("deny".to_string()), + ..Default::default() + }; + assert!(fixture.is_blocking()); + + let fixture = HookOutput::default(); + assert!(!fixture.is_blocking()); + } + + #[test] + fn test_hook_execution_result_blocking() { + let fixture = HookExecutionResult { + exit_code: Some(2), + stdout: String::new(), + stderr: "Blocked: unsafe command".to_string(), + }; + + assert!(fixture.is_blocking_exit()); + assert!(!fixture.is_success()); + assert!(!fixture.is_non_blocking_error()); + assert_eq!(fixture.blocking_message(), Some("Blocked: unsafe command")); + assert!(fixture.parse_output().is_none()); + } + + #[test] + fn test_hook_execution_result_success() { + let fixture = HookExecutionResult { + exit_code: Some(0), + stdout: r#"{"decision": "block", "reason": "test"}"#.to_string(), + stderr: String::new(), + }; + + assert!(fixture.is_success()); + assert!(!fixture.is_blocking_exit()); + assert!(!fixture.is_non_blocking_error()); + let output = fixture.parse_output().unwrap(); + assert!(output.is_blocking()); + } + + #[test] + fn test_hook_execution_result_non_blocking_error() { + let fixture = HookExecutionResult { + exit_code: Some(1), + stdout: String::new(), + stderr: "some error".to_string(), + }; + + assert!(fixture.is_non_blocking_error()); + assert!(!fixture.is_success()); + assert!(!fixture.is_blocking_exit()); + assert!(fixture.blocking_message().is_none()); + } +} diff --git a/crates/forge_infra/src/executor.rs b/crates/forge_infra/src/executor.rs index 13f30d8c8d..a795556abc 100644 --- a/crates/forge_infra/src/executor.rs +++ b/crates/forge_infra/src/executor.rs @@ -1,10 +1,11 @@ +use std::collections::HashMap; use std::io::{self, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; use forge_app::CommandInfra; use forge_domain::{CommandOutput, ConsoleWriter as OutputPrinterTrait, Environment}; -use tokio::io::AsyncReadExt; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::process::Command; use tokio::sync::Mutex; @@ -224,6 +225,43 @@ impl CommandInfra for ForgeCommandExecutorService { Ok(prepared_command.spawn()?.wait().await?) } + + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + let mut prepared_command = self.prepare_command(&command, &working_dir, None); + + // Set directly-provided key-value env vars + for (key, value) in &env_vars { + prepared_command.env(key, value); + } + + // Override stdin to piped so we can write to it + prepared_command.stdin(std::process::Stdio::piped()); + + let mut child = prepared_command.spawn()?; + + // Pipe the JSON input to stdin + if let Some(mut stdin) = child.stdin.take() { + let input = stdin_input.clone(); + tokio::spawn(async move { + let _ = stdin.write_all(input.as_bytes()).await; + let _ = stdin.shutdown().await; + }); + } + + let output = child.wait_with_output().await?; + Ok(CommandOutput { + command, + exit_code: output.status.code(), + stdout: String::from_utf8_lossy(&output.stdout).into_owned(), + stderr: String::from_utf8_lossy(&output.stderr).into_owned(), + }) + } } #[cfg(test)] diff --git a/crates/forge_infra/src/forge_infra.rs b/crates/forge_infra/src/forge_infra.rs index 2ae84ab33f..5b7857109b 100644 --- a/crates/forge_infra/src/forge_infra.rs +++ b/crates/forge_infra/src/forge_infra.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::path::{Path, PathBuf}; use std::process::ExitStatus; use std::sync::Arc; @@ -227,6 +227,18 @@ impl CommandInfra for ForgeInfra { .execute_command_raw(command, working_dir, env_vars) .await } + + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + self.command_executor_service + .execute_command_with_input(command, working_dir, stdin_input, env_vars) + .await + } } #[async_trait::async_trait] diff --git a/crates/forge_main/src/info.rs b/crates/forge_main/src/info.rs index b0815a8799..49d5fd1e23 100644 --- a/crates/forge_main/src/info.rs +++ b/crates/forge_main/src/info.rs @@ -387,6 +387,7 @@ impl From<&ForgeConfig> for Info { .add_key_value("ForgeCode Service URL", config.services_url.to_string()) .add_title("TOOL CONFIGURATION") .add_key_value("Tool Timeout", format!("{}s", config.tool_timeout_secs)) + .add_key("Hook Timed out") .add_key_value( "Max Image Size", format!("{} bytes", config.max_image_size_bytes), diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index f5954dbb91..2a33e42883 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -3227,6 +3227,14 @@ impl A + Send + Sync> UI { self.writeln_title(TitleFormat::error(cause.as_str()))?; } } + ChatResponse::HookError { tool_name, reason } => { + writer.finish()?; + self.spinner.stop(None)?; + self.writeln_title(TitleFormat::error(format!( + "PreToolUse:{tool_name} hook error: {reason}" + )))?; + self.spinner.start(None)?; + } ChatResponse::Interrupt { reason } => { writer.finish()?; self.spinner.stop(None)?; diff --git a/crates/forge_repo/src/forge_repo.rs b/crates/forge_repo/src/forge_repo.rs index ce203f6c4b..4289b74266 100644 --- a/crates/forge_repo/src/forge_repo.rs +++ b/crates/forge_repo/src/forge_repo.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashMap}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -459,6 +459,18 @@ where .execute_command_raw(command, working_dir, env_vars) .await } + + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + self.infra + .execute_command_with_input(command, working_dir, stdin_input, env_vars) + .await + } } #[async_trait::async_trait] diff --git a/crates/forge_services/src/forge_services.rs b/crates/forge_services/src/forge_services.rs index 5ec09e1191..4a02947705 100644 --- a/crates/forge_services/src/forge_services.rs +++ b/crates/forge_services/src/forge_services.rs @@ -26,8 +26,10 @@ use crate::provider_service::ForgeProviderService; use crate::template::ForgeTemplateService; use crate::tool_services::{ ForgeFetch, ForgeFollowup, ForgeFsPatch, ForgeFsRead, ForgeFsRemove, ForgeFsSearch, - ForgeFsUndo, ForgeFsWrite, ForgeImageRead, ForgePlanCreate, ForgeShell, ForgeSkillFetch, + ForgeFsUndo, ForgeFsWrite, ForgeHookCommandService, ForgeImageRead, ForgePlanCreate, + ForgeShell, ForgeSkillFetch, }; +use crate::user_hook_config::ForgeUserHookConfigService; type McpService = ForgeMcpService, F, ::Client>; type AuthService = ForgeAuthService; @@ -79,10 +81,12 @@ pub struct ForgeServices< auth_service: Arc>, agent_registry_service: Arc>, command_loader_service: Arc>, + user_hook_config_service: Arc>, policy_service: ForgePolicyService, provider_auth_service: ForgeProviderAuthService, workspace_service: Arc>>, skill_service: Arc>, + hook_command_service: Arc>, infra: Arc, } @@ -134,6 +138,7 @@ impl< Arc::new(ForgeCustomInstructionsService::new(infra.clone())); let agent_registry_service = Arc::new(ForgeAgentRegistryService::new(infra.clone())); let command_loader_service = Arc::new(ForgeCommandLoaderService::new(infra.clone())); + let user_hook_config_service = Arc::new(ForgeUserHookConfigService::new(infra.clone())); let policy_service = ForgePolicyService::new(infra.clone()); let provider_auth_service = ForgeProviderAuthService::new(infra.clone()); let discovery = Arc::new(FdDefault::new(infra.clone())); @@ -142,6 +147,7 @@ impl< discovery, )); let skill_service = Arc::new(ForgeSkillFetch::new(infra.clone())); + let hook_command_service = Arc::new(ForgeHookCommandService::new(infra.clone())); Self { conversation_service, @@ -166,10 +172,12 @@ impl< config_service, agent_registry_service, command_loader_service, + user_hook_config_service, policy_service, provider_auth_service, workspace_service, skill_service, + hook_command_service, chat_service, infra, } @@ -233,10 +241,12 @@ impl< type AuthService = AuthService; type AgentRegistry = ForgeAgentRegistryService; type CommandLoaderService = ForgeCommandLoaderService; + type UserHookConfigService = ForgeUserHookConfigService; type PolicyService = ForgePolicyService; type ProviderService = ForgeProviderService; type WorkspaceService = crate::context_engine::ForgeWorkspaceService>; type SkillFetchService = ForgeSkillFetch; + type HookCommandService = ForgeHookCommandService; fn config_service(&self) -> &Self::AppConfigService { &self.config_service @@ -322,6 +332,10 @@ impl< &self.command_loader_service } + fn user_hook_config_service(&self) -> &Self::UserHookConfigService { + &self.user_hook_config_service + } + fn policy_service(&self) -> &Self::PolicyService { &self.policy_service } @@ -337,6 +351,10 @@ impl< &self.skill_service } + fn hook_command_service(&self) -> &Self::HookCommandService { + &self.hook_command_service + } + fn provider_service(&self) -> &Self::ProviderService { &self.chat_service } diff --git a/crates/forge_services/src/lib.rs b/crates/forge_services/src/lib.rs index bb102e86c6..36060171e1 100644 --- a/crates/forge_services/src/lib.rs +++ b/crates/forge_services/src/lib.rs @@ -22,6 +22,7 @@ mod range; mod sync; mod template; mod tool_services; +mod user_hook_config; mod utils; pub use app_config::*; diff --git a/crates/forge_services/src/tool_services/hook_command.rs b/crates/forge_services/src/tool_services/hook_command.rs new file mode 100644 index 0000000000..3588bc0fa3 --- /dev/null +++ b/crates/forge_services/src/tool_services/hook_command.rs @@ -0,0 +1,37 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use forge_app::{CommandInfra, HookCommandService}; +use forge_domain::CommandOutput; + +/// Thin wrapper around [`CommandInfra::execute_command_with_input`] that +/// satisfies the [`HookCommandService`] contract. +/// +/// By delegating to the underlying infra this service avoids duplicating +/// process-spawning and stdin-piping logic; those concerns live entirely inside +/// the `CommandInfra` implementation. +#[derive(Clone)] +pub struct ForgeHookCommandService(Arc); + +impl ForgeHookCommandService { + /// Creates a new `ForgeHookCommandService` backed by the given infra. + pub fn new(infra: Arc) -> Self { + Self(infra) + } +} + +#[async_trait::async_trait] +impl HookCommandService for ForgeHookCommandService { + async fn execute_command_with_input( + &self, + command: String, + working_dir: PathBuf, + stdin_input: String, + env_vars: HashMap, + ) -> anyhow::Result { + self.0 + .execute_command_with_input(command, working_dir, stdin_input, env_vars) + .await + } +} diff --git a/crates/forge_services/src/tool_services/mod.rs b/crates/forge_services/src/tool_services/mod.rs index 64a5c6f3c0..75e78f3d7a 100644 --- a/crates/forge_services/src/tool_services/mod.rs +++ b/crates/forge_services/src/tool_services/mod.rs @@ -6,6 +6,7 @@ mod fs_remove; mod fs_search; mod fs_undo; mod fs_write; +mod hook_command; mod image_read; mod plan_create; mod shell; @@ -19,6 +20,7 @@ pub use fs_remove::*; pub use fs_search::*; pub use fs_undo::*; pub use fs_write::*; +pub use hook_command::*; pub use image_read::*; pub use plan_create::*; pub use shell::*; diff --git a/crates/forge_services/src/tool_services/shell.rs b/crates/forge_services/src/tool_services/shell.rs index 74cbe34405..1988779f77 100644 --- a/crates/forge_services/src/tool_services/shell.rs +++ b/crates/forge_services/src/tool_services/shell.rs @@ -108,6 +108,21 @@ mod tests { ) -> anyhow::Result { unimplemented!() } + + async fn execute_command_with_input( + &self, + command: String, + _working_dir: PathBuf, + _stdin_input: String, + _env_vars: std::collections::HashMap, + ) -> anyhow::Result { + Ok(forge_domain::CommandOutput { + command, + exit_code: Some(0), + stdout: String::new(), + stderr: String::new(), + }) + } } impl EnvironmentInfra for MockCommandInfra { diff --git a/crates/forge_services/src/user_hook_config.rs b/crates/forge_services/src/user_hook_config.rs new file mode 100644 index 0000000000..cee18c6439 --- /dev/null +++ b/crates/forge_services/src/user_hook_config.rs @@ -0,0 +1,363 @@ +use std::path::Path; +use std::sync::Arc; + +use forge_app::{EnvironmentInfra, FileInfoInfra, FileReaderInfra}; +use forge_domain::{UserHookConfig, UserSettings}; + +/// Loads and merges user hook configurations from the three settings file +/// locations using infrastructure abstractions. +/// +/// Resolution order (all merged, not overridden): +/// 1. `~/.forge/settings.json` (user-level, applies to all projects) +/// 2. `.forge/settings.json` (project-level, committable) +/// 3. `.forge/settings.local.json` (project-level, gitignored) +pub struct ForgeUserHookConfigService(Arc); + +impl ForgeUserHookConfigService { + /// Creates a new service with the given infrastructure dependency. + pub fn new(infra: Arc) -> Self { + Self(infra) + } +} + +impl ForgeUserHookConfigService { + /// Loads a single settings file and extracts hook configuration. + /// + /// Returns `Ok(None)` if the file does not exist or cannot be read. + /// Returns `Err` if the file exists but fails to deserialize, including the + /// file path in the error message. + async fn load_file(&self, path: &Path) -> anyhow::Result> { + if !self.0.exists(path).await? { + return Ok(None); + } + let contents = self + .0 + .read_utf8(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read '{}': {}", path.display(), e))?; + + match serde_json::from_str::(&contents) { + Ok(settings) => { + if settings.hooks.is_empty() { + Ok(None) + } else { + Ok(Some(settings.hooks)) + } + } + Err(e) => Err(anyhow::anyhow!( + "Failed to deserialize '{}': {}", + path.display(), + e + )), + } + } +} + +#[async_trait::async_trait] +impl forge_app::UserHookConfigService + for ForgeUserHookConfigService +{ + async fn get_user_hook_config(&self) -> anyhow::Result { + let env = self.0.get_environment(); + + // Collect all candidate paths in resolution order + let mut paths: Vec = Vec::new(); + if let Some(home) = &env.home { + paths.push(home.join("forge").join("settings.json")); + } + paths.push(env.cwd.join(".forge").join("settings.json")); + paths.push(env.cwd.join(".forge").join("settings.local.json")); + + // Load every file, keeping the (path, result) pairs + let results = + futures::future::join_all(paths.iter().map(|path| self.load_file(path))).await; + + // Collect the error message from every file that failed + let errors: Vec = results + .iter() + .filter_map(|r| r.as_ref().err().map(|e| e.to_string())) + .collect(); + + if !errors.is_empty() { + return Err(anyhow::anyhow!("{}", errors.join("\n\n"))); + } + + // Merge every successfully loaded config + let mut config = UserHookConfig::new(); + for result in results { + if let Ok(Some(file_config)) = result { + config.merge(file_config); + } + } + + Ok(config) + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use fake::Fake; + use forge_app::UserHookConfigService; + use pretty_assertions::assert_eq; + + use super::*; + + #[tokio::test] + async fn test_load_file_valid_settings() { + let dir = tempfile::tempdir().unwrap(); + let settings_path = dir.path().join("settings.json"); + std::fs::write( + &settings_path, + r#"{ + "hooks": { + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "check.sh" }] } + ] + } + }"#, + ) + .unwrap(); + + let service = fixture(None, PathBuf::from("/nonexistent")); + + let actual = service.load_file(&settings_path).await.unwrap(); + assert!(actual.is_some()); + let config = actual.unwrap(); + assert_eq!( + config + .get_groups(&forge_domain::UserHookEventName::PreToolUse) + .len(), + 1 + ); + } + + #[tokio::test] + async fn test_load_file_settings_without_hooks() { + let dir = tempfile::tempdir().unwrap(); + let settings_path = dir.path().join("settings.json"); + std::fs::write(&settings_path, r#"{"other_key": "value"}"#).unwrap(); + + let service = fixture(None, PathBuf::from("/nonexistent")); + + let actual = service.load_file(&settings_path).await.unwrap(); + assert!(actual.is_none()); + } + + #[tokio::test] + async fn test_load_file_invalid_json_returns_error_with_path() { + let dir = tempfile::tempdir().unwrap(); + let settings_path = dir.path().join("settings.json"); + std::fs::write(&settings_path, r#"{ invalid json }"#).unwrap(); + + let service = fixture(None, PathBuf::from("/nonexistent")); + + let actual = service.load_file(&settings_path).await; + assert!(actual.is_err()); + let err = actual.unwrap_err().to_string(); + assert!( + err.contains(&settings_path.display().to_string()), + "Error message should contain the file path, got: {err}" + ); + } + + #[tokio::test] + async fn test_get_user_hook_config_reports_all_invalid_files() { + let project_dir = tempfile::tempdir().unwrap(); + let project_forge_dir = project_dir.path().join(".forge"); + std::fs::create_dir_all(&project_forge_dir).unwrap(); + + // Both project files have invalid JSON + std::fs::write(project_forge_dir.join("settings.json"), r#"{ bad }"#).unwrap(); + std::fs::write( + project_forge_dir.join("settings.local.json"), + r#"{ also bad }"#, + ) + .unwrap(); + + let service = fixture(None, project_dir.path().to_path_buf()); + + let actual = service.get_user_hook_config().await; + assert!(actual.is_err()); + let err = actual.unwrap_err().to_string(); + assert!( + err.contains("settings.json"), + "Error should mention settings.json, got: {err}" + ); + assert!( + err.contains("settings.local.json"), + "Error should mention settings.local.json, got: {err}" + ); + } + + #[tokio::test] + async fn test_get_user_hook_config_nonexistent_paths() { + let service = fixture( + Some(PathBuf::from("/nonexistent/home")), + PathBuf::from("/nonexistent/project"), + ); + + let actual = service.get_user_hook_config().await.unwrap(); + assert!(actual.is_empty()); + } + + #[tokio::test] + async fn test_get_user_hook_config_merges_all_sources() { + // Set up a fake home directory + let home_dir = tempfile::tempdir().unwrap(); + let forge_dir = home_dir.path().join("forge"); + std::fs::create_dir_all(&forge_dir).unwrap(); + std::fs::write( + forge_dir.join("settings.json"), + r#"{ + "hooks": { + "PreToolUse": [ + { "matcher": "Bash", "hooks": [{ "type": "command", "command": "global.sh" }] } + ] + } + }"#, + ) + .unwrap(); + + // Set up a project directory + let project_dir = tempfile::tempdir().unwrap(); + let project_forge_dir = project_dir.path().join(".forge"); + std::fs::create_dir_all(&project_forge_dir).unwrap(); + std::fs::write( + project_forge_dir.join("settings.json"), + r#"{ + "hooks": { + "PreToolUse": [ + { "matcher": "Write", "hooks": [{ "type": "command", "command": "project.sh" }] } + ] + } + }"#, + ) + .unwrap(); + std::fs::write( + project_forge_dir.join("settings.local.json"), + r#"{ + "hooks": { + "Stop": [ + { "hooks": [{ "type": "command", "command": "local-stop.sh" }] } + ] + } + }"#, + ) + .unwrap(); + + let service = fixture( + Some(home_dir.path().to_path_buf()), + project_dir.path().to_path_buf(), + ); + + let actual = service.get_user_hook_config().await.unwrap(); + + // PreToolUse should have 2 groups (global + project) + assert_eq!( + actual + .get_groups(&forge_domain::UserHookEventName::PreToolUse) + .len(), + 2 + ); + // Stop should have 1 group (local) + assert_eq!( + actual + .get_groups(&forge_domain::UserHookEventName::Stop) + .len(), + 1 + ); + } + + // --- Test helpers --- + + fn fixture(home: Option, cwd: PathBuf) -> ForgeUserHookConfigService { + ForgeUserHookConfigService::new(Arc::new(TestInfra { home, cwd })) + } + + struct TestInfra { + home: Option, + cwd: PathBuf, + } + + #[async_trait::async_trait] + impl FileInfoInfra for TestInfra { + async fn is_binary(&self, _path: &Path) -> anyhow::Result { + Ok(false) + } + + async fn is_file(&self, path: &Path) -> anyhow::Result { + Ok(tokio::fs::metadata(path) + .await + .map(|m| m.is_file()) + .unwrap_or(false)) + } + + async fn exists(&self, path: &Path) -> anyhow::Result { + Ok(tokio::fs::try_exists(path).await.unwrap_or(false)) + } + + async fn file_size(&self, path: &Path) -> anyhow::Result { + Ok(tokio::fs::metadata(path).await?.len()) + } + } + + #[async_trait::async_trait] + impl FileReaderInfra for TestInfra { + async fn read_utf8(&self, path: &Path) -> anyhow::Result { + Ok(tokio::fs::read_to_string(path).await?) + } + + fn read_batch_utf8( + &self, + _batch_size: usize, + _paths: Vec, + ) -> impl futures::Stream)> + Send { + futures::stream::empty() + } + + async fn read(&self, path: &Path) -> anyhow::Result> { + Ok(tokio::fs::read(path).await?) + } + + async fn range_read_utf8( + &self, + _path: &Path, + _start_line: u64, + _end_line: u64, + ) -> anyhow::Result<(String, forge_domain::FileInfo)> { + unimplemented!("not needed for tests") + } + } + + impl EnvironmentInfra for TestInfra { + type Config = forge_config::ForgeConfig; + + fn get_env_var(&self, _key: &str) -> Option { + None + } + + fn get_env_vars(&self) -> std::collections::BTreeMap { + Default::default() + } + + fn get_environment(&self) -> forge_domain::Environment { + let mut env: forge_domain::Environment = fake::Faker.fake(); + env.home = self.home.clone(); + env.cwd = self.cwd.clone(); + env + } + + fn get_config(&self) -> forge_config::ForgeConfig { + Default::default() + } + + async fn update_environment( + &self, + _ops: Vec, + ) -> anyhow::Result<()> { + unimplemented!("not needed for tests") + } + } +} diff --git a/forge.schema.json b/forge.schema.json index 9ba1a41bba..3c87127701 100644 --- a/forge.schema.json +++ b/forge.schema.json @@ -56,6 +56,15 @@ "null" ] }, + "hook_timeout_ms": { + "description": "Default timeout in milliseconds for user hook commands.\nIndividual hooks can override this via their own `timeout` field.", + "type": [ + "integer", + "null" + ], + "format": "uint64", + "minimum": 0 + }, "http": { "description": "HTTP client settings including proxy, TLS, and timeout configuration.", "anyOf": [