diff --git a/crates/browser-use-agent/src/entrypoint/mod.rs b/crates/browser-use-agent/src/entrypoint/mod.rs index a90b4a68..7a9b96da 100644 --- a/crates/browser-use-agent/src/entrypoint/mod.rs +++ b/crates/browser-use-agent/src/entrypoint/mod.rs @@ -242,6 +242,7 @@ struct StoreTurnState { base_instructions: String, current_model: Option, previous_model_compaction: Option, + cancel: CancellationToken, /// The model-based summary pass for [`compact`](TurnState::compact). `None` /// disables compaction (the no-sampler / `Fake` path); the production run sets /// a real [`EntrypointSampler`]. @@ -276,6 +277,7 @@ impl StoreTurnState { base_instructions: crate::prompts::browser_agent_system_prompt(), current_model: None, previous_model_compaction: None, + cancel: CancellationToken::new(), compaction_sampler: None, compacted: Mutex::new(None), } @@ -321,6 +323,11 @@ impl StoreTurnState { self } + fn with_cancel(mut self, cancel: CancellationToken) -> Self { + self.cancel = cancel; + self + } + /// Assemble the current prompt as typed [`Message`]s (synchronously). The base /// is the compacted override when present, else the lowered durable log; this /// run's recorded turns are appended. @@ -1230,14 +1237,18 @@ async fn run_compaction_with_retries( token_limit: usize, max_retries: u32, context_window: Option, + cancel: CancellationToken, ) -> Result { let mut retries = 0; let mut working: Vec = history.to_vec(); working.push(compaction_prompt_item(compact_prompt)); loop { + if cancel.is_cancelled() { + return Err(AgentError::TurnAborted); + } let request = compaction_request_messages(&working); let request_len = request.len(); - match sampler.summarize(request, CancellationToken::new()).await { + match sampler.summarize(request, cancel.clone()).await { Ok(summary) => { append_compaction_token_usage( store, @@ -1269,10 +1280,13 @@ async fn run_compaction_with_retries( session_id.as_str(), &format!("Reconnecting... {retries}/{max_retries}"), ); - tokio::time::sleep(std::time::Duration::from_millis( + let delay = tokio::time::sleep(std::time::Duration::from_millis( crate::decision::backoff_ms(retries), - )) - .await; + )); + tokio::select! { + _ = cancel.cancelled() => return Err(AgentError::TurnAborted), + _ = delay => {} + } } Err(error) => return Err(error), } @@ -1667,6 +1681,7 @@ impl StoreTurnState { COMPACT_USER_MESSAGE_MAX_TOKENS, DEFAULT_STREAM_MAX_RETRIES, self.context_window(), + self.cancel.clone(), ) .await { @@ -2053,7 +2068,8 @@ async fn drive_run( base_instructions.unwrap_or_else(|| crate::prompts::browser_agent_system_prompt()), developer_instructions, ) - .with_previous_model_compaction(previous_model_compaction); + .with_previous_model_compaction(previous_model_compaction) + .with_cancel(cancel.clone()); let pre_turn_replay_from_seq = if turn_has_fresh_input { let events = events_from_store(&store, session_id.as_str()); @@ -2687,6 +2703,21 @@ mod tests { } } + struct WaitForCancelSampler; + + impl CompactionSampler for WaitForCancelSampler { + fn summarize( + &self, + _request: Vec, + cancel: CancellationToken, + ) -> impl Future> + Send { + async move { + cancel.cancelled().await; + Err(AgentError::TurnAborted) + } + } + } + struct WindowThenRetryableThenSuccessSampler { attempts: AtomicUsize, request_lens: Mutex>, @@ -3521,6 +3552,7 @@ mod tests { COMPACT_USER_MESSAGE_MAX_TOKENS, 1, Some(1_000), + CancellationToken::new(), ) .await .expect("mixed window and retryable errors should eventually succeed"); @@ -3550,6 +3582,7 @@ mod tests { COMPACT_USER_MESSAGE_MAX_TOKENS, 0, Some(1_000), + CancellationToken::new(), ) .await .expect_err("unshrinkable summary request should fail"); @@ -3574,6 +3607,38 @@ mod tests { .any(|event| event.event_type == "model.turn.context_overflow")); } + #[tokio::test] + async fn compaction_uses_callers_cancellation_token() { + let (_dir, store, session_id) = store_with_session(); + let cancel = CancellationToken::new(); + let cancel_from_task = cancel.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(50)).await; + cancel_from_task.cancel(); + }); + + let sampler: Arc = Arc::new(WaitForCancelSampler); + let err = tokio::time::timeout( + Duration::from_secs(2), + run_compaction_with_retries( + &SessionId(session_id.clone()), + &store, + &[], + sampler.as_ref(), + crate::compact::SUMMARIZATION_PROMPT, + COMPACT_USER_MESSAGE_MAX_TOKENS, + 0, + Some(1_000), + cancel, + ), + ) + .await + .expect("compaction should observe cancellation") + .expect_err("cancelled compaction should fail"); + + assert!(matches!(err, AgentError::TurnAborted)); + } + #[tokio::test] async fn previous_model_downshift_compacts_before_current_model_sampling() { let (_dir, store, session_id) = store_with_session(); diff --git a/crates/browser-use-agent/src/mcp/stdio.rs b/crates/browser-use-agent/src/mcp/stdio.rs index e24a55f5..d54764a2 100644 --- a/crates/browser-use-agent/src/mcp/stdio.rs +++ b/crates/browser-use-agent/src/mcp/stdio.rs @@ -39,6 +39,8 @@ use crate::mcp::protocol::{ type PendingMap = Arc>>>; +const MCP_SERVER_REPLY_TIMEOUT: Duration = Duration::from_secs(5); + /// A connected stdio MCP server. pub struct StdioTransport { /// Kept alive (killed on drop via `kill_on_drop`) for the transport's life. @@ -147,7 +149,7 @@ impl StdioTransport { .await .context("initialize handshake failed")?; transport - .notify(initialized_notification()) + .notify(initialized_notification(), startup_timeout) .await .context("sending notifications/initialized failed")?; @@ -191,16 +193,17 @@ impl StdioTransport { } let req = JsonRpcRequest::new(id.clone(), method, params); - if let Err(err) = write_json(&self.writer, &req).await { - self.pending.lock().await.remove(&id); - return Err(err); - } + let operation = async { + write_json(&self.writer, &req).await?; + rx.await + .map_err(|_| anyhow!("MCP server closed stdout before responding to {method}")) + }; - let msg = match timeout(timeout_dur, rx).await { + let msg = match timeout(timeout_dur, operation).await { Ok(Ok(msg)) => msg, - Ok(Err(_)) => { + Ok(Err(err)) => { self.pending.lock().await.remove(&id); - bail!("MCP server closed stdout before responding to {method}"); + return Err(err); } Err(_) => { self.pending.lock().await.remove(&id); @@ -218,8 +221,13 @@ impl StdioTransport { Ok(msg.result.unwrap_or(Value::Null)) } - async fn notify(&self, notification: JsonRpcNotification) -> Result<()> { - write_json(&self.writer, ¬ification).await + async fn notify(&self, notification: JsonRpcNotification, timeout_dur: Duration) -> Result<()> { + match timeout(timeout_dur, write_json(&self.writer, ¬ification)).await { + Ok(result) => result, + Err(_) => { + bail!("MCP server timed out while sending notification after {timeout_dur:?}") + } + } } } @@ -242,7 +250,7 @@ async fn handle_server_request(writer: &Arc>, msg: JsonRpcMess "error": { "code": -32601, "message": format!("method not found: {method}") }, }) }; - let _ = write_json(writer, &reply).await; + let _ = timeout(MCP_SERVER_REPLY_TIMEOUT, write_json(writer, &reply)).await; } /// Write a serializable message as a newline-terminated JSON line, then flush. diff --git a/crates/browser-use-agent/src/mcp/tests.rs b/crates/browser-use-agent/src/mcp/tests.rs index 722caa4c..4dd06d1e 100644 --- a/crates/browser-use-agent/src/mcp/tests.rs +++ b/crates/browser-use-agent/src/mcp/tests.rs @@ -161,6 +161,33 @@ async fn stdio_error_result_maps_is_error() { assert_eq!(mcp_result_tool_content(&result.into_seam()), "kaboom"); } +#[tokio::test] +async fn stdio_connect_times_out_when_child_never_answers_initialize() { + let result = tokio::time::timeout( + Duration::from_secs(2), + StdioTransport::connect( + "python3", + &["-c".to_string(), "import time; time.sleep(60)".to_string()], + &HashMap::new(), + None, + Duration::from_millis(100), + Duration::from_millis(100), + ), + ) + .await + .expect("transport watchdog should return"); + let err = match result { + Ok(_) => panic!("connect should fail"), + Err(err) => err, + }; + + let message = format!("{err:#}"); + assert!( + message.contains("timed out"), + "expected timeout error, got {message}" + ); +} + // --------------------------------------------------------------------------- // http transport (loopback TcpListener) // --------------------------------------------------------------------------- diff --git a/crates/browser-use-agent/src/tools/handlers/apply_patch.rs b/crates/browser-use-agent/src/tools/handlers/apply_patch.rs index 466f761d..5261588f 100644 --- a/crates/browser-use-agent/src/tools/handlers/apply_patch.rs +++ b/crates/browser-use-agent/src/tools/handlers/apply_patch.rs @@ -486,6 +486,7 @@ pub fn apply_patch_operations( match op { PatchOperation::AddFile { path, contents } => { let real = resolve_patch_path(root, path); + reject_existing_special_file(&real, path)?; if let Some(parent) = real.parent() { std::fs::create_dir_all(parent).map_err(|e| { ToolError::Other(anyhow::anyhow!("creating parent dirs for {path}: {e}")) @@ -508,6 +509,7 @@ pub fn apply_patch_operations( hunks, } => { let real = resolve_patch_path(root, path); + ensure_regular_file_for_read(&real, path)?; let original = std::fs::read_to_string(&real).map_err(|e| { ToolError::Other(anyhow::anyhow!("reading file to update {path}: {e}")) })?; @@ -521,6 +523,7 @@ pub fn apply_patch_operations( } else { real.clone() }; + reject_existing_special_file(&dest, move_to.as_deref().unwrap_or(path))?; if let Some(parent) = dest.parent() { std::fs::create_dir_all(parent).map_err(|e| { ToolError::Other(anyhow::anyhow!("creating parent dirs for {path}: {e}")) @@ -545,6 +548,26 @@ pub fn apply_patch_operations( Ok(ApplyPatchSummary { changed }) } +fn ensure_regular_file_for_read(real: &Path, display: &str) -> Result<(), ToolError> { + let meta = std::fs::metadata(real) + .map_err(|e| ToolError::Other(anyhow::anyhow!("reading metadata for {display}: {e}")))?; + if !meta.file_type().is_file() { + return Err(ToolError::Rejected(format!( + "apply_patch refuses to read non-regular file {display}" + ))); + } + Ok(()) +} + +fn reject_existing_special_file(real: &Path, display: &str) -> Result<(), ToolError> { + match std::fs::metadata(real) { + Ok(meta) if !meta.file_type().is_file() => Err(ToolError::Rejected(format!( + "apply_patch refuses to write non-regular file {display}" + ))), + Ok(_) | Err(_) => Ok(()), + } +} + /// Apply update hunks to the original file contents, returning the new contents. /// /// Parity: legacy `apply_hunk_to_lines` (files.rs:789). Each hunk is located by diff --git a/crates/browser-use-agent/src/tools/handlers/apply_patch_tests.rs b/crates/browser-use-agent/src/tools/handlers/apply_patch_tests.rs index 5a7d2927..e2e5e081 100644 --- a/crates/browser-use-agent/src/tools/handlers/apply_patch_tests.rs +++ b/crates/browser-use-agent/src/tools/handlers/apply_patch_tests.rs @@ -97,6 +97,30 @@ async fn add_file_creates_nested_dirs() { assert_eq!(written, "nested"); } +#[tokio::test] +async fn add_file_over_directory_is_rejected_before_write() { + let dir = tempfile::tempdir().unwrap(); + let ctx = ctx_in(dir.path()); + std::fs::create_dir(dir.path().join("existing.txt")).unwrap(); + + let patch = "\ +*** Begin Patch +*** Add File: existing.txt ++replacement +*** End Patch +"; + let req = ApplyPatchRequest::new(patch); + match run_direct(&req, &ctx).await { + Err(ToolError::Rejected(msg)) => { + assert!( + msg.contains("non-regular file"), + "should reject directories before writing, got: {msg}" + ); + } + other => panic!("expected Rejected for directory add, got {other:?}"), + } +} + // (2) Update File applies a hunk (context + +/- lines) correctly. #[tokio::test] async fn update_file_applies_hunk() { @@ -127,6 +151,31 @@ async fn update_file_applies_hunk() { ); } +#[tokio::test] +async fn update_directory_is_rejected_before_read() { + let dir = tempfile::tempdir().unwrap(); + let ctx = ctx_in(dir.path()); + std::fs::create_dir(dir.path().join("not-a-file.txt")).unwrap(); + + let patch = "\ +*** Begin Patch +*** Update File: not-a-file.txt +@@ + anything +*** End Patch +"; + let req = ApplyPatchRequest::new(patch); + match run_direct(&req, &ctx).await { + Err(ToolError::Rejected(msg)) => { + assert!( + msg.contains("non-regular file"), + "should reject directories before reading, got: {msg}" + ); + } + other => panic!("expected Rejected for directory update, got {other:?}"), + } +} + // (2b) Update File with an *** End of File marker terminating the hunk. #[tokio::test] async fn update_file_with_end_of_file_marker() { diff --git a/crates/browser-use-agent/src/tools/handlers/browser.rs b/crates/browser-use-agent/src/tools/handlers/browser.rs index f0f6c8a1..8bc4cdd9 100644 --- a/crates/browser-use-agent/src/tools/handlers/browser.rs +++ b/crates/browser-use-agent/src/tools/handlers/browser.rs @@ -39,6 +39,7 @@ use std::fs; use std::path::PathBuf; use std::sync::Arc; +use std::time::Duration; use anyhow::{anyhow, bail}; use base64::{engine::general_purpose, Engine as _}; @@ -67,6 +68,16 @@ pub const DEFAULT_BROWSER_SCRIPT_TIMEOUT_SECS: u64 = 120; /// Mirrors the legacy default observe window used by the browser_script runtime. pub const DEFAULT_OBSERVE_TIMEOUT_MS: u64 = 1_000; +/// Final async watchdog for browser command calls. +pub const DEFAULT_BROWSER_COMMAND_OUTER_TIMEOUT_SECS: u64 = 180; + +/// Grace added around script backend timeouts so the handler can serialize the +/// timeout result instead of racing the lower layer. +pub const BROWSER_SCRIPT_OUTER_GRACE_SECS: u64 = 30; + +/// Final async watchdog for cancel calls. +pub const BROWSER_CANCEL_OUTER_TIMEOUT_SECS: u64 = 15; + /// Appended to `browser_script` stdout when the response carries image parts. /// /// The dispatch layer strips this marker and re-wraps the JSON payload as typed @@ -192,6 +203,26 @@ impl BrowserRequest { } } +fn browser_action_outer_timeout( + action: &BrowserAction, + timeout_secs: u64, + observe_ms: u64, +) -> Duration { + match action { + BrowserAction::Command { .. } => { + Duration::from_secs(DEFAULT_BROWSER_COMMAND_OUTER_TIMEOUT_SECS) + } + BrowserAction::Execute { .. } => Duration::from_secs( + timeout_secs + .saturating_add(BROWSER_SCRIPT_OUTER_GRACE_SECS) + .max(1), + ), + BrowserAction::Observe { .. } => Duration::from_millis(observe_ms) + .saturating_add(Duration::from_secs(BROWSER_SCRIPT_OUTER_GRACE_SECS)), + BrowserAction::Cancel { .. } => Duration::from_secs(BROWSER_CANCEL_OUTER_TIMEOUT_SECS), + } +} + /// Model-facing wire arguments for the browser tool. /// /// [`BrowserRequest`] is a PARSED form: its [`BrowserAction`] is an internally @@ -1489,6 +1520,7 @@ impl ToolRuntime for BrowserTool { ) -> Result { // No sandbox backend is exercised here (the browser runtime spawns its // own processes); acknowledge the attempt to make the seam explicit. + let cancel = attempt.cancel.clone(); let _ = attempt; let effective_session_id = if req.session_id.trim().is_empty() { @@ -1534,6 +1566,7 @@ impl ToolRuntime for BrowserTool { let timeout_secs = req.effective_timeout_secs(self.default_script_timeout_secs); let observe_ms = req.effective_observe_ms(); let action = req.action.clone(); + let outer_timeout = browser_action_outer_timeout(&action, timeout_secs, observe_ms); let persistence = self.persistence.clone(); let selected_browser_mode = self.selected_browser_mode.clone(); let tool_call_id = ctx.call_id.clone(); @@ -1550,7 +1583,7 @@ impl ToolRuntime for BrowserTool { // The browser fns are synchronous and spawn external processes; run on a // blocking thread so we never stall the async runtime. - let result = tokio::task::spawn_blocking(move || -> Result { + let task = tokio::task::spawn_blocking(move || -> Result { match action { BrowserAction::Command { command } => { let selected_browser_mode = selected_browser_mode.as_deref(); @@ -1661,9 +1694,31 @@ impl ToolRuntime for BrowserTool { Ok(map_script_output(out)) } } - }) - .await - .map_err(|e| ToolError::Other(anyhow::anyhow!("browser task panicked: {e}")))?; + }); + + let result = tokio::select! { + biased; + _ = async { + if let Some(cancel) = cancel { + cancel.cancelled().await; + } else { + std::future::pending::<()>().await; + } + } => { + return Err(ToolError::Other(anyhow::anyhow!("browser task cancelled"))); + } + timed = tokio::time::timeout(outer_timeout, task) => { + match timed { + Ok(joined) => joined + .map_err(|e| ToolError::Other(anyhow::anyhow!("browser task panicked: {e}")))?, + Err(_) => { + return Err(ToolError::Other(anyhow::anyhow!( + "browser task timed out after {outer_timeout:?}" + ))); + } + } + } + }; result } diff --git a/crates/browser-use-agent/src/tools/handlers/browser_tests.rs b/crates/browser-use-agent/src/tools/handlers/browser_tests.rs index c63c3a36..302e3b35 100644 --- a/crates/browser-use-agent/src/tools/handlers/browser_tests.rs +++ b/crates/browser-use-agent/src/tools/handlers/browser_tests.rs @@ -58,6 +58,60 @@ struct FakeBackend { fail: bool, } +struct SlowCommandBackend; + +impl BrowserBackend for SlowCommandBackend { + fn command( + &self, + _session_id: &str, + _cwd: &std::path::Path, + _artifact_dir: &std::path::Path, + _command: &str, + ) -> anyhow::Result { + std::thread::sleep(std::time::Duration::from_secs(2)); + Ok(FakeBackend::ok_command()) + } + + fn run_script( + &self, + _session_id: &str, + _cwd: &std::path::Path, + _artifact_dir: &std::path::Path, + _code: &str, + _timeout_secs: u64, + ) -> anyhow::Result { + anyhow::bail!("not used") + } + + fn start_script( + &self, + _session_id: &str, + _cwd: &std::path::Path, + _artifact_dir: &std::path::Path, + _code: &str, + _timeout_secs: u64, + ) -> anyhow::Result { + anyhow::bail!("not used") + } + + fn observe_script( + &self, + _session_id: &str, + _run_id: &str, + _observe_timeout_ms: u64, + ) -> anyhow::Result { + anyhow::bail!("not used") + } + + fn cancel_script( + &self, + _session_id: &str, + _run_id: &str, + ) -> anyhow::Result { + anyhow::bail!("not used") + } +} + impl FakeBackend { fn last(&self) -> LastCall { self.last.lock().unwrap().clone() @@ -286,6 +340,38 @@ async fn command_routes_and_maps_output() { ); } +#[tokio::test] +async fn cancelled_browser_task_returns_before_blocking_backend_finishes() { + let tool = BrowserTool::with_backend(Arc::new(SlowCommandBackend)); + let req = BrowserRequest::command("sess-1", "go https://example.com"); + let launch = none_launch(); + let cancel = tokio_util::sync::CancellationToken::new(); + cancel.cancel(); + let attempt = SandboxAttempt { + sandbox: SandboxType::None, + permissions: SandboxPermissions::UseDefault, + enforce_managed_network: false, + launch: &launch, + cancel: Some(cancel), + }; + + let err = tokio::time::timeout( + std::time::Duration::from_millis(500), + tool.run(&req, &attempt, &ctx()), + ) + .await + .expect("outer cancellation should win") + .expect_err("cancelled browser task should error"); + + match err { + ToolError::Other(error) => assert!( + error.to_string().contains("cancelled"), + "expected cancellation error, got {error:#}" + ), + other => panic!("expected Other cancellation error, got {other:?}"), + } +} + #[tokio::test] async fn bare_browser_connect_resolves_to_selected_local_mode() { let backend = Arc::new(FakeBackend::default()); diff --git a/crates/browser-use-agent/src/tools/handlers/mcp.rs b/crates/browser-use-agent/src/tools/handlers/mcp.rs index 0e52a2dd..1fd55174 100644 --- a/crates/browser-use-agent/src/tools/handlers/mcp.rs +++ b/crates/browser-use-agent/src/tools/handlers/mcp.rs @@ -468,7 +468,8 @@ impl ToolRuntime for McpTool { ) -> Result { // The MCP call runs in the server process; no sandbox backend is // exercised here. Acknowledge the seam args explicitly. - let _ = (attempt, ctx); + let cancel = attempt.cancel.clone(); + let _ = ctx; // Validate the request before touching the client (legacy unknown-server // / empty-name guards, lib.rs:13404-13411). @@ -491,7 +492,7 @@ impl ToolRuntime for McpTool { // The real client is synchronous (blocking stdio JSON-RPC); run it on a // blocking thread so we never stall the async runtime (mirrors the // browser/python handlers). - let result = tokio::task::spawn_blocking(move || -> Result { + let task = tokio::task::spawn_blocking(move || -> Result { let call = client .call_tool(&server, &tool, args) // Parity with legacy lib.rs:13416-13419 / codex mcp_tool_call.rs:579: @@ -502,9 +503,22 @@ impl ToolRuntime for McpTool { )) })?; Ok(map_call_result(call)) - }) - .await - .map_err(|e| ToolError::Other(anyhow::anyhow!("MCP task panicked: {e}")))?; + }); + + let result = tokio::select! { + biased; + _ = async { + if let Some(cancel) = cancel { + cancel.cancelled().await; + } else { + std::future::pending::<()>().await; + } + } => { + return Err(ToolError::Other(anyhow::anyhow!("MCP task cancelled"))); + } + joined = task => joined + .map_err(|e| ToolError::Other(anyhow::anyhow!("MCP task panicked: {e}")))?, + }; result } diff --git a/crates/browser-use-agent/src/tools/handlers/view_image.rs b/crates/browser-use-agent/src/tools/handlers/view_image.rs index 95e92249..2993f399 100644 --- a/crates/browser-use-agent/src/tools/handlers/view_image.rs +++ b/crates/browser-use-agent/src/tools/handlers/view_image.rs @@ -249,6 +249,27 @@ impl ToolRuntime for ViewImageTool { )) })?; + let meta = std::fs::metadata(&real).map_err(|e| { + ToolError::Other(anyhow::anyhow!( + "view_image: cannot stat {}: {e}", + req.path.display() + )) + })?; + if !meta.file_type().is_file() { + return Err(ToolError::Rejected(format!( + "view_image refuses to read non-regular file {}", + req.path.display() + ))); + } + if meta.len() > MAX_INLINE_LOCAL_IMAGE_BYTES as u64 { + return Err(ToolError::Rejected(format!( + "view_image cannot inline {}: image is {} bytes, above the {} byte inline limit", + req.path.display(), + meta.len(), + MAX_INLINE_LOCAL_IMAGE_BYTES + ))); + } + // Blocking, deliberately-serial read (see the module doc: this tool is // not parallel-safe, so a blocking std read is the right choice — no // benefit to yielding the async runtime for a fast local file). A diff --git a/crates/browser-use-agent/src/tools/handlers/view_image_tests.rs b/crates/browser-use-agent/src/tools/handlers/view_image_tests.rs index 3af7bb87..933ae9f5 100644 --- a/crates/browser-use-agent/src/tools/handlers/view_image_tests.rs +++ b/crates/browser-use-agent/src/tools/handlers/view_image_tests.rs @@ -168,14 +168,32 @@ async fn nonexistent_file_errors() { match run_direct(&req, &ctx).await { Err(ToolError::Other(e)) => { assert!( - e.to_string().contains("cannot read"), - "should report it could not read the file, got: {e}" + e.to_string().contains("cannot stat"), + "should report it could not stat the file, got: {e}" ); } other => panic!("expected Other for nonexistent file, got {other:?}"), } } +#[tokio::test] +async fn directory_with_image_extension_is_rejected_before_read() { + let dir = tempfile::tempdir().unwrap(); + let ctx = ctx_in(dir.path()); + std::fs::create_dir(dir.path().join("not-a-file.png")).unwrap(); + + let req = ViewImageRequest::new("not-a-file.png"); + match run_direct(&req, &ctx).await { + Err(ToolError::Rejected(msg)) => { + assert!( + msg.contains("non-regular file"), + "should reject directories before reading, got: {msg}" + ); + } + other => panic!("expected Rejected for directory, got {other:?}"), + } +} + // (3b) An unsupported extension is rejected cleanly (not a panic). #[tokio::test] async fn unsupported_extension_is_rejected() { diff --git a/crates/browser-use-browser/src/lib.rs b/crates/browser-use-browser/src/lib.rs index 0a7e6a16..e5e3804f 100644 --- a/crates/browser-use-browser/src/lib.rs +++ b/crates/browser-use-browser/src/lib.rs @@ -30,6 +30,7 @@ const LOG_LIMIT: usize = 250; const SCRIPT_MAX_OUTPUT_CHARS: usize = 120_000; const BROWSER_SCRIPT_INITIAL_WAIT_MS: u64 = 750; const BROWSER_SCRIPT_DEFAULT_OBSERVE_MS: u64 = 1_000; +const CDP_CONNECT_TIMEOUT: Duration = Duration::from_secs(10); const BROWSER_SCRIPT_HELPERS: &str = include_str!("browser_script_helpers.py"); #[derive(Debug)] @@ -2598,8 +2599,7 @@ impl BrowserSession { impl CdpConnection { fn connect(ws_url: &str) -> Result { - let (mut socket, _) = - connect(ws_url).with_context(|| format!("connect CDP websocket {ws_url}"))?; + let mut socket = connect_cdp_websocket(ws_url)?; set_cdp_socket_timeouts(&mut socket); Ok(Self { socket, next_id: 1 }) } @@ -2727,8 +2727,7 @@ struct CdpDispatcher { impl CdpDispatcher { fn connect(ws_url: &str) -> Result> { - let (mut socket, _) = - connect(ws_url).with_context(|| format!("connect CDP websocket {ws_url}"))?; + let mut socket = connect_cdp_websocket(ws_url)?; set_cdp_dispatcher_socket_timeouts(&mut socket); let (tx, rx) = std::sync::mpsc::channel::(); let reader = thread::spawn(move || cdp_dispatcher_loop(socket, rx)); @@ -2772,6 +2771,27 @@ impl CdpDispatcher { } } +fn connect_cdp_websocket(ws_url: &str) -> Result>> { + let (tx, rx) = std::sync::mpsc::channel(); + let url = ws_url.to_string(); + thread::spawn(move || { + let result = connect(url.as_str()) + .map(|(socket, _)| socket) + .with_context(|| format!("connect CDP websocket {url}")); + let _ = tx.send(result); + }); + + match rx.recv_timeout(CDP_CONNECT_TIMEOUT) { + Ok(result) => result, + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { + bail!("connect CDP websocket {ws_url} timed out after {CDP_CONNECT_TIMEOUT:?}") + } + Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => { + bail!("connect CDP websocket {ws_url} worker disconnected") + } + } +} + impl Drop for CdpDispatcher { fn drop(&mut self) { let _ = self diff --git a/crates/browser-use-llm/src/route/client.rs b/crates/browser-use-llm/src/route/client.rs index 121bdabc..d8c89e44 100644 --- a/crates/browser-use-llm/src/route/client.rs +++ b/crates/browser-use-llm/src/route/client.rs @@ -35,6 +35,7 @@ use std::time::Duration; use futures_util::{Stream, StreamExt}; use serde_json::Value; +use tokio::time::timeout; use crate::route::framing::{SseDecoder, SseFrame}; use crate::route::protocol::{Protocol, ProtocolStream}; @@ -538,6 +539,18 @@ fn aggregate(events: Vec) -> LlmResponse { // Async client + streaming state machine // =========================================================================== +/// Upper bound for opening an LLM request and receiving response headers. +/// +/// Streaming generations may run much longer, but establishing the stream must +/// not wait forever on a wedged socket or provider edge. +pub const DEFAULT_LLM_REQUEST_TIMEOUT: Duration = Duration::from_secs(120); + +/// Maximum idle gap between byte chunks once a streaming response has opened. +/// +/// This mirrors Codex's SSE idle-timeout shape while keeping the production +/// default generous enough for long model turns. +pub const DEFAULT_LLM_STREAM_IDLE_TIMEOUT: Duration = Duration::from_secs(300); + /// Where the streaming state machine is in the response lifecycle. enum Phase { /// Still pulling byte chunks off the HTTP body. @@ -561,6 +574,8 @@ struct StreamState { protocol_stream: Box, /// Events decoded but not yet yielded. ready: VecDeque>, + /// Maximum time to wait for the next byte chunk. + idle_timeout: Duration, /// Lifecycle phase. phase: Phase, } @@ -599,12 +614,18 @@ pub struct ModelClient { http: reqwest::Client, /// Retry/backoff configuration. retry: RetryPolicy, + /// Timeout for opening a request and reading non-2xx body snippets. + request_timeout: Duration, + /// Timeout for idle streaming body reads. + stream_idle_timeout: Duration, } impl fmt::Debug for ModelClient { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ModelClient") .field("retry", &self.retry) + .field("request_timeout", &self.request_timeout) + .field("stream_idle_timeout", &self.stream_idle_timeout) .finish_non_exhaustive() } } @@ -621,6 +642,8 @@ impl ModelClient { Self { http: reqwest::Client::new(), retry: RetryPolicy::default(), + request_timeout: DEFAULT_LLM_REQUEST_TIMEOUT, + stream_idle_timeout: DEFAULT_LLM_STREAM_IDLE_TIMEOUT, } } @@ -629,12 +652,34 @@ impl ModelClient { Self { http: reqwest::Client::new(), retry, + request_timeout: DEFAULT_LLM_REQUEST_TIMEOUT, + stream_idle_timeout: DEFAULT_LLM_STREAM_IDLE_TIMEOUT, } } /// Construct a client from an existing `reqwest::Client` and retry policy. pub fn from_parts(http: reqwest::Client, retry: RetryPolicy) -> Self { - Self { http, retry } + Self { + http, + retry, + request_timeout: DEFAULT_LLM_REQUEST_TIMEOUT, + stream_idle_timeout: DEFAULT_LLM_STREAM_IDLE_TIMEOUT, + } + } + + /// Construct a client from explicit parts, including watchdog timeouts. + pub fn from_parts_with_timeouts( + http: reqwest::Client, + retry: RetryPolicy, + request_timeout: Duration, + stream_idle_timeout: Duration, + ) -> Self { + Self { + http, + retry, + request_timeout, + stream_idle_timeout, + } } /// The retry policy in effect. @@ -677,9 +722,17 @@ impl ModelClient { for (k, v) in headers { builder = builder.header(k.as_str(), v.as_str()); } - let resp = builder - .send() + let resp = timeout(self.request_timeout, builder.send()) .await + .map_err(|_| { + ( + LlmError::transport(format!( + "LLM request timed out after {:?}", + self.request_timeout + )), + None, + ) + })? .map_err(|e| (LlmError::transport(scrub(&e.to_string())), None))?; let status = resp.status(); @@ -690,7 +743,11 @@ impl ModelClient { // Non-2xx: collect headers for the rate-limit hint, then the body snippet. let info = RateLimitInfo::from_headers(&header_map(resp.headers())); let code = status.as_u16(); - let text = resp.text().await.unwrap_or_default(); + let text = timeout(self.request_timeout, resp.text()) + .await + .ok() + .and_then(Result::ok) + .unwrap_or_default(); Err((error_for_status(code, &scrub(&text)), info.retry_after_ms)) } @@ -743,6 +800,7 @@ impl ModelClient { sse: SseDecoder::new(), protocol_stream: route.protocol.decoder(), ready: VecDeque::new(), + idle_timeout: self.stream_idle_timeout, phase: Phase::Streaming, }; @@ -754,17 +812,32 @@ impl ModelClient { return Some((ev, st)); } match st.phase { - Phase::Streaming => match st.byte_stream.next().await { - Some(Ok(chunk)) => { - let frames = st.sse.push(chunk.as_ref()); - st.decode_frames(frames); - } - Some(Err(e)) => { - st.phase = Phase::Done; - return Some((Err(LlmError::transport(scrub(&e.to_string()))), st)); + Phase::Streaming => { + let next = match timeout(st.idle_timeout, st.byte_stream.next()).await { + Ok(next) => next, + Err(_) => { + st.phase = Phase::Done; + return Some(( + Err(LlmError::transport(format!( + "LLM stream stalled for {:?}", + st.idle_timeout + ))), + st, + )); + } + }; + match next { + Some(Ok(chunk)) => { + let frames = st.sse.push(chunk.as_ref()); + st.decode_frames(frames); + } + Some(Err(e)) => { + st.phase = Phase::Done; + return Some((Err(LlmError::transport(scrub(&e.to_string()))), st)); + } + None => st.phase = Phase::Flushing, } - None => st.phase = Phase::Flushing, - }, + } Phase::Flushing => { st.phase = Phase::Done; match st.protocol_stream.finish() { @@ -1285,4 +1358,50 @@ mod tests { // The bearer token must never leak into the transport error message. assert!(!err.message.contains("sk-not-used"), "leaked token: {err}"); } + + #[tokio::test] + async fn stream_body_idle_timeout_returns_transport_error() { + use std::io::{Read as _, Write as _}; + + let listener = std::net::TcpListener::bind("127.0.0.1:0").expect("bind listener"); + let addr = listener.local_addr().expect("local addr"); + std::thread::spawn(move || { + let (mut stream, _) = listener.accept().expect("accept"); + let mut buf = [0u8; 4096]; + let _ = stream.read(&mut buf); + stream + .write_all( + b"HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nTransfer-Encoding: chunked\r\n\r\n", + ) + .expect("write headers"); + std::thread::sleep(Duration::from_secs(10)); + }); + + let client = ModelClient::from_parts_with_timeouts( + reqwest::Client::new(), + RetryPolicy { + max_attempts: 1, + ..RetryPolicy::default() + }, + Duration::from_secs(2), + Duration::from_millis(100), + ); + let route = Route::new( + Box::new(OpenAiResponsesProtocol::new()), + Endpoint::new(format!("http://{addr}"), "/v1/responses"), + Auth::bearer("sk-not-used"), + ); + let mut req = LlmRequest::new("gpt-5.1-codex", "openai"); + req.messages.push(crate::schema::Message::user_text("hi")); + + let mut stream = client.stream(&route, &req).await.expect("open stream"); + let err = tokio::time::timeout(Duration::from_secs(2), stream.next()) + .await + .expect("idle watchdog should fire") + .expect("stream should yield timeout error") + .expect_err("idle timeout should be an error item"); + + assert_eq!(err.reason, LlmErrorReason::Transport); + assert!(err.message.contains("LLM stream stalled"), "{err}"); + } } diff --git a/crates/browser-use-python-worker/src/lib.rs b/crates/browser-use-python-worker/src/lib.rs index b2692b47..ee8c0cf4 100644 --- a/crates/browser-use-python-worker/src/lib.rs +++ b/crates/browser-use-python-worker/src/lib.rs @@ -12,6 +12,13 @@ use anyhow::{bail, Context, Result}; use serde::{Deserialize, Serialize}; use serde_json::Value; +pub const DEFAULT_PYTHON_TOOL_TIMEOUT_SECONDS: f64 = 120.0; +const PYTHON_WORKER_IO_GRACE: Duration = Duration::from_secs(5); + +fn resolved_timeout_seconds(timeout_seconds: Option) -> f64 { + timeout_seconds.unwrap_or(DEFAULT_PYTHON_TOOL_TIMEOUT_SECONDS) +} + pub struct PythonWorker { child: Child, stdin: ChildStdin, @@ -256,6 +263,43 @@ impl PythonWorker { }) } + fn write_request_line( + &mut self, + request_id: &str, + line: &str, + deadline: Instant, + ) -> Result { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Ok(false); + } + + let payload = line.to_string(); + let (tx, rx) = mpsc::channel(); + let child = &mut self.child; + let stdin = &mut self.stdin; + std::thread::scope(|scope| { + scope.spawn(move || { + let result = writeln!(stdin, "{payload}").and_then(|_| stdin.flush()); + let _ = tx.send(result); + }); + + match rx.recv_timeout(remaining) { + Ok(Ok(())) => Ok(true), + Ok(Err(err)) => Err(err.into()), + Err(mpsc::RecvTimeoutError::Timeout) => { + kill_worker_child(child); + let _ = rx.recv_timeout(PYTHON_WORKER_IO_GRACE); + Ok(false) + } + Err(mpsc::RecvTimeoutError::Disconnected) => { + bail!("python worker writer disconnected") + } + } + }) + .with_context(|| format!("write python worker request {request_id}")) + } + pub fn run( &mut self, session_id: &str, @@ -304,6 +348,12 @@ impl PythonWorker { timeout_seconds: Option, mut on_event: impl FnMut(PythonWorkerEvent), ) -> Result { + let effective_timeout_seconds = resolved_timeout_seconds(timeout_seconds); + let deadline = { + let seconds = effective_timeout_seconds.max(0.0); + let grace = (seconds * 0.1).clamp(1.0, 2.0); + Instant::now() + Duration::from_secs_f64(seconds + grace) + }; let request = RunPythonRequest { id: format!("py-{}", self.next_id), session_id: session_id.to_string(), @@ -311,40 +361,25 @@ impl PythonWorker { artifact_dir: artifact_dir.as_ref().display().to_string(), code: code.to_string(), cancel_requested: false, - timeout_seconds, + timeout_seconds: Some(effective_timeout_seconds), control: None, }; self.next_id += 1; let line = serde_json::to_string(&request)?; - writeln!(self.stdin, "{line}")?; - self.stdin.flush()?; - - let deadline = timeout_seconds.map(|seconds| { - let seconds = seconds.max(0.0); - let grace = (seconds * 0.1).clamp(1.0, 2.0); - Instant::now() + Duration::from_secs_f64(seconds + grace) - }); + if !self.write_request_line(&request.id, &line, deadline)? { + self.restart()?; + return Ok(timeout_response(&request.id, effective_timeout_seconds)); + } loop { - let Some(response) = self.read_response_line(&request.id, timeout_seconds, deadline)? + let Some(response) = self.read_response_line( + &request.id, + Some(effective_timeout_seconds), + Some(deadline), + )? else { self.restart()?; - return Ok(RunPythonResponse { - id: request.id.clone(), - ok: false, - text: String::new(), - error: Some(format!( - "python tool timed out after {} seconds", - timeout_seconds.unwrap_or_default() - )), - data: Value::Null, - outputs: Vec::new(), - artifacts: Vec::new(), - images: Vec::new(), - browser_events: Vec::new(), - browser_harness_available: false, - browser_harness_error: None, - }); + return Ok(timeout_response(&request.id, effective_timeout_seconds)); }; let trimmed = response.trim(); let value: Value = match serde_json::from_str(trimmed) { @@ -368,11 +403,18 @@ impl PythonWorker { } continue; } - return serde_json::from_value(value).context("parse python worker response"); + let parsed: RunPythonResponse = + serde_json::from_value(value).context("parse python worker response")?; + if parsed.id != request.id { + continue; + } + return Ok(parsed); } } pub fn shutdown_owned_cloud_browser(&mut self) -> Result> { + let effective_timeout_seconds = 5.0; + let deadline = Instant::now() + Duration::from_secs_f64(effective_timeout_seconds + 1.0); let cwd = std::env::current_dir()?; let request = RunPythonRequest { id: format!("py-{}", self.next_id), @@ -385,20 +427,26 @@ impl PythonWorker { .to_string(), code: String::new(), cancel_requested: false, - timeout_seconds: Some(5.0), + timeout_seconds: Some(effective_timeout_seconds), control: Some("shutdown_owned_cloud_browser".to_string()), }; self.next_id += 1; let line = serde_json::to_string(&request)?; - writeln!(self.stdin, "{line}")?; - self.stdin.flush()?; + if !self.write_request_line(&request.id, &line, deadline)? { + self.restart()?; + return Ok(None); + } loop { - let mut response = String::new(); - let bytes = self.stdout.read_line(&mut response)?; - if bytes == 0 { + let Some(response) = self.read_response_line( + &request.id, + Some(effective_timeout_seconds), + Some(deadline), + )? + else { + self.restart()?; return Ok(None); - } + }; let trimmed = response.trim(); let value: Value = match serde_json::from_str(trimmed) { Ok(value) => value, @@ -407,11 +455,34 @@ impl PythonWorker { if value.get("event").is_some() { continue; } - return Ok(value.get("data").cloned()); + let parsed: RunPythonResponse = + serde_json::from_value(value).context("parse python worker control response")?; + if parsed.id != request.id { + continue; + } + return Ok(Some(parsed.data)); } } } +fn timeout_response(request_id: &str, timeout_seconds: f64) -> RunPythonResponse { + RunPythonResponse { + id: request_id.to_string(), + ok: false, + text: String::new(), + error: Some(format!( + "python tool timed out after {timeout_seconds} seconds" + )), + data: Value::Null, + outputs: Vec::new(), + artifacts: Vec::new(), + images: Vec::new(), + browser_events: Vec::new(), + browser_harness_available: false, + browser_harness_error: None, + } +} + fn installed_python_path() -> Option { let exe = std::env::current_exe().ok()?; let exe_dir = exe.parent()?; @@ -448,7 +519,7 @@ fn spawn_python_worker( let mut child = command .stdin(Stdio::piped()) .stdout(Stdio::piped()) - .stderr(Stdio::piped()) + .stderr(Stdio::null()) .spawn() .with_context(|| { format!( @@ -499,6 +570,15 @@ mod tests { use super::*; use std::path::PathBuf; + #[test] + fn omitted_timeout_uses_safe_default() { + assert_eq!( + resolved_timeout_seconds(None), + DEFAULT_PYTHON_TOOL_TIMEOUT_SECONDS + ); + assert_eq!(resolved_timeout_seconds(Some(0.25)), 0.25); + } + #[test] fn worker_keeps_a_persistent_namespace_per_session() -> Result<()> { let repo_root = PathBuf::from(env!("CARGO_MANIFEST_DIR"))