diff --git a/rust/tests/e2e/tools.rs b/rust/tests/e2e/tools.rs index 85d15b571..2c474bca1 100644 --- a/rust/tests/e2e/tools.rs +++ b/rust/tests/e2e/tools.rs @@ -4,10 +4,10 @@ use github_copilot_sdk::handler::{ApproveAllHandler, PermissionHandler, Permissi use github_copilot_sdk::tool::ToolHandler; use github_copilot_sdk::{ Error, PermissionRequestData, RequestId, SessionConfig, SessionId, Tool, ToolInvocation, - ToolResult, + ToolResult, ToolSet, }; use serde_json::json; -use tokio::sync::mpsc; +use tokio::sync::{Mutex, mpsc}; use super::support::{assistant_message_content, recv_with_timeout, with_e2e_context}; @@ -73,6 +73,55 @@ async fn invokes_custom_tool() { .await; } +#[tokio::test] +async fn low_level_tool_definition() { + with_e2e_context("tools", "low_level_tool_definition", |ctx| { + Box::pin(async move { + ctx.set_default_copilot_user(); + let client = ctx.start_client().await; + let __perm = Arc::new(ApproveAllHandler); + let current_phase = Arc::new(Mutex::new(String::new())); + let tools = vec![ + set_current_phase_tool(current_phase.clone()), + search_items_tool(), + ]; + let available_tools = ToolSet::new() + .add_custom("*") + .expect("add custom wildcard") + .add_builtin("web_fetch") + .expect("add web_fetch") + .into_vec(); + let session = client + .create_session( + SessionConfig::default() + .with_github_token(super::support::DEFAULT_TEST_TOKEN) + .with_permission_handler(__perm) + .with_tools(tools) + .with_available_tools(available_tools), + ) + .await + .expect("create session"); + + let answer = session + .send_and_wait( + "First, set the current phase to 'analyzing'. Then search for items with keyword 'copilot'. Report the phase and search results.", + ) + .await + .expect("send") + .expect("assistant message"); + let content = assistant_message_content(&answer); + assert!(!content.is_empty()); + assert!(content.to_lowercase().contains("analyzing")); + assert!(content.contains("item_alpha") || content.contains("item_beta")); + assert_eq!(current_phase.lock().await.clone(), "analyzing"); + + session.disconnect().await.expect("disconnect session"); + client.stop().await.expect("stop client"); + }) + }) + .await; +} + #[tokio::test] async fn handles_tool_calling_errors() { with_e2e_context("tools", "handles_tool_calling_errors", |ctx| { @@ -502,6 +551,69 @@ impl ToolHandler for ErrorTool { struct CustomGrepTool; +struct SetCurrentPhaseTool { + current_phase: Arc>, +} + +fn set_current_phase_tool(current_phase: Arc>) -> Tool { + Tool::new("set_current_phase") + .with_description("Sets the current phase of the agent") + .with_parameters(json!({ + "type": "object", + "properties": { + "phase": { + "type": "string", + "description": "Current phase", + "pattern": "^(searching|analyzing|done)$" + } + }, + "required": ["phase"] + })) + .with_handler(Arc::new(SetCurrentPhaseTool { current_phase })) +} + +#[async_trait::async_trait] +impl ToolHandler for SetCurrentPhaseTool { + async fn call(&self, invocation: ToolInvocation) -> Result { + let phase = invocation + .arguments + .get("phase") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + *self.current_phase.lock().await = phase.clone(); + Ok(ToolResult::Text(format!("Phase set to {phase}"))) + } +} + +struct SearchItemsTool; + +fn search_items_tool() -> Tool { + Tool::new("search_items") + .with_description("Search for items by keyword") + .with_parameters(json!({ + "type": "object", + "properties": { + "keyword": { "type": "string" } + }, + "required": ["keyword"] + })) + .with_handler(Arc::new(SearchItemsTool)) +} + +#[async_trait::async_trait] +impl ToolHandler for SearchItemsTool { + async fn call(&self, invocation: ToolInvocation) -> Result { + let keyword = invocation + .arguments + .get("keyword") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + assert_eq!(keyword, "copilot"); + Ok(ToolResult::Text("Found: item_alpha, item_beta".to_string())) + } +} + fn custom_grep_tool() -> Tool { Tool::new("grep") .with_description("A custom grep implementation that overrides the built-in")