diff --git a/python/README.md b/python/README.md index e77fcf16c..37f4bac5c 100644 --- a/python/README.md +++ b/python/README.md @@ -329,6 +329,27 @@ async def lookup_issue(params: LookupParams) -> str: # your logic ``` +#### Providing Per-Call Tool Context + +Pass `on_provide_tool_context` to `create_session` (or `resume_session`) to inject application context into your tool handlers without exposing it to the model. The provider is invoked once per tool call with the `ToolInvocation` (sync or async); its return value is assigned to `invocation.context` before the handler runs. Use it to hand per-request services or state to handlers that would otherwise need a global lookup. `invocation.context` defaults to `None` when no provider is registered, and is never sent over the wire. + +```python +from copilot.tools import ToolInvocation + +@define_tool(description="List the current user's open issues") +async def my_issues(invocation: ToolInvocation) -> str: + ctx = invocation.context # whatever the provider returned + return await ctx.db.open_issues_for(ctx.user_id) + +async with await client.create_session( + on_permission_request=PermissionHandler.approve_all, + model="gpt-5", + tools=[my_issues], + on_provide_tool_context=lambda invocation: build_request_context(), +) as session: + ... +``` + ## Image Support The SDK supports image attachments via the `attachments` parameter. You can attach images by providing their file path, or by passing base64-encoded data directly using a blob attachment: diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 3f1a84d25..e72209ee1 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -128,6 +128,7 @@ SessionUiApi, SessionUiCapabilities, SystemMessageConfig, + ToolContextProvider, UserInputHandler, UserInputRequest, UserInputResponse, @@ -275,6 +276,7 @@ "TelemetryConfig", "Tool", "ToolBinaryResult", + "ToolContextProvider", "ToolInvocation", "ToolResult", "ToolResultType", diff --git a/python/copilot/client.py b/python/copilot/client.py index 24eec9d72..35b72e3d1 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -94,6 +94,7 @@ SessionFsConfig, SessionHooks, SystemMessageConfig, + ToolContextProvider, UserInputHandler, _PermissionHandlerFn, ) @@ -1565,6 +1566,7 @@ async def create_session( reasoning_summary: ReasoningSummary | None = None, context_tier: ContextTier | None = None, tools: list[Tool] | None = None, + on_provide_tool_context: ToolContextProvider | None = None, system_message: SystemMessageConfig | None = None, available_tools: list[str] | ToolSet | None = None, excluded_tools: list[str] | ToolSet | None = None, @@ -1639,6 +1641,10 @@ async def create_session( context_tier: Context window tier for models that support it. Use ``"long_context"`` to pin the session to the long-context tier. tools: Custom tools to register with the session. + on_provide_tool_context: Optional provider invoked once per tool call + with the ``ToolInvocation``; its return value (awaited when a + coroutine) is assigned to ``ToolInvocation.context`` before the + handler runs. system_message: System message configuration. available_tools: Allowlist of tools to enable. When specified, only these tools will be available. Applies to the full merged tool @@ -2008,6 +2014,7 @@ def _initialize_session(sid: str) -> CopilotSession: ) s._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) s._register_tools(tools) + s._register_tool_context_provider(on_provide_tool_context) s._register_commands(commands) s._register_permission_handler(on_permission_request) if on_user_input_request: @@ -2136,6 +2143,7 @@ async def resume_session( reasoning_summary: ReasoningSummary | None = None, context_tier: ContextTier | None = None, tools: list[Tool] | None = None, + on_provide_tool_context: ToolContextProvider | None = None, system_message: SystemMessageConfig | None = None, available_tools: list[str] | ToolSet | None = None, excluded_tools: list[str] | ToolSet | None = None, @@ -2211,6 +2219,10 @@ async def resume_session( context_tier: Context window tier for models that support it. Use ``"long_context"`` to pin the session to the long-context tier. tools: Custom tools to register with the session. + on_provide_tool_context: Optional provider invoked once per tool call + with the ``ToolInvocation``; its return value (awaited when a + coroutine) is assigned to ``ToolInvocation.context`` before the + handler runs. system_message: System message configuration. available_tools: Allowlist of tools to enable. When specified, only these tools will be available. Applies to the full merged tool @@ -2533,6 +2545,7 @@ async def resume_session( ) session._client_session_apis.session_fs = create_session_fs_adapter(fs_provider) session._register_tools(tools) + session._register_tool_context_provider(on_provide_tool_context) session._register_commands(commands) session._register_permission_handler(on_permission_request) if on_user_input_request: diff --git a/python/copilot/session.py b/python/copilot/session.py index 32201870c..648c53414 100644 --- a/python/copilot/session.py +++ b/python/copilot/session.py @@ -280,6 +280,10 @@ class PermissionNoResult: PermissionRequestResult = PermissionDecision | PermissionNoResult +ToolContextProvider = Callable[[ToolInvocation], Any] +"""Per-call tool context provider: receives the invocation, returns ``context``.""" + + _PermissionHandlerFn = Callable[ [PermissionRequest, dict[str, str]], PermissionRequestResult | Awaitable[PermissionRequestResult], @@ -1119,6 +1123,8 @@ def __init__( self._event_handlers_lock = threading.Lock() self._tool_handlers: dict[str, ToolHandler] = {} self._tool_handlers_lock = threading.Lock() + self._tool_context_provider: ToolContextProvider | None = None + self._tool_context_provider_lock = threading.Lock() self._permission_handler: _PermissionHandlerFn | None = None self._permission_handler_lock = threading.Lock() self._user_input_handler: UserInputHandler | None = None @@ -1592,6 +1598,13 @@ async def _execute_tool_and_respond( arguments=arguments, ) + provider = self._get_tool_context_provider() + if provider is not None: + tool_context = provider(invocation) + if inspect.isawaitable(tool_context): + tool_context = await tool_context + invocation.context = tool_context + with trace_context(traceparent, tracestate): handler_start = time.perf_counter() result = handler(invocation) @@ -1989,6 +2002,25 @@ def _get_tool_handler(self, name: str) -> ToolHandler | None: with self._tool_handlers_lock: return self._tool_handlers.get(name) + def _register_tool_context_provider(self, provider: ToolContextProvider | None) -> None: + """ + Register the provider that supplies per-call tool context. + + Note: + This method is internal. The provider is typically registered when + creating a session via :meth:`CopilotClient.create_session`. + + Args: + provider: The tool context provider, or None to remove it. + """ + with self._tool_context_provider_lock: + self._tool_context_provider = provider + + def _get_tool_context_provider(self) -> ToolContextProvider | None: + """Retrieve the registered tool context provider, if any.""" + with self._tool_context_provider_lock: + return self._tool_context_provider + def _register_permission_handler(self, handler: _PermissionHandlerFn | None) -> None: """ Register a handler for permission requests. diff --git a/python/copilot/tools.py b/python/copilot/tools.py index a82a48b1e..d64476868 100644 --- a/python/copilot/tools.py +++ b/python/copilot/tools.py @@ -49,6 +49,7 @@ class ToolInvocation: tool_call_id: str = "" tool_name: str = "" arguments: Any = None + context: Any = None ToolHandler = Callable[[ToolInvocation], ToolResult | Awaitable[ToolResult]] diff --git a/python/test_tool_context.py b/python/test_tool_context.py new file mode 100644 index 000000000..28fc7b88c --- /dev/null +++ b/python/test_tool_context.py @@ -0,0 +1,167 @@ +"""Unit tests for the per-call tool context provider. + +The provider is registered on a session and invoked once per tool call to +populate ``ToolInvocation.context`` before the handler runs. These tests drive +``CopilotSession._execute_tool_and_respond`` directly with a fake RPC so the +injection path is exercised without a live runtime connection. +""" + +from __future__ import annotations + +from typing import Any + +from copilot import define_tool +from copilot.session import CopilotSession +from copilot.tools import ToolInvocation + + +class _FakeToolsRpc: + def __init__(self) -> None: + self.calls: list[Any] = [] + + async def handle_pending_tool_call(self, request: Any) -> None: + self.calls.append(request) + + +class _FakeRpc: + def __init__(self) -> None: + self.tools = _FakeToolsRpc() + + +def _session_with_fake_rpc(session_id: str = "sess-1") -> CopilotSession: + session = CopilotSession(session_id, client=None) + session._rpc = _FakeRpc() # type: ignore[assignment] + return session + + +async def test_provider_value_injected_into_invocation_context(): + seen: dict[str, Any] = {} + + @define_tool("echo", description="Echo tool") + def echo(invocation: ToolInvocation) -> str: + seen["context"] = invocation.context + return "ok" + + session = _session_with_fake_rpc() + session._register_tool_context_provider(lambda inv: {"user": "alice", "tool": inv.tool_name}) + + await session._execute_tool_and_respond( + request_id="r1", + tool_name="echo", + tool_call_id="c1", + arguments={}, + handler=echo.handler, + ) + + assert seen["context"] == {"user": "alice", "tool": "echo"} + + +async def test_async_provider_is_awaited(): + seen: dict[str, Any] = {} + + @define_tool("echo", description="Echo tool") + def echo(invocation: ToolInvocation) -> str: + seen["context"] = invocation.context + return "ok" + + async def provider(_: ToolInvocation) -> dict[str, Any]: + return {"async": True} + + session = _session_with_fake_rpc() + session._register_tool_context_provider(provider) + + await session._execute_tool_and_respond( + request_id="r1", + tool_name="echo", + tool_call_id="c1", + arguments={}, + handler=echo.handler, + ) + + assert seen["context"] == {"async": True} + + +async def test_provider_receives_full_invocation(): + seen: dict[str, ToolInvocation] = {} + + @define_tool("echo", description="Echo tool") + def echo(invocation: ToolInvocation) -> str: + return "ok" + + def provider(inv: ToolInvocation) -> str: + seen["invocation"] = inv + return "ctx" + + session = _session_with_fake_rpc("sess-42") + session._register_tool_context_provider(provider) + + await session._execute_tool_and_respond( + request_id="r1", + tool_name="echo", + tool_call_id="call-7", + arguments={"q": "hello"}, + handler=echo.handler, + ) + + inv = seen["invocation"] + assert inv.session_id == "sess-42" + assert inv.tool_name == "echo" + assert inv.tool_call_id == "call-7" + assert inv.arguments == {"q": "hello"} + + +async def test_no_provider_leaves_context_none(): + seen: dict[str, Any] = {} + + @define_tool("echo", description="Echo tool") + def echo(invocation: ToolInvocation) -> str: + seen["context"] = invocation.context + return "ok" + + session = _session_with_fake_rpc() + + await session._execute_tool_and_respond( + request_id="r1", + tool_name="echo", + tool_call_id="c1", + arguments={}, + handler=echo.handler, + ) + + assert seen["context"] is None + + +async def test_provider_returning_none_leaves_context_none(): + seen: dict[str, Any] = {} + + @define_tool("echo", description="Echo tool") + def echo(invocation: ToolInvocation) -> str: + seen["context"] = invocation.context + return "ok" + + session = _session_with_fake_rpc() + session._register_tool_context_provider(lambda _: None) + + await session._execute_tool_and_respond( + request_id="r1", + tool_name="echo", + tool_call_id="c1", + arguments={}, + handler=echo.handler, + ) + + assert seen["context"] is None + + +def test_register_and_clear_provider_round_trip(): + session = CopilotSession("sess-1", client=None) + assert session._get_tool_context_provider() is None + + def provider(_: ToolInvocation) -> str: + return "ctx" + + session._register_tool_context_provider(provider) + assert session._get_tool_context_provider() is provider + + session._register_tool_context_provider(None) + assert session._get_tool_context_provider() is None diff --git a/python/test_tools.py b/python/test_tools.py index d583b59c0..30b933a51 100644 --- a/python/test_tools.py +++ b/python/test_tools.py @@ -249,6 +249,49 @@ class Params(BaseModel): ) +class TestToolInvocation: + def test_context_defaults_to_none(self): + inv = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="t", + arguments={}, + ) + assert inv.context is None + + def test_context_can_be_set(self): + sentinel = object() + inv = ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="t", + arguments={}, + context=sentinel, + ) + assert inv.context is sentinel + + async def test_handler_can_read_context(self): + seen = None + + @define_tool("t", description="Reads context") + def tool(invocation: ToolInvocation) -> str: + nonlocal seen + seen = invocation.context + return "ok" + + await tool.handler( + ToolInvocation( + session_id="s1", + tool_call_id="c1", + tool_name="t", + arguments={}, + context={"user": "alice"}, + ) + ) + + assert seen == {"user": "alice"} + + class TestNormalizeResult: def test_none_returns_empty_success(self): result = _normalize_result(None)