Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions crates/browser-use-agent/src/entrypoint/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ struct StoreTurnState {
base_instructions: String,
current_model: Option<String>,
previous_model_compaction: Option<PreviousModelCompaction>,
cancel: CancellationToken,
/// The model-based summary pass for [`compact`](TurnState::compact). `None`
/// disables compaction (the no-sampler / `Fake` path); the production run sets
/// a real [`EntrypointSampler`].
Expand Down Expand Up @@ -276,6 +277,7 @@ impl StoreTurnState {
base_instructions: crate::prompts::browser_agent_system_prompt(),
current_model: None,
previous_model_compaction: None,
cancel: CancellationToken::new(),
compaction_sampler: None,
compacted: Mutex::new(None),
}
Expand Down Expand Up @@ -321,6 +323,11 @@ impl StoreTurnState {
self
}

fn with_cancel(mut self, cancel: CancellationToken) -> Self {
self.cancel = cancel;
self
}

/// Assemble the current prompt as typed [`Message`]s (synchronously). The base
/// is the compacted override when present, else the lowered durable log; this
/// run's recorded turns are appended.
Expand Down Expand Up @@ -1230,14 +1237,18 @@ async fn run_compaction_with_retries(
token_limit: usize,
max_retries: u32,
context_window: Option<i64>,
cancel: CancellationToken,
) -> Result<crate::compact::CompactedHistory, AgentError> {
let mut retries = 0;
let mut working: Vec<Item> = history.to_vec();
working.push(compaction_prompt_item(compact_prompt));
loop {
if cancel.is_cancelled() {
return Err(AgentError::TurnAborted);
}
let request = compaction_request_messages(&working);
let request_len = request.len();
match sampler.summarize(request, CancellationToken::new()).await {
match sampler.summarize(request, cancel.clone()).await {
Ok(summary) => {
append_compaction_token_usage(
store,
Expand Down Expand Up @@ -1269,10 +1280,13 @@ async fn run_compaction_with_retries(
session_id.as_str(),
&format!("Reconnecting... {retries}/{max_retries}"),
);
tokio::time::sleep(std::time::Duration::from_millis(
let delay = tokio::time::sleep(std::time::Duration::from_millis(
crate::decision::backoff_ms(retries),
))
.await;
));
tokio::select! {
_ = cancel.cancelled() => return Err(AgentError::TurnAborted),
_ = delay => {}
}
}
Err(error) => return Err(error),
}
Expand Down Expand Up @@ -1667,6 +1681,7 @@ impl StoreTurnState {
COMPACT_USER_MESSAGE_MAX_TOKENS,
DEFAULT_STREAM_MAX_RETRIES,
self.context_window(),
self.cancel.clone(),
)
.await
{
Expand Down Expand Up @@ -2053,7 +2068,8 @@ async fn drive_run<Sd: SamplingDriver>(
base_instructions.unwrap_or_else(|| crate::prompts::browser_agent_system_prompt()),
developer_instructions,
)
.with_previous_model_compaction(previous_model_compaction);
.with_previous_model_compaction(previous_model_compaction)
.with_cancel(cancel.clone());

let pre_turn_replay_from_seq = if turn_has_fresh_input {
let events = events_from_store(&store, session_id.as_str());
Expand Down Expand Up @@ -2687,6 +2703,21 @@ mod tests {
}
}

struct WaitForCancelSampler;

impl CompactionSampler for WaitForCancelSampler {
fn summarize(
&self,
_request: Vec<Message>,
cancel: CancellationToken,
) -> impl Future<Output = Result<CompactionSummary, AgentError>> + Send {
async move {
cancel.cancelled().await;
Err(AgentError::TurnAborted)
}
}
}

struct WindowThenRetryableThenSuccessSampler {
attempts: AtomicUsize,
request_lens: Mutex<Vec<usize>>,
Expand Down Expand Up @@ -3521,6 +3552,7 @@ mod tests {
COMPACT_USER_MESSAGE_MAX_TOKENS,
1,
Some(1_000),
CancellationToken::new(),
)
.await
.expect("mixed window and retryable errors should eventually succeed");
Expand Down Expand Up @@ -3550,6 +3582,7 @@ mod tests {
COMPACT_USER_MESSAGE_MAX_TOKENS,
0,
Some(1_000),
CancellationToken::new(),
)
.await
.expect_err("unshrinkable summary request should fail");
Expand All @@ -3574,6 +3607,38 @@ mod tests {
.any(|event| event.event_type == "model.turn.context_overflow"));
}

#[tokio::test]
async fn compaction_uses_callers_cancellation_token() {
let (_dir, store, session_id) = store_with_session();
let cancel = CancellationToken::new();
let cancel_from_task = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_from_task.cancel();
});

let sampler: Arc<dyn DynCompactionSampler> = Arc::new(WaitForCancelSampler);
let err = tokio::time::timeout(
Duration::from_secs(2),
run_compaction_with_retries(
&SessionId(session_id.clone()),
&store,
&[],
sampler.as_ref(),
crate::compact::SUMMARIZATION_PROMPT,
COMPACT_USER_MESSAGE_MAX_TOKENS,
0,
Some(1_000),
cancel,
),
)
.await
.expect("compaction should observe cancellation")
.expect_err("cancelled compaction should fail");

assert!(matches!(err, AgentError::TurnAborted));
}

#[tokio::test]
async fn previous_model_downshift_compacts_before_current_model_sampling() {
let (_dir, store, session_id) = store_with_session();
Expand Down
30 changes: 19 additions & 11 deletions crates/browser-use-agent/src/mcp/stdio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ use crate::mcp::protocol::{

type PendingMap = Arc<Mutex<HashMap<RequestId, oneshot::Sender<JsonRpcMessage>>>>;

const MCP_SERVER_REPLY_TIMEOUT: Duration = Duration::from_secs(5);

/// A connected stdio MCP server.
pub struct StdioTransport {
/// Kept alive (killed on drop via `kill_on_drop`) for the transport's life.
Expand Down Expand Up @@ -147,7 +149,7 @@ impl StdioTransport {
.await
.context("initialize handshake failed")?;
transport
.notify(initialized_notification())
.notify(initialized_notification(), startup_timeout)
.await
.context("sending notifications/initialized failed")?;

Expand Down Expand Up @@ -191,16 +193,17 @@ impl StdioTransport {
}

let req = JsonRpcRequest::new(id.clone(), method, params);
if let Err(err) = write_json(&self.writer, &req).await {
self.pending.lock().await.remove(&id);
return Err(err);
}
let operation = async {
write_json(&self.writer, &req).await?;
rx.await
.map_err(|_| anyhow!("MCP server closed stdout before responding to {method}"))
};

let msg = match timeout(timeout_dur, rx).await {
let msg = match timeout(timeout_dur, operation).await {
Ok(Ok(msg)) => msg,
Ok(Err(_)) => {
Ok(Err(err)) => {
self.pending.lock().await.remove(&id);
bail!("MCP server closed stdout before responding to {method}");
return Err(err);
}
Err(_) => {
self.pending.lock().await.remove(&id);
Expand All @@ -218,8 +221,13 @@ impl StdioTransport {
Ok(msg.result.unwrap_or(Value::Null))
}

async fn notify(&self, notification: JsonRpcNotification) -> Result<()> {
write_json(&self.writer, &notification).await
async fn notify(&self, notification: JsonRpcNotification, timeout_dur: Duration) -> Result<()> {
match timeout(timeout_dur, write_json(&self.writer, &notification)).await {
Ok(result) => result,
Err(_) => {
bail!("MCP server timed out while sending notification after {timeout_dur:?}")
}
}
}
}

Expand All @@ -242,7 +250,7 @@ async fn handle_server_request(writer: &Arc<Mutex<ChildStdin>>, msg: JsonRpcMess
"error": { "code": -32601, "message": format!("method not found: {method}") },
})
};
let _ = write_json(writer, &reply).await;
let _ = timeout(MCP_SERVER_REPLY_TIMEOUT, write_json(writer, &reply)).await;
}

/// Write a serializable message as a newline-terminated JSON line, then flush.
Expand Down
27 changes: 27 additions & 0 deletions crates/browser-use-agent/src/mcp/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,33 @@ async fn stdio_error_result_maps_is_error() {
assert_eq!(mcp_result_tool_content(&result.into_seam()), "kaboom");
}

#[tokio::test]
async fn stdio_connect_times_out_when_child_never_answers_initialize() {
let result = tokio::time::timeout(
Duration::from_secs(2),
StdioTransport::connect(
"python3",
&["-c".to_string(), "import time; time.sleep(60)".to_string()],
&HashMap::new(),
None,
Duration::from_millis(100),
Duration::from_millis(100),
),
)
.await
.expect("transport watchdog should return");
let err = match result {
Ok(_) => panic!("connect should fail"),
Err(err) => err,
};

let message = format!("{err:#}");
assert!(
message.contains("timed out"),
"expected timeout error, got {message}"
);
}

// ---------------------------------------------------------------------------
// http transport (loopback TcpListener)
// ---------------------------------------------------------------------------
Expand Down
23 changes: 23 additions & 0 deletions crates/browser-use-agent/src/tools/handlers/apply_patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ pub fn apply_patch_operations(
match op {
PatchOperation::AddFile { path, contents } => {
let real = resolve_patch_path(root, path);
reject_existing_special_file(&real, path)?;
if let Some(parent) = real.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
ToolError::Other(anyhow::anyhow!("creating parent dirs for {path}: {e}"))
Expand All @@ -508,6 +509,7 @@ pub fn apply_patch_operations(
hunks,
} => {
let real = resolve_patch_path(root, path);
ensure_regular_file_for_read(&real, path)?;
let original = std::fs::read_to_string(&real).map_err(|e| {
ToolError::Other(anyhow::anyhow!("reading file to update {path}: {e}"))
})?;
Expand All @@ -521,6 +523,7 @@ pub fn apply_patch_operations(
} else {
real.clone()
};
reject_existing_special_file(&dest, move_to.as_deref().unwrap_or(path))?;
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent).map_err(|e| {
ToolError::Other(anyhow::anyhow!("creating parent dirs for {path}: {e}"))
Expand All @@ -545,6 +548,26 @@ pub fn apply_patch_operations(
Ok(ApplyPatchSummary { changed })
}

fn ensure_regular_file_for_read(real: &Path, display: &str) -> Result<(), ToolError> {
let meta = std::fs::metadata(real)
.map_err(|e| ToolError::Other(anyhow::anyhow!("reading metadata for {display}: {e}")))?;
if !meta.file_type().is_file() {
return Err(ToolError::Rejected(format!(
"apply_patch refuses to read non-regular file {display}"
)));
}
Ok(())
}

fn reject_existing_special_file(real: &Path, display: &str) -> Result<(), ToolError> {
match std::fs::metadata(real) {
Ok(meta) if !meta.file_type().is_file() => Err(ToolError::Rejected(format!(
"apply_patch refuses to write non-regular file {display}"
))),
Ok(_) | Err(_) => Ok(()),
}
}

/// Apply update hunks to the original file contents, returning the new contents.
///
/// Parity: legacy `apply_hunk_to_lines` (files.rs:789). Each hunk is located by
Expand Down
49 changes: 49 additions & 0 deletions crates/browser-use-agent/src/tools/handlers/apply_patch_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,30 @@ async fn add_file_creates_nested_dirs() {
assert_eq!(written, "nested");
}

#[tokio::test]
async fn add_file_over_directory_is_rejected_before_write() {
let dir = tempfile::tempdir().unwrap();
let ctx = ctx_in(dir.path());
std::fs::create_dir(dir.path().join("existing.txt")).unwrap();

let patch = "\
*** Begin Patch
*** Add File: existing.txt
+replacement
*** End Patch
";
let req = ApplyPatchRequest::new(patch);
match run_direct(&req, &ctx).await {
Err(ToolError::Rejected(msg)) => {
assert!(
msg.contains("non-regular file"),
"should reject directories before writing, got: {msg}"
);
}
other => panic!("expected Rejected for directory add, got {other:?}"),
}
}

// (2) Update File applies a hunk (context + +/- lines) correctly.
#[tokio::test]
async fn update_file_applies_hunk() {
Expand Down Expand Up @@ -127,6 +151,31 @@ async fn update_file_applies_hunk() {
);
}

#[tokio::test]
async fn update_directory_is_rejected_before_read() {
let dir = tempfile::tempdir().unwrap();
let ctx = ctx_in(dir.path());
std::fs::create_dir(dir.path().join("not-a-file.txt")).unwrap();

let patch = "\
*** Begin Patch
*** Update File: not-a-file.txt
@@
anything
*** End Patch
";
let req = ApplyPatchRequest::new(patch);
match run_direct(&req, &ctx).await {
Err(ToolError::Rejected(msg)) => {
assert!(
msg.contains("non-regular file"),
"should reject directories before reading, got: {msg}"
);
}
other => panic!("expected Rejected for directory update, got {other:?}"),
}
}

// (2b) Update File with an *** End of File marker terminating the hunk.
#[tokio::test]
async fn update_file_with_end_of_file_marker() {
Expand Down
Loading