diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py new file mode 100644 index 0000000..8a264cf --- /dev/null +++ b/sdk/adrian/__init__.py @@ -0,0 +1,1247 @@ +"""Adrian: multi-agent event capture SDK for LangChain/LangGraph as of 2026-05-10. + +Initialise with a single call and all LLM / tool activity is automatically +captured, paired, and emitted as ``PairedEvent`` objects through registered +handlers:: + + import adrian + + adrian.init(api_key="...") + +Events are paired (chat_model_start + llm_end, tool_start + tool_end), +enriched with agent identity and parent context, and emitted through +pluggable handlers (JSONL, WebSocket, custom). + +""" + +# pyright: reportUnknownVariableType=false +# pyright: reportUnknownMemberType=false +# pyright: reportUnknownArgumentType=false +# pyright: reportUnknownLambdaType=false + +from __future__ import annotations + +import asyncio +import atexit +import logging +import os +from pathlib import Path +from typing import Any +from uuid import uuid4 + +from langchain_core.callbacks.manager import CallbackManager +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage, ToolMessage +from langchain_core.runnables.base import Runnable +from langchain_core.runnables.config import ensure_config + +from adrian.config import ( + AdrianConfig, + OnAuditCallback, + OnBlockCallback, + OnDisconnectCallback, + OnEventCallback, + OnMcpServerCallback, + OnReconnectCallback, + OnVerdictCallback, + get_config, + is_initialized, + set_config, +) +from adrian.context import AgentContextTracker, get_invocation_id, set_invocation_id +from adrian.format.types import PairedEvent +from adrian.handler import AdrianCallbackHandler +from adrian.handlers.jsonl import JSONLHandler +from adrian.hooks import EventHandler, HookRegistry +from adrian.mcp import ( + McpServer, + _patch_mcp_adapter, # pyright: ignore[reportPrivateUsage] + mcp_servers, +) +from adrian.mcp import ( + _reset as _reset_mcp, # pyright: ignore[reportPrivateUsage] +) +from adrian.pairing import EventPairBuffer +from adrian.pii import ( + PiiConfig, + PiiRedactor, + RedactingHandler, + RedactionStrategy, + redact_text, +) +from adrian.proto import event_pb2 as pb +from adrian.session_persistence import resolve_session_id +from adrian.types import ToolCallRecord, VerdictContext +from adrian.ws import WebSocketClient + +__version__ = "1.0.2" +__all__ = [ + "init", + "shutdown", + "get_handler", + "AdrianCallbackHandler", + "AdrianConfig", + "EventHandler", + "JSONLHandler", + "McpServer", + "OnAuditCallback", + "OnBlockCallback", + "OnDisconnectCallback", + "OnEventCallback", + "OnMcpServerCallback", + "OnReconnectCallback", + "OnVerdictCallback", + "PairedEvent", + "PiiConfig", + "PiiRedactor", + "RedactingHandler", + "RedactionStrategy", + "ToolCallRecord", + "VerdictContext", + "__version__", + "mcp_servers", + "redact_text", +] + +logger = logging.getLogger("adrian") + +_hooks: HookRegistry | None = None +_handler: AdrianCallbackHandler | None = None +_ws_client: WebSocketClient | None = None +_fork_handler_registered: bool = False + + +# ------------------------------------------------------------------ +# Fork safety +# ------------------------------------------------------------------ + + +def _reset_after_fork() -> None: + """Drop inherited Adrian state in a forked child process. + + Registered via ``os.register_at_fork`` on the first :func:`init` call. + Nulls out module globals so the child does not silently share the + parent's WebSocket socket, writing to a shared socket from two + processes interleaves bytes on the wire, corrupting frames the + server cannot parse. + + Triggered by pre-fork deployments (``gunicorn --preload``, + ``multiprocessing.Pool``, Celery prefork). The child must call + :func:`init` again from its worker startup hook to establish its + own connection. + """ + global _hooks, _handler, _ws_client # noqa: PLW0603 + + _hooks = None + _handler = None + _ws_client = None + _reset_mcp() + + +# ------------------------------------------------------------------ +# Public API +# ------------------------------------------------------------------ + + +def init( + api_key: str | None = None, + log_file: str | Path = "events.jsonl", + handlers: list[EventHandler] | None = None, + auto_instrument: bool = True, + log_level: str | None = None, + ws_url: str | None = None, + session_id: str | None = None, + block_timeout: float = 30.0, + on_event: OnEventCallback | None = None, + on_verdict: OnVerdictCallback | None = None, + on_block: OnBlockCallback | None = None, + on_audit: OnAuditCallback | None = None, + on_disconnect: OnDisconnectCallback | None = None, + on_reconnect: OnReconnectCallback | None = None, + on_mcp_server: OnMcpServerCallback | None = None, + replay_buffer_frames: int = 1000, +) -> None: + """Initialise the Adrian SDK. + + Creates the event pairing buffer, agent context tracker, and hook + registry, then monkey-patches LangChain so every LLM call and tool + invocation is captured as a ``PairedEvent``. + + Events are emitted through registered handlers. If no handlers are + provided, defaults to a ``JSONLHandler`` writing to ``log_file``. + + Transport (WebSocket, HTTP, etc.) is not managed by the SDK, pass + a pre-configured handler via the ``handlers`` list instead. + + Args: + api_key: Adrian API key. Falls back to ``ADRIAN_API_KEY`` env + var. Stored in config for handlers that need it. + log_file: Path to the JSONL output file (used when no handlers + are explicitly provided). + handlers: List of ``EventHandler`` instances to receive paired + events. If ``None``, defaults to ``JSONLHandler(log_file)``. + auto_instrument: Patch LangChain / LangGraph at import time. + log_level: Optional override for the ``adrian`` logger's level. + ``None`` (default) inherits from the application's logging + config; pass e.g. ``"DEBUG"`` to force-enable verbose SDK + logging without touching global config. + ws_url: WebSocket URL for the Adrian server (e.g. + ``"ws://localhost:8080/ws"``). Falls back to ``ADRIAN_WS_URL``. + When set and ``handlers`` is ``None``, a ``WebSocketClient`` is + auto-registered alongside the default ``JSONLHandler``. Requires + ``api_key``. + session_id: Session identifier. Falls back to + ``ADRIAN_SESSION_ID``, then to a per-cwd persistent UUID. + See :mod:`adrian.session_persistence`. + block_timeout: Max seconds to wait for a verdict in ``MODE_BLOCK`` + before fail-open. Ignored in ``MODE_ALERT`` (no wait) and + ``MODE_HITL`` (wait indefinitely). Falls back to + ``ADRIAN_BLOCK_TIMEOUT``. + on_event: Callback for every paired event. + on_verdict: Callback for every verdict. + on_block: Callback for BLOCK-tier verdicts (M3 / M4). Notification + only; return value is ignored. + on_audit: Callback for NOTIFY-tier verdicts (M2). + on_disconnect: Callback fired when the WebSocket is lost. Receives + a reason string. Sync or async. + on_reconnect: Callback fired when the WebSocket reconnects after a + prior disconnect. Does not fire on initial connection. Sync + or async. + on_mcp_server: Callback fired when an MCP server is registered or + updated. Receives the freshly-registered ``McpServer``. Does + NOT fire on no-op re-observations. Sync or async. + replay_buffer_frames: Max serialised frames kept in the in-memory + ring for replay after a transient WS outage (server restart, + ALB shuffle). Each frame is one ``ClientFrame.paired_batch`` + (~4KB). Default 1000 frames ≈ ~4MB RAM. Falls back to + ``ADRIAN_REPLAY_BUFFER_FRAMES``. At capacity each further + append evicts the oldest; a one-shot WARN fires on first fill + and cumulative drops are logged on the next reconnect. + """ + global _hooks, _handler, _ws_client, _fork_handler_registered # noqa: PLW0603 + + if not _fork_handler_registered and hasattr(os, "register_at_fork"): + os.register_at_fork(after_in_child=_reset_after_fork) + _fork_handler_registered = True + + try: + loop: asyncio.AbstractEventLoop | None = asyncio.get_running_loop() + except RuntimeError: + loop = None + + resolved_key = api_key or os.getenv("ADRIAN_API_KEY") or None + resolved_file = Path(os.getenv("ADRIAN_LOG_FILE", str(log_file))) + # Default to the hosted Adrian backend so `adrian.init(api_key=...)` + # Just Works for freemium users. Self-hosted users override via + # ws_url= or ADRIAN_WS_URL. + resolved_ws_url = ( + os.getenv("ADRIAN_WS_URL") or ws_url or "wss://adrian.secureagentics.ai/ws" + ) + resolved_session = ( + os.getenv("ADRIAN_SESSION_ID") or session_id or resolve_session_id() + ) + resolved_block_timeout = float( + os.getenv("ADRIAN_BLOCK_TIMEOUT", str(block_timeout)), + ) + + resolved_replay_buffer_frames = replay_buffer_frames + env_replay = os.getenv("ADRIAN_REPLAY_BUFFER_FRAMES", "").strip() + + if env_replay: + try: + resolved_replay_buffer_frames = int(env_replay) + except ValueError: + logger.warning( + "ADRIAN_REPLAY_BUFFER_FRAMES=%r is not an int; " + "falling back to kwarg default %d", + env_replay, + replay_buffer_frames, + ) + + if resolved_ws_url and not resolved_key: + logger.warning( + "ws_url is set but no api_key provided. Set api_key or " + "ADRIAN_API_KEY; the server will reject the WS connection." + ) + + config = AdrianConfig( + api_key=resolved_key, + log_file=resolved_file, + log_level=log_level, + session_id=resolved_session, + ws_url=resolved_ws_url, + block_timeout=resolved_block_timeout, + on_event=on_event, + on_verdict=on_verdict, + on_block=on_block, + on_audit=on_audit, + on_disconnect=on_disconnect, + on_reconnect=on_reconnect, + on_mcp_server=_make_on_mcp_server_chain(on_mcp_server), + replay_buffer_frames=resolved_replay_buffer_frames, + ) + + set_config(config) + + if log_level is not None: + # Only override the adrian logger's level when the caller asks + # for it explicitly. Default behaviour respects whatever the + # application configured via logging.basicConfig / .config. + logging.getLogger("adrian").setLevel( + getattr(logging, log_level.upper(), logging.INFO), + ) + + # Build handler list, then optionally wrap with PII redaction + handler_list: list[EventHandler] = [] + + if handlers: + handler_list = list(handlers) + else: + handler_list.append(JSONLHandler(path=resolved_file)) + + if resolved_ws_url: + _ws_client = WebSocketClient( + url=resolved_ws_url, + session_id=config.session_id, + api_key=resolved_key or "", + on_disconnect=on_disconnect, + on_reconnect=on_reconnect, + on_login_ack=_send_mcp_inventory, + replay_buffer_frames=resolved_replay_buffer_frames, + ) + handler_list.append(_ws_client) + + handler_list = [RedactingHandler(h) for h in handler_list] + + # Create hook registry and register handlers + _hooks = HookRegistry() + + for h in handler_list: + _hooks.register(h) + + # Create pairing and context tracking components + pair_buffer = EventPairBuffer() + context_tracker = AgentContextTracker() + + # Create handler with new components + _handler = AdrianCallbackHandler( + pair_buffer=pair_buffer, + context_tracker=context_tracker, + hooks=_hooks, + config=config, + ) + + if _ws_client is not None: + # Back-reference so the recv loop can dispatch verdicts into the + # handler's block/audit/verdict callback machinery. + _ws_client._handler = _handler # pyright: ignore[reportPrivateUsage] + + if loop is not None: + _ws_client.schedule_connect(loop) + else: + logger.debug( + "No running event loop at init(); WebSocket will connect on " + "first send from within an async context." + ) + + if auto_instrument: + _auto_instrument_langchain() + + # MCP server tracking is independent of LangChain auto-instrumentation, + # it observes a different library (langchain-mcp-adapters) and is the + # only path the SDK has to learn about MCP servers. Always run. + _patch_mcp_adapter() + + atexit.register(shutdown) + logger.info( + "Adrian v%s initialised (handlers=%d, ws=%s)", + __version__, + len(_hooks), + resolved_ws_url or "disabled", + ) + + +def shutdown() -> None: + """Close all handlers and reset state.""" + global _hooks, _handler, _ws_client # noqa: PLW0603 + + if _hooks is not None: + try: + loop = asyncio.get_running_loop() + loop.create_task(_hooks.close()) + except RuntimeError: + asyncio.run(_hooks.close()) + + _hooks = None + + _handler = None + _ws_client = None + set_config(None) + + +def get_handler() -> AdrianCallbackHandler | None: + """Return the SDK's callback handler, or ``None`` if uninitialised. + + Useful when ``adrian.init(auto_instrument=False)`` is set and you + need to attach the handler to LangChain calls explicitly, e.g.:: + + adrian.init(api_key=..., auto_instrument=False) + handler = adrian.get_handler() + await llm.ainvoke(prompt, config={"callbacks": [handler]}) + + The handler is wired into Adrian's WS hook chain at ``init()`` + time; constructing a fresh ``AdrianCallbackHandler`` directly will + not emit events. + """ + return _handler + + +# ------------------------------------------------------------------ +# Internal helpers +# ------------------------------------------------------------------ + + +def _get_callback_handler() -> AdrianCallbackHandler | None: + """Return the current callback handler (closure helper).""" + return _handler + + +def _get_config() -> AdrianConfig | None: + """Return the current config without raising (closure helper).""" + if not is_initialized(): + return None + + return get_config() + + +async def _send_mcp_inventory() -> None: + """Send the current MCP server registry as a ``ClientFrame``. + + Triggers: once per connect (after each ``LoginAck``) and on every + ``on_mcp_server`` registry change. The server replaces its full + list on every frame, so a fresh snapshot is correct on every fire. + No-op when the WebSocket transport is disabled or when the registry + is empty (the registry is additive, so an empty snapshot is + indistinguishable from "not yet observed", sending it would only + log a ``which=`` warning on the server). + """ + ws = _ws_client + + if ws is None: + return + + servers = mcp_servers() + + if not servers: + return + + frame = pb.ClientFrame() + + for server in servers: + added = frame.mcp_inventory.servers.add() + added.name = server.name + added.transport = server.transport + added.endpoint = server.endpoint + + await ws._send_frame(frame) # pyright: ignore[reportPrivateUsage] + + +def _make_on_mcp_server_chain( + user_cb: OnMcpServerCallback | None, +) -> OnMcpServerCallback: + """Compose ``_send_mcp_inventory`` with the user's ``on_mcp_server``. + + Schedules the inventory sync as a fire-and-forget task on the + running loop (if any) and forwards transparently to the user's + callback so its sync-vs-async return shape is preserved for + :func:`adrian.callbacks.fire` to handle. When no loop is running, + the inventory sync is skipped, the next ``LoginAck`` (which only + fires once a loop is up) will catch up. + """ + + def chain(server: McpServer) -> Any: # noqa: ANN401 + try: + loop = asyncio.get_running_loop() + except RuntimeError: + pass + else: + loop.create_task(_send_mcp_inventory()) + + if user_cb is None: + return None + + return user_cb(server) + + return chain + + +def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 + """Merge the Adrian handler into a LangChain ``RunnableConfig``. + + Args: + config: An existing LangChain RunnableConfig or ``None``. + + Returns: + A config dict guaranteed to contain the Adrian handler. + """ + handler = _get_callback_handler() + + if handler is None: + return ensure_config(config) + + config = ensure_config(config) + callbacks = config.get("callbacks") or [] + + if hasattr(callbacks, "handlers"): + callbacks = list(callbacks.handlers) # pyright: ignore[reportAttributeAccessIssue] + elif not isinstance(callbacks, list): + callbacks = [callbacks] if callbacks else [] + else: + callbacks = list(callbacks) + + handler_types = [type(h).__name__ for h in callbacks] + + if "AdrianCallbackHandler" not in handler_types: + callbacks.insert(0, handler) + + config["callbacks"] = callbacks + + return config + + +# ------------------------------------------------------------------ +# Auto-instrumentation +# ------------------------------------------------------------------ + + +def _auto_instrument_langchain() -> None: + """Apply all monkey-patches to LangChain / LangGraph.""" + try: + _patch_runnable() + _patch_callback_manager() + _patch_chat_model() + _patch_langgraph() + _patch_tool_node() + _patch_base_tool() + _patch_agent_executor() + logger.debug("LangChain auto-instrumentation applied") + except ImportError: + logger.debug("LangChain not found, skipping auto-instrumentation") + except Exception: + logger.exception("Auto-instrumentation failed") + + +# --- 1. Runnable --- + + +def _patch_runnable() -> None: + """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" + if getattr(Runnable, "_adrian_patched", False): + return + + original_invoke = Runnable.invoke + original_ainvoke = Runnable.ainvoke + original_astream = Runnable.astream + original_stream = Runnable.stream + + def patched_invoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return original_invoke(self, input, config, **kwargs) + + async def patched_ainvoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return await original_ainvoke(self, input, config, **kwargs) + + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config, **kwargs) + + Runnable.invoke = patched_invoke # type: ignore[assignment] + Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] + Runnable.astream = patched_astream # type: ignore[assignment] + Runnable.stream = patched_stream # type: ignore[assignment] + Runnable._adrian_patched = True # type: ignore[attr-defined] + logger.debug("Patched Runnable.invoke / ainvoke") + + +# --- 2. CallbackManager --- + + +def _patch_callback_manager() -> None: + """Patch ``CallbackManager.__init__`` to always include Adrian.""" + if getattr(CallbackManager, "_adrian_cbm_patched", False): + return + + original_configure = CallbackManager.configure + + def patched_configure( + _cls: Any, # noqa: ANN401 + inheritable_callbacks: Any = None, # noqa: ANN401 + local_callbacks: Any = None, # noqa: ANN401 + verbose: bool = False, + inheritable_tags: Any = None, # noqa: ANN401 + local_tags: Any = None, # noqa: ANN401 + inheritable_metadata: Any = None, # noqa: ANN401 + local_metadata: Any = None, # noqa: ANN401 + **extra: Any, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Inject Adrian handler into inheritable callbacks. + + ``**extra`` forwards any kwargs newer langchain-core releases + add to ``CallbackManager.configure`` (e.g. 1.3 added + ``langsmith_inheritable_metadata``) so the patch stays + forward-compatible without re-declaring every signature change. + """ + handler = _get_callback_handler() + + if handler: + if inheritable_callbacks is None: + inheritable_callbacks = [handler] + elif isinstance(inheritable_callbacks, list): + handler_types = [type(h).__name__ for h in inheritable_callbacks] + + if "AdrianCallbackHandler" not in handler_types: + inheritable_callbacks = [handler, *inheritable_callbacks] + elif hasattr(inheritable_callbacks, "handlers"): + handler_types = [ + type(h).__name__ for h in inheritable_callbacks.handlers + ] + + if "AdrianCallbackHandler" not in handler_types: + inheritable_callbacks.handlers.insert(0, handler) + + return original_configure( + inheritable_callbacks=inheritable_callbacks, + local_callbacks=local_callbacks, + verbose=verbose, + inheritable_tags=inheritable_tags, + local_tags=local_tags, + inheritable_metadata=inheritable_metadata, + local_metadata=local_metadata, + **extra, + ) + + CallbackManager.configure = classmethod( # type: ignore[assignment] + lambda _cls, *a, **kw: patched_configure(_cls, *a, **kw), # pyright: ignore[reportCallIssue] + ) + CallbackManager._adrian_cbm_patched = True # type: ignore[attr-defined] + logger.debug("Patched CallbackManager.configure") + + +# --- 3. BaseChatModel --- + + +def _patch_chat_model() -> None: + """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" + if getattr(BaseChatModel, "_adrian_chat_model_patched", False): + return + + original_invoke = BaseChatModel.invoke + original_ainvoke = BaseChatModel.ainvoke + original_astream = BaseChatModel.astream + original_stream = BaseChatModel.stream + + def patched_invoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return original_invoke(self, input, config=config, **kwargs) + + async def patched_ainvoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config=config, **kwargs) + + BaseChatModel.invoke = patched_invoke # type: ignore[assignment] + BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseChatModel.astream = patched_astream # type: ignore[assignment] + BaseChatModel.stream = patched_stream # type: ignore[assignment] + BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] + logger.debug("Patched BaseChatModel.invoke / ainvoke") + + +# --- 4. LangGraph Pregel --- + + +def _patch_langgraph() -> None: + """Patch ``Pregel.invoke`` / ``ainvoke`` / ``astream``. + + The async patches also set the invocation_id ContextVar at the + top-level call so all sub-agent events share the same ID. + """ + try: + from langgraph.pregel import Pregel + except ImportError: + return + + if getattr(Pregel, "_adrian_pregel_patched", False): + return + + original_invoke = Pregel.invoke + original_ainvoke = Pregel.ainvoke + original_astream = Pregel.astream + + def patched_invoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks into sync graph invocation.""" + config = _inject_callbacks(config) + + return original_invoke(self, input, config=config, **kwargs) + + async def patched_ainvoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks and set invocation_id. + + Only the top-level call sets the invocation_id. Nested calls + (sub-agent ainvoke) inherit it via contextvars propagation. + """ + config = _inject_callbacks(config) + + current = get_invocation_id() + token = None + + if current is None: + uuid_ = uuid4() + token = set_invocation_id(str(uuid_)) + + try: + return await original_ainvoke(self, input, config=config, **kwargs) + finally: + if token is not None: + token.var.reset(token) + + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks and set invocation_id for streaming.""" + config = _inject_callbacks(config) + + current = get_invocation_id() + token = None + + if current is None: + uuid_ = uuid4() + token = set_invocation_id(str(uuid_)) + + try: + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + finally: + if token is not None: + token.var.reset(token) + + Pregel.invoke = patched_invoke # type: ignore[assignment] + Pregel.ainvoke = patched_ainvoke # type: ignore[assignment] + Pregel.astream = patched_astream # type: ignore[assignment] + Pregel._adrian_pregel_patched = True # type: ignore[attr-defined] + logger.debug("Patched Pregel.invoke / ainvoke / astream") + + +# --- 5. ToolNode --- + + +def _extract_tool_calls( + state: dict[str, Any] | list[BaseMessage] | Any, +) -> list[dict[str, Any]]: + """Extract tool_calls from ToolNode input (all three dispatch shapes). + + Returns full tool_call dicts (with id, name, args) for backward + compat with tests and callers that need the full shape. + """ + # Shape 3: per-tool-call dict from _afunc dispatch + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + if isinstance(tc, dict) and tc.get("id"): + return [tc] + tc_id = getattr(tc, "id", None) + if tc_id: + return [ + { + "id": tc_id, + "name": getattr(tc, "name", ""), + "args": getattr(tc, "args", {}), + } + ] + return [] + + # Shape 1/2: state dict or message list + if isinstance(state, dict): + messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + elif isinstance(state, list): + messages = list(state) + else: + return [] + + for msg in reversed(messages): + if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): + return msg.tool_calls # type: ignore[no-any-return] + + return [] + + +def _should_halt(verdict: pb.Verdict) -> bool: + """Decide whether a verdict should halt tool execution. + + HITL resolutions override per-MAD policy when present. + """ + if verdict.HasField("hitl"): + return not verdict.hitl.continue_execution + + mad_prefix = verdict.mad_code[:2] + return { + "M0": verdict.policy.policy_m0, + "M2": verdict.policy.policy_m2, + "M3": verdict.policy.policy_m3, + "M4": verdict.policy.policy_m4, + }.get(mad_prefix, False) + + +def _patch_tool_node() -> None: + """Patch ToolNode for callback injection + async verdict gate. + + ToolNode dispatches tools via tool.invoke (sync) even within async + Pregel. BaseTool.invoke can't await a verdict from the event loop + thread, so we add the verdict gate here on ToolNode.ainvoke — the + entry point Pregel calls before tool dispatch begins. This is a + complementary gate to BaseTool (which covers direct callers). + """ + try: + from langgraph.prebuilt import ToolNode + except ImportError: + return + + if getattr(ToolNode, "_adrian_tool_node_patched", False): + return + + original_invoke = ToolNode.invoke + original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) + + async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 + """Returns True if tools should be BLOCKED.""" + ws = _ws_client + if ws is None: + return False + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for(ws._login_ack_received.wait(), timeout=5.0) # pyright: ignore[reportPrivateUsage] + except TimeoutError: + logger.warning("ToolNode: LoginAck not received within 5s; blocking") + return True + if not ws.policy_active(): + return False + + tc_ids: list[str] = [ + str(tc.get("id")) for tc in _extract_tool_calls(state) if tc.get("id") + ] + if not tc_ids: + return False + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) + if verdict is None: + logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") + return True + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return True + return False + + def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 + tc_ids = [tc.get("id") for tc in _extract_tool_calls(state) if tc.get("id")] + return { + "messages": [ + ToolMessage( + content="[BLOCKED by security policy]", tool_call_id=tid, name="" + ) + for tid in tc_ids + ] + } + + def patched_invoke( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return original_invoke(self, input, config=config, **kwargs) + + async def patched_ainvoke( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + # Verdict gate removed — BaseTool.ainvoke/arun is the single + # gate layer. Gating here too caused double-gate: ToolNode + # consumed the verdict future, BaseTool's gate registered a + # fresh future that never resolved → 30s timeout on a benign + # verdict. Callback injection is kept so events still flow. + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + assert original_astream is not None # guarded by line below + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode.invoke / ainvoke / astream") + + +# --- 6. BaseTool (universal verdict gate) --- + + +_BLOCKED_CONTENT = "[BLOCKED by security policy]" + + +def _patch_base_tool() -> None: + """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. + + Every LangChain tool — whether dispatched by ToolNode, AgentExecutor, + create_react_agent, or a manual ``tool.invoke(tool_call)`` loop — + funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` + (async). Gating here covers all frameworks in one place. + + The gate extracts ``tool_call_id`` from the input (a ``ToolCall`` + TypedDict), awaits the classifier verdict for the producing LLM + event, and returns a ``[BLOCKED]`` string instead of running the + tool body when the verdict is in-scope (M3/M4 under MODE_BLOCK). + + In MODE_BLOCK, verdict timeout is fail-closed (block the tool) + because the absence of a verdict in block mode is a policy violation. + In MODE_ALERT, no gate fires at all (skip). + """ + from langchain_core.tools import BaseTool + from langchain_core.tools.base import ( + _is_tool_call, # pyright: ignore[reportPrivateUsage] + ) + + if getattr(BaseTool, "_adrian_base_tool_patched", False): + return + + original_invoke = BaseTool.invoke + original_ainvoke = BaseTool.ainvoke + + def _extract_tool_call_id(input: Any) -> str | None: # noqa: A002, ANN401 + """Extract tool_call_id from a ToolCall input, or None.""" + if isinstance(input, dict) and _is_tool_call(input): + return input.get("id") + return None + + async def _async_gate(tool_call_id: str) -> bool: + """Returns True if the tool should be BLOCKED.""" + ws = _ws_client + if ws is None: + return False + + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "BaseTool: LoginAck not received within 5s; " + "blocking tool (refusing to run without verified policy)" + ) + return True + + if not ws.policy_active(): + return False + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) + + if verdict is None: + # Fail-closed in block mode: no verdict = block. + logger.warning( + "BaseTool: verdict timeout for tool_call_id=%s; " + "blocking (fail-closed in MODE_BLOCK)", + tool_call_id, + ) + return True + + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return True + + return False + + def _sync_gate(tool_call_id: str) -> bool: + """Sync verdict gate — works for pure-sync and worker-thread callers. + + Pure-sync (no event loop): runs ``_async_gate`` via + ``loop.run_until_complete``. + + Worker-thread (Pregel dispatches sync tools on a thread-pool + worker while the event loop runs on the main thread): bridges + the async gate to the main loop via ``run_coroutine_threadsafe`` + and blocks the worker thread until the verdict resolves. + + Event-loop thread (calling tool.invoke directly from async + code): cannot block — returns False (skip). The async path + (BaseTool.ainvoke) handles this case. + """ + ws = _ws_client + if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] + return False + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + return False + + if not loop.is_running(): + # Pure-sync caller — safe to block + return loop.run_until_complete(_async_gate(tool_call_id)) + + # Check if we're on a worker thread (no running loop on THIS + # thread) vs the event-loop thread itself. + try: + asyncio.get_running_loop() + # We ARE on the event-loop thread — can't block it. + return False + except RuntimeError: + pass + + # Worker thread: bridge the async gate to the main loop. + main_loop = getattr(ws, "_loop", None) + if main_loop is None or not main_loop.is_running(): + return False + + try: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + future = asyncio.run_coroutine_threadsafe( + _async_gate(tool_call_id), main_loop + ) + return future.result(timeout=timeout if timeout else 60.0) + except Exception: + return False + + def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 + """Return a blocked response compatible with ToolNode. + + Returns a ToolMessage for create_react_agent / ToolNode + compatibility. Falls back to bare string on import failure. + """ + try: + return ToolMessage(content=_BLOCKED_CONTENT, tool_call_id=tc_id, name="") + except Exception: + return _BLOCKED_CONTENT + + def patched_invoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + tc_id = _extract_tool_call_id(input) + if tc_id and _sync_gate(tc_id): + return _blocked_response(tc_id) + return original_invoke(self, input, config=config, **kwargs) + + async def patched_ainvoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + tc_id = _extract_tool_call_id(input) + if tc_id and await _async_gate(tc_id): + return _blocked_response(tc_id) + return await original_ainvoke(self, input, config=config, **kwargs) + + original_arun = BaseTool.arun + + async def patched_arun( + self: Any, # noqa: ANN401 + tool_input: Any, # noqa: ANN401 + *args: Any, + tool_call_id: str | None = None, + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Gate on arun — AgentExecutor calls tool.arun directly.""" + if tool_call_id and await _async_gate(tool_call_id): + return _blocked_response(tool_call_id) + return await original_arun( + self, tool_input, *args, tool_call_id=tool_call_id, **kwargs + ) + + BaseTool.invoke = patched_invoke # type: ignore[assignment] + BaseTool.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseTool.arun = patched_arun # type: ignore[assignment] + BaseTool._adrian_base_tool_patched = True # type: ignore[attr-defined] + logger.debug("Patched BaseTool.invoke / ainvoke / arun (universal verdict gate)") + + +# --- 7. AgentExecutor (tool_call_id on agent_action, not on tool.arun) --- + + +def _patch_agent_executor() -> None: + """Patch AgentExecutor._aperform_agent_action for the executor path. + + AgentExecutor calls tool.arun without forwarding tool_call_id, + so the BaseTool.arun gate can't extract it. The tool_call_id lives + on agent_action.tool_call_id (set by OpenAI-style parsers). We + intercept here, await the verdict, and return a blocked observation + instead of calling the tool. + """ + AgentExecutor = None + AgentStep = None + for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): + try: + mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) + AgentExecutor = getattr(mod, "AgentExecutor", None) + AgentStep = getattr(mod, "AgentStep", None) + if AgentExecutor and AgentStep: + break + except ImportError: + continue + + if AgentExecutor is None or AgentStep is None: + return + if getattr(AgentExecutor, "_adrian_executor_patched", False): + return + + original_aperform = AgentExecutor._aperform_agent_action + + async def patched_aperform( + self: Any, + name_to_tool_map: Any, + color_mapping: Any, # noqa: ANN401 + agent_action: Any, + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + tc_id = getattr(agent_action, "tool_call_id", None) + if tc_id: + ws = _ws_client + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; blocking" + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if ws.policy_active(): + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager + ) + + AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] + AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] + logger.debug("Patched AgentExecutor._aperform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py new file mode 100644 index 0000000..169cbdc --- /dev/null +++ b/sdk/adrian/ws.py @@ -0,0 +1,1038 @@ +"""Async WebSocket ``EventHandler`` that streams ``PairedEvent`` to the worker core API. + +Converts each ``PairedEvent`` into a ``pb.PairedEvent`` protobuf, wraps it in a +``ClientFrame.paired_batch``, and sends it over a long-lived WebSocket +connection. Verdicts received back resolve block-mode futures and fire the +callback handler's verdict processing. + +Implements the ``EventHandler`` protocol so it slots into the SDK's hook +registry alongside ``JSONLHandler``. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import time +from collections import OrderedDict, deque +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import websockets + +if TYPE_CHECKING: + from adrian.config import OnDisconnectCallback, OnReconnectCallback + from adrian.handler import AdrianCallbackHandler + +from adrian.format.types import ( + AgentContext, + LlmPairData, + PairedEvent, + ParentContext, +) +from adrian.proto import event_pb2 as pb + +logger = logging.getLogger("adrian.ws") + +SCHEMA_VERSION = 2 + +_INITIAL_BACKOFF = 1.0 +_MAX_BACKOFF = 30.0 +# Server close code: quota exhausted. Spec'd in +# server/internal/websocket/handler.go (closeQuotaExceeded). Returning +# every 30s would hammer the server while quota is depleted; one +# minute is slow enough to be cheap, fast enough that the next hourly +# / daily / monthly window-rollover is picked up within tolerance. +_QUOTA_EXHAUSTED_CLOSE_CODE = 4003 +_QUOTA_RECONNECT_DELAY = 60.0 +# Cap on in-flight LLM run_id → event_id mappings. Evicted LRU-style; +# block-mode lookups for evicted entries fail open. +_MAX_RUN_ID_MAP = 1024 +# Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). +_MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 + +_DEFAULT_REPLAY_BUFFER_FRAMES = 1000 + +# Heartbeat tuning. 10s interval / 15s pong timeout detects half-open +# connections (ALB idle cut, NAT drop, dead remote process) without +# flooding the wire. Kept in sync with the backend's pingInterval / +# pongTimeout, if these change, update server/internal/websocket/handler.go. +_PING_INTERVAL = 10.0 +_PING_TIMEOUT = 15.0 + +_PROVIDER_PREFIXES: dict[str, str] = { + "chatanthropic": "anthropic", + "chatopenai": "openai", + "chatgooglegenai": "google", + "chatcohere": "cohere", + "chatmistralai": "mistral", +} + +_PAIR_TYPE_MAP: dict[str, pb.PairType.ValueType] = { + "llm": pb.PAIR_TYPE_LLM, + "tool": pb.PAIR_TYPE_TOOL, +} + + +def _derive_provider(model_class_name: str) -> str: + """Derive the LLM provider from the model class name. + + Args: + model_class_name: Class name like ``"ChatAnthropic"`` or ``"ChatOpenAI"``. + + Returns: + Provider string (e.g. ``"anthropic"``), or the class name lower-cased + if no known prefix matches. + """ + key = model_class_name.lower() + + return _PROVIDER_PREFIXES.get(key, key) + + +def _fill_agent_context( + pb_ctx: pb.AgentContext, src: AgentContext | ParentContext +) -> None: + """Copy an AgentContext / ParentContext dataclass into its proto counterpart.""" + pb_ctx.agent_id = src.agent_id + pb_ctx.system_prompt = src.system_prompt + pb_ctx.user_instruction = src.user_instruction + + +def _safe_cancel( + task_or_future: asyncio.Task[Any] | asyncio.Future[Any] | None, +) -> None: + """Cancel a task / future, ignoring closed-loop errors at shutdown. + + Adrian's ``atexit`` handler may run after the user's loop has been + closed; in that path ``adrian.shutdown`` spawns a new ``asyncio.run`` + and walks each handler's ``close()``. Tasks bound to the *old* loop + can no longer be cancelled (``call_soon`` raises ``Event loop is + closed``). Swallowing the error here keeps the cleanup path quiet, + the task will be reaped when the dead loop is GC'd. + """ + if task_or_future is None or task_or_future.done(): + return + # "Event loop is closed", old loop is gone, nothing to cancel. + with contextlib.suppress(RuntimeError): + task_or_future.cancel() + + +def _paired_event_to_proto(event: PairedEvent) -> pb.PairedEvent: + """Convert a ``PairedEvent`` dataclass into its protobuf form. + + ``parent.agent_id`` empty-string signals "no parent agent". + ``parent_run_id`` empty-string signals "no parent in run tree". + """ + proto = pb.PairedEvent( + event_id=event.event_id, + invocation_id=event.invocation_id, + session_id=event.session_id, + run_id=event.run_id, + parent_run_id=event.parent_run_id, + timestamp=event.timestamp, + pair_type=_PAIR_TYPE_MAP.get(event.pair_type, pb.PAIR_TYPE_UNSPECIFIED), + ) + + _fill_agent_context(proto.agent, event.agent) + + if event.parent is not None: + _fill_agent_context(proto.parent, event.parent) + + if isinstance(event.data, LlmPairData): + proto.llm.model = event.data.model + + for msg in event.data.messages: + pb_msg = proto.llm.messages.add() + pb_msg.role = msg["role"] + pb_msg.content = msg["content"] + + proto.llm.output = event.data.output + + for tc in event.data.tool_calls: + pb_tc = proto.llm.tool_calls.add() + pb_tc.name = tc["name"] + pb_tc.args = json.dumps(tc["args"], default=str) + pb_tc.id = tc["id"] + + if event.data.usage is not None: + proto.llm.usage.prompt_tokens = event.data.usage["prompt_tokens"] + proto.llm.usage.completion_tokens = event.data.usage["completion_tokens"] + proto.llm.usage.total_tokens = event.data.usage["total_tokens"] + else: + # Union is LlmPairData | ToolPairData; this branch is the + # ToolPairData case. + proto.tool.tool_name = event.data.tool_name + proto.tool.tool_call_id = event.data.tool_call_id or "" + proto.tool.input = event.data.input + proto.tool.output = event.data.output + + if event.metadata: + proto.metadata_json = json.dumps(event.metadata, default=str).encode() + + return proto + + +class WebSocketClient: + """Streams ``PairedEvent`` instances to the worker core API. + + Connects eagerly via :meth:`schedule_connect` with exponential backoff, + auto-detects the LLM provider on the first LLM pair, sends paired events + as protobuf frames, and resolves block-mode futures when verdicts arrive. + """ + + def __init__( + self, + url: str, + session_id: str, + api_key: str, + handler: AdrianCallbackHandler | None = None, + on_disconnect: OnDisconnectCallback | None = None, + on_reconnect: OnReconnectCallback | None = None, + on_login_ack: Callable[[], Awaitable[None]] | None = None, + replay_buffer_frames: int = _DEFAULT_REPLAY_BUFFER_FRAMES, + ) -> None: + """Initialise without connecting. + + Args: + url: WebSocket endpoint URL. + session_id: Session ID sent in the login frame. + api_key: Adrian API key for the ``Authorization`` header. + handler: Callback handler for verdict processing. + on_disconnect: Fired when the connection is lost (sync or async). + Receives a reason string. + on_reconnect: Fired when the connection re-establishes after a + prior disconnect (sync or async). Does not fire on initial + connect. + on_login_ack: Async hook fired after each ``LoginAck`` frame is + applied, once per (re)connect. Used internally to push a + fresh ``McpInventory`` on every login. Exceptions are + logged and swallowed. + replay_buffer_frames: Ring-buffer capacity (frame count, not + bytes). When the cap is reached each further append evicts + the oldest frame; a one-shot WARN fires on first fill, and + the cumulative drop count is logged at WARN on the next + reconnect. + """ + self._url = url + self._session_id = session_id + self._api_key = api_key + self._handler = handler + self._on_disconnect = on_disconnect + self._on_reconnect = on_reconnect + self._on_login_ack_cb = on_login_ack + self._provider = "" + self._model = "" + # Server-supplied execution-mode policy. Populated when the + # first ServerFrame{login_ack} arrives after each (re)connect. + # ``policy_active()`` and ``block_timeout()`` read this state + # to decide whether the patched ToolNode should wait for a + # verdict and how long. + self._mode: int = pb.MODE_UNSPECIFIED + self._policy: pb.PolicySnapshot | None = None + # Set the first time a ``ServerFrame{login_ack}`` is applied. + # Used in two places: + # 1. ``on_paired_event`` defensively pre-registers a + # verdict-wait future when this event is unset, so the + # very first tool-bearing LLM emission is covered even + # though the recv loop hasn't yet processed LoginAck and + # ``policy_active()`` reads False. + # 2. The patched ``ToolNode.ainvoke`` ``await``s this event + # (with a short timeout) before deciding whether to wait + # for a verdict, so the first ToolNode invocation cannot + # run-through-without-waiting in the same window. + # Stays set across disconnect/reconnect because mode/policy + # state survives, a fresh LoginAck on reconnect simply re-sets + # an already-set event. + self._login_ack_received: asyncio.Event = asyncio.Event() + self._ws: websockets.ClientConnection | None = None + self._logged_in = False + self._connected = asyncio.Event() + self._connect_task: asyncio.Task[None] | None = None + self._recv_task: asyncio.Task[None] | None = None + # Set by close() so _handle_disconnect knows not to spawn a reconnect + # during a graceful shutdown. + self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None + # Futures awaited by the patched ToolNode.ainvoke when the + # active mode requires a wait (BLOCK or HITL). Each resolves + # with the matching ``Verdict`` proto. Futures survive a + # disconnect: a late verdict after reconnect still resolves + # the wait; if none arrives, ``wait_for_verdict``'s timeout + # produces a natural fail-open in BLOCK mode. + self._pending_verdicts: dict[str, asyncio.Future[pb.Verdict]] = {} + # Maps LLM pair run_id → event_id so a subsequent tool call can + # look up the verdict by its parent_run_id (the LLM's run_id). + # LRU-capped at _MAX_RUN_ID_MAP to bound memory on long sessions. + self._run_id_to_event_id: OrderedDict[str, str] = OrderedDict() + # Verdict-correlation map: maps each tool_call.id emitted by + # an LLM to the event_id of the LLM pair that emitted it. + # Populated on every LLM PairedEvent that has tool_calls. + # Consulted by the patched ``ToolNode.ainvoke`` so each tool + # in a parallel fan-out waits on its own producing LLM's + # verdict, not a global "last" pointer. LRU-capped at + # ``_MAX_TOOL_CALL_MAP``. + self._tool_call_id_to_event_id: OrderedDict[str, str] = OrderedDict() + # Serialises the lazy login-then-send sequence so two concurrent + # on_paired_event calls (parallel agents) cannot both send a login. + # Reused by _replay_buffer_to_ws to coordinate with live sends. + self._login_lock = asyncio.Lock() + # Ring buffer of recently serialised ClientFrame bytes. Appended + # only from the offline-or-send-failure paths in _send_frame; the + # happy path bypasses the ring entirely. Drained on reconnect. + self._replay_buffer: deque[bytes] = deque(maxlen=replay_buffer_frames) + # Flips True on the first append that reaches maxlen. Gates the + # one-shot "buffer full" WARN so we don't flood logs. + self._replay_buffer_filled: bool = False + # Monotonic counter of frames dropped due to buffer overflow + # (oldest evicted when a new append arrives at a full ring). + # Logged at WARN on the next reconnect. + self._replay_buffer_dropped: int = 0 + # True while the reconnect path is draining the replay buffer. + # Live sends observed during this window are routed back into + # the same deque so they slot in AFTER the pre-outage tail + # rather than racing onto the wire ahead of older buffered + # frames. Flipped on as the first sync line of + # _replay_buffer_to_ws and cleared in its finally. + self._replaying: bool = False + # Set by _handle_disconnect, cleared on successful reconnect. + # Used to gate on_reconnect and measure downtime. + self._disconnected_at: float | None = None + # One-shot delay applied before the next ``connect()`` attempt. + # Set when the server closes with a code that requests a longer + # wait (currently only 4003 quota exhausted); cleared by + # ``connect()`` after honouring it. ``None`` means use the + # standard exponential schedule. + self._next_reconnect_delay: float | None = None + + # -- Mode / policy state (populated by LoginAck) -- + + def policy_active(self) -> bool: + """Whether the active server mode requires waiting on verdicts. + + Single predicate consulted by the patched ``ToolNode.ainvoke``. + Returns ``True`` for ``MODE_BLOCK`` and ``MODE_HITL``; ``False`` + for ``MODE_ALERT`` and unset (pre-login) state. + """ + return self._mode in (pb.MODE_BLOCK, pb.MODE_HITL) + + def block_timeout(self, kwarg_default: float) -> float | None: + """Effective per-tool-call wait timeout for the active mode. + + - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s), fail-open + if the server doesn't classify in time. + - ``MODE_HITL``: ``None``, wait indefinitely for human review. + - ``MODE_ALERT`` / unset: ``0``, caller short-circuits before + registering a future. + """ + if self._mode == pb.MODE_BLOCK: + return kwarg_default + elif self._mode == pb.MODE_HITL: + return None + else: + return 0 + + # -- EventHandler protocol -- + + async def on_paired_event(self, event: PairedEvent) -> None: + """Send a paired event over the WebSocket. + + Auto-detects the LLM provider on the first LLM pair, updates the + run_id → event_id map for block mode, converts the dataclass to + protobuf, and sends a ``ClientFrame.paired_batch`` frame. + + For LLM pairs that carry tool_calls, registers the verdict-wait + future *before* the frame leaves the SDK. This closes the race + where a fast verdict roundtrip resolves and is dropped before + the patched ``ToolNode.ainvoke`` reaches its own + ``register_pending`` call. The matching ``register_pending`` + from the wait site is a get-or-create that returns the existing + future. + + Args: + event: The paired event to stream. + """ + if ( + event.pair_type == "llm" + and not self._provider + and isinstance(event.data, LlmPairData) + ): + self._model = event.data.model + self._provider = _derive_provider(event.data.model) + + if event.pair_type == "llm": + self._run_id_to_event_id[event.run_id] = event.event_id + self._run_id_to_event_id.move_to_end(event.run_id) + + if len(self._run_id_to_event_id) > _MAX_RUN_ID_MAP: + self._run_id_to_event_id.popitem(last=False) + + # Populate tool_call.id → event_id so each tool call can block + # on its own producing LLM's verdict under parallel fan-out. + if isinstance(event.data, LlmPairData) and event.data.tool_calls: + for tc in event.data.tool_calls: + tc_id = tc.get("id") or "" + + if not tc_id: + continue + + self._tool_call_id_to_event_id[tc_id] = event.event_id + self._tool_call_id_to_event_id.move_to_end(tc_id) + + if len(self._tool_call_id_to_event_id) > _MAX_TOOL_CALL_MAP: + self._tool_call_id_to_event_id.popitem(last=False) + + # Pre-register the wait future so an eager verdict + # cannot race ahead of the ToolNode patch. Gated on + # ``policy_active()`` so ALERT-mode sessions don't + # accumulate futures that will never be resolved or + # awaited, except for the very first event of the + # session, where ``LoginAck`` may not yet have been + # processed by the recv loop and ``policy_active()`` + # therefore reads False even when the mode will + # imminently be set to BLOCK or HITL. Pre-register + # defensively in that window; in ALERT mode the gate + # filters out every subsequent event so the leak is + # bounded to one orphan future per session. + if self.policy_active() or not self._login_ack_received.is_set(): + self.register_pending(event.event_id) + + proto = _paired_event_to_proto(event) + frame = pb.ClientFrame() + added = frame.paired_batch.events.add() + added.CopyFrom(proto) + + await self._send_frame(frame) + + async def close(self) -> None: + """Cancel background tasks and close the WebSocket. + + Sets ``_closing`` so any in-flight ``_handle_disconnect`` does not + spawn a reconnect during graceful shutdown. + + Defensive against the ``atexit`` shutdown path: ``adrian.shutdown`` + spawns a fresh ``asyncio.run`` loop after the user's loop has + already closed, so background tasks bound to the old loop can no + longer be cancelled cleanly (``call_soon`` raises + ``Event loop is closed``). Skip the cancel in that case, the + old loop is gone, the task will be reaped by GC. + """ + self._closing = True + + _safe_cancel(self._recv_task) + self._recv_task = None + _safe_cancel(self._connect_task) + self._connect_task = None + + if self._ws is not None: + with contextlib.suppress(Exception): + await asyncio.wait_for(self._ws.close(), timeout=2.0) + self._ws = None + + for fut in self._pending_verdicts.values(): + if not fut.done(): + _safe_cancel(fut) + self._pending_verdicts.clear() + + # -- Connection lifecycle -- + + def schedule_connect(self, loop: asyncio.AbstractEventLoop) -> None: + """Schedule :meth:`connect` as a background task on the given loop.""" + if self._connect_task is None or self._connect_task.done(): + self._connect_task = loop.create_task(self.connect()) + + async def connect(self) -> None: + """Establish the WebSocket with exponential-backoff retry. + + Heartbeat (``ping_interval`` / ``ping_timeout``) is configured on + the underlying ``websockets`` client; if the server fails to pong + within ``_PING_TIMEOUT`` the library closes the connection and + ``_recv_loop`` surfaces the disconnect via ``_handle_disconnect``. + + On a reconnect (``_disconnected_at`` set by a prior disconnect), + drains the replay buffer and fires ``on_reconnect``. Login is + deferred to ``_send_frame`` / ``_replay_buffer_to_ws`` so the + auto-detected provider/model is included. An ``api_key``, if + configured, is sent as an ``Authorization: Bearer `` header. + + Honours ``_next_reconnect_delay`` if a previous disconnect set + it (e.g. 4003 quota exhausted requests a slower retry). The + delay is consumed on the first attempt; subsequent failures + fall back to the standard exponential schedule. + """ + initial_delay = self._next_reconnect_delay + self._next_reconnect_delay = None + + if initial_delay is not None: + logger.info( + "delaying reconnect by %.0fs (server-requested)", + initial_delay, + ) + await asyncio.sleep(initial_delay) + + backoff = _INITIAL_BACKOFF + loop = asyncio.get_running_loop() + self._loop = loop + + headers: dict[str, str] = {} + + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + while True: + try: + self._ws = await websockets.connect( + self._url, + additional_headers=headers, + ping_interval=_PING_INTERVAL, + ping_timeout=_PING_TIMEOUT, + ) + self._connected.set() + self._recv_task = loop.create_task(self._recv_loop()) + + disconnected_at = self._disconnected_at + is_reconnect = disconnected_at is not None + if disconnected_at is not None: + downtime = time.monotonic() - disconnected_at + self._disconnected_at = None + logger.warning( + "WebSocket reconnected: %s (session_id=%s, downtime=%.2fs)", + self._url, + self._session_id, + downtime, + ) + + if self._replay_buffer_dropped > 0: + logger.warning( + "replay buffer dropped %d frames due to overflow " + "before this reconnect (session_id=%s); " + "increase replay_buffer_frames if this recurs", + self._replay_buffer_dropped, + self._session_id, + ) + else: + logger.info("WebSocket connected: %s", self._url) + + # Drain anything buffered while we were offline, even + # on the very first connect. ``_send_mcp_inventory`` + # and other init-time emitters queue frames before the + # WS is open; without this drain those frames never + # ship until something else triggers a live send. + if self._replay_buffer: + logger.info( + "replaying %d buffered frames after connect", + len(self._replay_buffer), + ) + await self._replay_buffer_to_ws() + + if is_reconnect: + await self._fire_on_reconnect() + + return + except Exception: + logger.warning( + "WebSocket connect to %s failed, retrying in %.0fs", + self._url, + backoff, + ) + try: + await asyncio.sleep(backoff) + except RuntimeError: + # Loop closed mid-retry (atexit shutdown). Bail out + # quietly rather than dumping a traceback. + return + backoff = min(backoff * 2, _MAX_BACKOFF) + + async def _send_login(self, ws: websockets.ClientConnection) -> None: + """Send the mandatory SessionLogin frame.""" + frame = pb.ClientFrame() + frame.login.session_id = self._session_id + frame.login.llm_stack.provider = self._provider + frame.login.llm_stack.model = self._model + frame.login.schema_version = SCHEMA_VERSION + await ws.send(frame.SerializeToString()) + logger.debug( + "Sent login (session=%s, provider=%s, model=%s, schema=%d)", + self._session_id, + self._provider, + self._model, + SCHEMA_VERSION, + ) + + async def _send_frame(self, frame: pb.ClientFrame) -> None: + """Serialise and send a ``ClientFrame``, buffering on failure. + + Happy path (connected + healthy): send over WS, bypass the ring + entirely, zero overhead. Offline on entry: buffer for replay. + During reconnect replay: buffer as well, so the drain loop picks + this frame up after the pre-outage tail (preserves order across + the outage boundary). Send raises: buffer the in-flight frame + then trigger ``_handle_disconnect`` so state is cleared and + reconnect is spawned. + """ + frame_bytes = frame.SerializeToString() + kind = frame.WhichOneof("frame") + + if not self._connected.is_set() or self._replaying: + self._buffer_frame(frame_bytes) + reason = "disconnected" if not self._connected.is_set() else "replaying" + logger.info( + "buffered for replay (session_id=%s, kind=%s, " + "buffer_size=%d, reason=%s)", + self._session_id, + kind, + len(self._replay_buffer), + reason, + ) + + return + + ws = self._ws + + if ws is None: + self._buffer_frame(frame_bytes) + + return + + try: + async with self._login_lock: + if not self._logged_in: + await self._send_login(ws) + self._logged_in = True + + await ws.send(frame_bytes) + logger.debug("Sent %s frame", kind) + except Exception: + # Send raised, we cannot confirm the server received this frame. + # Buffer it so the reconnect replay ships it, then clean up state. + self._buffer_frame(frame_bytes) + await self._handle_disconnect("send_failure") + + async def _recv_loop(self) -> None: + """Read ``ServerFrame``s, dispatch by oneof kind. + + First frame after each (re)login MUST be ``login_ack``; anything + else is a protocol error and we tear the connection down so the + reconnect path can try again. Subsequent frames are + ``verdict``s. Unknown oneof kinds (future server additions like + a quota-exhausted signal) are logged and dropped rather than + crashing the loop. + + Any exit path (clean close, exception, cancellation) calls + ``_handle_disconnect`` via ``finally`` so state is cleared and a + reconnect is spawned. + """ + ws = self._ws + + if ws is None: + return + + awaiting_login_ack = True + try: + async for message in ws: + if not isinstance(message, bytes): + continue + + frame = pb.ServerFrame() + frame.ParseFromString(message) + kind = frame.WhichOneof("frame") + + if awaiting_login_ack: + awaiting_login_ack = False + if kind != "login_ack": + logger.error( + "expected ServerFrame{login_ack} as first frame, " + "got %r, closing connection", + kind, + ) + return + + if kind == "login_ack": + self._on_login_ack(frame.login_ack) + elif kind == "verdict": + await self._on_verdict_frame(frame.verdict) + else: + logger.warning( + "ignoring unknown ServerFrame kind %r " + "(future server addition?)", + kind, + ) + except asyncio.CancelledError: + # Expected on graceful shutdown or when _handle_disconnect cancels + # us from the send_failure path. Re-raise to honour cancellation. + raise + except Exception as exc: + logger.warning("recv_loop exited: %s", exc) + finally: + close_code = getattr(ws, "close_code", None) + + if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE: + self._next_reconnect_delay = _QUOTA_RECONNECT_DELAY + + reason = ( + f"quota_exhausted (close={close_code})" + if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE + else "recv_loop_exit" + ) + await self._handle_disconnect(reason) + + def _on_login_ack(self, ack: pb.LoginAck) -> None: + """Apply the org's effective execution-mode policy. + + Fires the ``on_login_ack`` hook (if configured) as a fire-and-forget + task on the running loop so the recv loop doesn't block waiting on it. + """ + self._mode = ack.policy.mode + self._policy = ack.policy + self._login_ack_received.set() + logger.info( + "LoginAck received: mode=%s policy_m0=%s policy_m2=%s " + "policy_m3=%s policy_m4=%s", + pb.Mode.Name(ack.policy.mode), + ack.policy.policy_m0, + ack.policy.policy_m2, + ack.policy.policy_m3, + ack.policy.policy_m4, + ) + + if self._on_login_ack_cb is not None: + asyncio.create_task(self._run_login_ack_cb()) + + async def _run_login_ack_cb(self) -> None: + """Invoke the on_login_ack hook, swallowing exceptions.""" + if self._on_login_ack_cb is None: + return + try: + await self._on_login_ack_cb() + except Exception: + logger.exception("on_login_ack hook raised") + + async def _on_verdict_frame(self, verdict: pb.Verdict) -> None: + """Fire callbacks then resolve the matching pending future, if any. + + The future is left in ``_pending_verdicts`` after ``set_result`` so + a later ``register_pending`` (e.g. from the patched ToolNode after + the verdict has already round-tripped) returns the resolved + future and the wait completes immediately. ``wait_for_verdict`` + owns the cleanup: its ``finally`` pops the entry after the await + returns. + """ + logger.info( + "Verdict received: event_id=%s mad_code=%s mode=%s hitl=%s", + verdict.event_id, + verdict.mad_code or "-", + pb.Mode.Name(verdict.policy.mode), + verdict.HasField("hitl"), + ) + + if self._handler is not None: + await self._handler.handle_verdict(verdict) + + fut = self._pending_verdicts.get(verdict.event_id) + + if fut is None: + if verdict.HasField("hitl"): + logger.warning( + "HITL resolution for unknown event_id=%s, ignoring " + "(stale resolution from a prior SDK process)", + verdict.event_id, + ) + return + + if not fut.done(): + fut.set_result(verdict) + + # -- Resilience: buffering, replay, disconnect/reconnect -- + + def _buffer_frame(self, frame_bytes: bytes) -> None: + """Append a serialised frame to the replay ring. + + Tracks overflow drops and fires the one-shot "buffer full" WARN. + Called only from the offline or send-failure paths in + ``_send_frame``, the happy path bypasses the ring entirely. + """ + if len(self._replay_buffer) == self._replay_buffer.maxlen: + self._replay_buffer_dropped += 1 + + self._replay_buffer.append(frame_bytes) + + if ( + not self._replay_buffer_filled + and len(self._replay_buffer) == self._replay_buffer.maxlen + ): + self._replay_buffer_filled = True + logger.warning( + "adrian replay buffer reached capacity (%d frames); " + "further frames will evict oldest. Tune via " + "replay_buffer_frames or ADRIAN_REPLAY_BUFFER_FRAMES.", + self._replay_buffer.maxlen, + ) + + async def _replay_buffer_to_ws(self) -> None: + """Reissue buffered frames over the current WebSocket. + + Sends ``SessionLogin`` first if not already logged in (the server + requires it as the first frame on every new connection). Uses + ``_login_lock`` so a concurrent live send does not race on the + login check. + + Drains the deque one frame at a time via ``popleft`` inside a + ``while`` loop, rather than taking a snapshot up front. That + way, a live ``_send_frame`` call observed during the drain + routes its frame to the back of the same deque (because + ``_replaying`` is set) and this loop picks it up in the next + iteration, preserving across-outage order + ``[pre-outage] → [live during replay] → [post-replay live]``. + + On a mid-drain send failure, the failed frame is put back at + the front with ``appendleft`` and the function returns; the + next reconnect resumes from exactly where this one stopped. + """ + ws = self._ws + + if ws is None: + return + + self._replaying = True + try: + async with self._login_lock: + if not self._logged_in: + try: + await self._send_login(ws) + self._logged_in = True + except Exception as exc: + logger.warning( + "replay aborted: login send failed: %s", + exc, + ) + + return + + sent = 0 + while self._replay_buffer: + frame_bytes = self._replay_buffer.popleft() + try: + await ws.send(frame_bytes) + except Exception as exc: + # Put the failed frame back at the front so the next + # reconnect's drain resumes from exactly this point. + self._replay_buffer.appendleft(frame_bytes) + logger.warning( + "replay aborted after %d frame(s), %d remaining: %s", + sent, + len(self._replay_buffer), + exc, + ) + + return + sent += 1 + + logger.info("replayed %d buffered frames", sent) + self._replay_buffer_dropped = 0 + self._replay_buffer_filled = False + finally: + self._replaying = False + + async def _handle_disconnect(self, reason: str) -> None: + """Clear connection state and spawn a reconnect. + + Idempotent: if already disconnected or closing, returns immediately. + Pending verdict futures are intentionally left pending across the + disconnect, a late verdict after reconnect resolves them; if none + arrives, ``wait_for_verdict``'s timeout fires naturally. + """ + if self._closing or not self._connected.is_set(): + return + + self._connected.clear() + self._disconnected_at = time.monotonic() + + # Only cancel the recv task if we are not currently running inside it. + # When _recv_loop's own finally calls us, self._recv_task IS the + # current task, cancelling it would raise CancelledError inside the + # finally and prevent us from finishing disconnect handling. + current = asyncio.current_task() + + if self._recv_task is not None and self._recv_task is not current: + self._recv_task.cancel() + + self._recv_task = None + self._ws = None + self._logged_in = False + + logger.warning( + "disconnected (session_id=%s, reason=%s, pending_verdicts=%d)", + self._session_id, + reason, + len(self._pending_verdicts), + ) + + await self._fire_on_disconnect(reason) + + if self._closing: + return + + loop = asyncio.get_running_loop() + + if self._connect_task is None or self._connect_task.done(): + self._connect_task = loop.create_task(self.connect()) + + async def _fire_on_disconnect(self, reason: str) -> None: + """Invoke the on_disconnect callback, catching any exception.""" + if self._on_disconnect is None: + return + + try: + result = self._on_disconnect(reason) + + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception("on_disconnect callback raised") + + async def _fire_on_reconnect(self) -> None: + """Invoke the on_reconnect callback, catching any exception.""" + if self._on_reconnect is None: + return + + try: + result = self._on_reconnect() + + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception("on_reconnect callback raised") + + # -- Verdict-wait support -- + + def register_pending( + self, + event_id: str, + ) -> asyncio.Future[pb.Verdict]: + """Return a future awaiting a verdict for ``event_id``. + + Reuses an existing pending future if one is already registered, + so concurrent callers waiting on the same event_id see the same + verdict once it arrives. Must be called BEFORE sending the event + to avoid the race where the verdict arrives before the future exists. + """ + existing = self._pending_verdicts.get(event_id) + + if existing is not None: + return existing + + loop = asyncio.get_running_loop() + fut: asyncio.Future[pb.Verdict] = loop.create_future() + self._pending_verdicts[event_id] = fut + + return fut + + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + + async def wait_for_verdict( + self, + event_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for a verdict for ``event_id``. + + ``timeout`` is mode-derived (see :meth:`block_timeout`): + a positive float for ``MODE_BLOCK`` (fail-open at timeout), + ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the + verdict, or ``None`` on timeout (fail-open). + + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. + """ + fut = self.register_pending(event_id) + + try: + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result + except TimeoutError: + logger.warning( + "Verdict timeout for event_id=%s after %ss", + event_id, + timeout, + ) + # Timed-out future is useless — remove so a retry can + # register a fresh one. + self._pending_verdicts.pop(event_id, None) + return None + + async def wait_for_tool_verdict( + self, + parent_run_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for the verdict of the LLM pair that produced this tool call. + + Looks up the LLM event_id from the run_id map and awaits its verdict. + Returns ``None`` (fail-open) when the parent LLM has not been seen, + e.g. tools invoked outside an LLM flow. + """ + event_id = self._run_id_to_event_id.get(parent_run_id) + + if event_id is None: + logger.debug( + "No LLM context for parent_run_id=%s, skipping verdict wait", + parent_run_id, + ) + + return None + + return await self.wait_for_verdict(event_id, timeout) + + async def wait_for_tool_call_verdict( + self, + tool_call_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for the verdict of the LLM pair that emitted ``tool_call_id``. + + Every tool call in an AIMessage carries the id the LLM assigned + to it; that id is threaded through LangChain to the ToolNode + invocation. Looking it up against ``_tool_call_id_to_event_id`` + gives the producing LLM's event_id, correct under parallel + agents where a ``last_llm_event_id``-style global would race. + + Returns ``None`` (fail-open) when ``tool_call_id`` is empty or + unknown (direct ToolNode invocation, pre-LLM tool, or the LLM + pair that produced it was evicted from the LRU map). + """ + if not tool_call_id: + return None + + event_id = self._tool_call_id_to_event_id.get(tool_call_id) + + if event_id is None: + logger.debug( + "No LLM context for tool_call_id=%s, skipping verdict wait", + tool_call_id, + ) + + return None + + return await self.wait_for_verdict(event_id, timeout) diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 03b7fd4..d0d6d81 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -231,10 +231,12 @@ def init( resolved_key = api_key or os.getenv("ADRIAN_API_KEY") or None resolved_file = Path(os.getenv("ADRIAN_LOG_FILE", str(log_file))) - # Default to a local self-hosted backend (the one `make dev` brings - # up at deploy/compose.yaml). OSS users pointing at a remote - # deployment override via ws_url= or ADRIAN_WS_URL. - resolved_ws_url = os.getenv("ADRIAN_WS_URL") or ws_url or "ws://localhost:8080/ws" + # Default to the hosted Adrian backend so `adrian.init(api_key=...)` + # Just Works for freemium users. Self-hosted users override via + # ws_url= or ADRIAN_WS_URL. + resolved_ws_url = ( + os.getenv("ADRIAN_WS_URL") or ws_url or "wss://adrian.secureagentics.ai/ws" + ) resolved_session = ( os.getenv("ADRIAN_SESSION_ID") or session_id or resolve_session_id() ) @@ -520,6 +522,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_base_tool() + _patch_agent_executor() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -531,12 +535,14 @@ def _auto_instrument_langchain() -> None: def _patch_runnable() -> None: - """Patch ``Runnable.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(Runnable, "_adrian_patched", False): return original_invoke = Runnable.invoke original_ainvoke = Runnable.ainvoke + original_astream = Runnable.astream + original_stream = Runnable.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -544,9 +550,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync Runnable call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config, **kwargs) async def patched_ainvoke( @@ -555,13 +559,32 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async Runnable call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config, **kwargs) + Runnable.invoke = patched_invoke # type: ignore[assignment] Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] + Runnable.astream = patched_astream # type: ignore[assignment] + Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] logger.debug("Patched Runnable.invoke / ainvoke") @@ -634,12 +657,14 @@ def patched_configure( def _patch_chat_model() -> None: - """Patch ``BaseChatModel.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(BaseChatModel, "_adrian_chat_model_patched", False): return original_invoke = BaseChatModel.invoke original_ainvoke = BaseChatModel.ainvoke + original_astream = BaseChatModel.astream + original_stream = BaseChatModel.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -647,9 +672,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync chat model call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -658,13 +681,32 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async chat model call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config=config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config=config, **kwargs) + BaseChatModel.invoke = patched_invoke # type: ignore[assignment] BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseChatModel.astream = patched_astream # type: ignore[assignment] + BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] logger.debug("Patched BaseChatModel.invoke / ainvoke") @@ -760,29 +802,15 @@ async def patched_astream( # --- 5. ToolNode --- -def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage], +def _extract_tool_calls( # pyright: ignore[reportUnusedFunction] + state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, Any]]: - """Extract tool_calls from the ToolNode input. + """Extract tool_calls from ToolNode input (all three dispatch shapes). - ``ToolNode`` is reached with three input shapes: - 1. a state dict whose ``"messages"`` key holds the message list - (hand-built ``StateGraph`` with ``ToolNode`` as a node), or - 2. a bare list of messages, or - 3. a single per-tool-call dict ``{"__type", "tool_call", "state"}`` - — how langgraph-prebuilt / ``create_react_agent`` dispatch each - tool call. The id lives at ``input["tool_call"]["id"]``. - - Shape 3 was previously unhandled: the function returned ``[]``, so the - block/HITL gate never found a tool_call_id and ran the tool un-gated. - - Args: - state: The ToolNode input (any of the three shapes above). - - Returns: - List of tool call dicts, or an empty list when none is found. + Returns full tool_call dicts (with id, name, args) for backward + compat with tests and callers that need the full shape. """ - # Shape 3: per-tool-call dispatch (create_react_agent / prebuilt ToolNode). + # Shape 3: per-tool-call dict from _afunc dispatch if isinstance(state, dict) and "tool_call" in state: tc = state["tool_call"] if isinstance(tc, dict) and tc.get("id"): @@ -798,10 +826,13 @@ def _extract_tool_calls( ] return [] + # Shape 1/2: state dict or message list if isinstance(state, dict): messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - else: + elif isinstance(state, list): messages = list(state) + else: + return [] for msg in reversed(messages): if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): @@ -813,55 +844,28 @@ def _extract_tool_calls( def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override everything: ``continue_execution=False`` - means halt, ``True`` means continue. Otherwise the per-MAD policy - bool is the sole scope authority, if the verdict's tier is - in-scope, halt; if not, continue. + HITL resolutions override per-MAD policy when present. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution mad_prefix = verdict.mad_code[:2] - in_scope = { + return { "M0": verdict.policy.policy_m0, "M2": verdict.policy.policy_m2, "M3": verdict.policy.policy_m3, "M4": verdict.policy.policy_m4, }.get(mad_prefix, False) - return in_scope - - -def _build_blocked_response( - tool_calls: list[dict[str, str]], -) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls. - - Args: - tool_calls: List of tool call dicts extracted from the AIMessage. - - Returns: - Dict in the format ToolNode expects. - """ - blocked_messages: list[ToolMessage] = [ - ToolMessage( - content="[BLOCKED by security policy]", - tool_call_id=str(tc.get("id", "")), - name=str(tc.get("name", "")), - ) - for tc in tool_calls - ] - - return {"messages": blocked_messages} - def _patch_tool_node() -> None: - """Patch ``ToolNode.invoke`` / ``ainvoke``. + """Patch ToolNode for callback injection + async verdict gate. - In block mode, the async patch waits for the preceding LLM's verdict - before executing tools. On BLOCK (unless overridden by ``on_block``) - it returns synthetic ``ToolMessage`` responses instead of running the - tools. On timeout it fails open. + ToolNode dispatches tools via tool.invoke (sync) even within async + Pregel. BaseTool.invoke can't await a verdict from the event loop + thread, so we add the verdict gate here on ToolNode.ainvoke — the + entry point Pregel calls before tool dispatch begins. This is a + complementary gate to BaseTool (which covers direct callers). """ try: from langgraph.prebuilt import ToolNode @@ -873,43 +877,96 @@ def _patch_tool_node() -> None: original_invoke = ToolNode.invoke original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) def patched_invoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync ToolNode invocation.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - **kwargs: Any, + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks; in BLOCK / HITL modes wait for verdict. + config = _inject_callbacks(config) + # Verdict gate removed — BaseTool.ainvoke/arun is the single + # gate layer. Gating here too caused double-gate: ToolNode + # consumed the verdict future, BaseTool's gate registered a + # fresh future that never resolved → 30s timeout on a benign + # verdict. Callback injection is kept so events still flow. + return await original_ainvoke(self, input, config=config, **kwargs) - Per-tool-call correlation: every tool_call.id is mapped (in - ``WebSocketClient`` ) to the event_id of the LLM that emitted - it. Each ToolNode invocation awaits its specific LLM's verdict, - race-free under parallel agents, no graph-wide pause. - """ + async def patched_astream( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - ws = _ws_client + assert original_astream is not None # guarded by line below + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode.invoke / ainvoke / astream") + + +# --- 6. BaseTool (universal verdict gate) --- + + +_BLOCKED_CONTENT = "[BLOCKED by security policy]" + + +def _patch_base_tool() -> None: + """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. + + Every LangChain tool — whether dispatched by ToolNode, AgentExecutor, + create_react_agent, or a manual ``tool.invoke(tool_call)`` loop — + funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` + (async). Gating here covers all frameworks in one place. + + The gate extracts ``tool_call_id`` from the input (a ``ToolCall`` + TypedDict), awaits the classifier verdict for the producing LLM + event, and returns a ``[BLOCKED]`` string instead of running the + tool body when the verdict is in-scope (M3/M4 under MODE_BLOCK). + In MODE_BLOCK, verdict timeout is fail-closed (block the tool) + because the absence of a verdict in block mode is a policy violation. + In MODE_ALERT, no gate fires at all (skip). + """ + from langchain_core.tools import BaseTool + from langchain_core.tools.base import ( + _is_tool_call, # pyright: ignore[reportPrivateUsage] + ) + + if getattr(BaseTool, "_adrian_base_tool_patched", False): + return + + original_invoke = BaseTool.invoke + original_ainvoke = BaseTool.ainvoke + + def _extract_tool_call_id(input: Any) -> str | None: # noqa: A002, ANN401 + """Extract tool_call_id from a ToolCall input, or None.""" + if isinstance(input, dict) and _is_tool_call(input): + return input.get("id") + return None + + async def _async_gate(tool_call_id: str) -> bool: + """Returns True if the tool should be BLOCKED.""" + ws = _ws_client if ws is None: - return await original_ainvoke(self, input, config=config, **kwargs) + return False - # First-tool-call window: the recv loop may not have processed - # ``LoginAck`` yet, so ``policy_active()`` reads False even - # when the org is in BLOCK or HITL. Wait for the LoginAck - # event before checking. If it doesn't arrive within the - # window, halt, refusing to run is the only safe outcome - # when we can't verify the org's policy. if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] try: await asyncio.wait_for( @@ -918,36 +975,26 @@ async def patched_ainvoke( ) except TimeoutError: logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" + "BaseTool: LoginAck not received within 5s; " + "blocking tool (refusing to run without verified policy)" ) - return _build_blocked_response(_extract_tool_calls(input)) + return True if not ws.policy_active(): - return await original_ainvoke(self, input, config=config, **kwargs) - - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) - - if not tool_call_id: - # Direct ToolNode invocation outside an LLM flow, no - # producing event_id to wait on, so let the tool run. - return await original_ainvoke(self, input, config=config, **kwargs) + return False cfg = _get_config() timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) if verdict is None: + # Fail-closed in block mode: no verdict = block. logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", + "BaseTool: verdict timeout for tool_call_id=%s; " + "blocking (fail-closed in MODE_BLOCK)", tool_call_id, ) - return await original_ainvoke(self, input, config=config, **kwargs) + return True if _should_halt(verdict): logger.warning( @@ -955,11 +1002,200 @@ async def patched_ainvoke( verdict.event_id, verdict.mad_code, ) - return _build_blocked_response(tool_calls) + return True + + return False + + def _sync_gate(tool_call_id: str) -> bool: + """Sync verdict gate — works for pure-sync and worker-thread callers. + + Pure-sync (no event loop): runs ``_async_gate`` via + ``loop.run_until_complete``. + + Worker-thread (Pregel dispatches sync tools on a thread-pool + worker while the event loop runs on the main thread): bridges + the async gate to the main loop via ``run_coroutine_threadsafe`` + and blocks the worker thread until the verdict resolves. + + Event-loop thread (calling tool.invoke directly from async + code): cannot block — returns False (skip). The async path + (BaseTool.ainvoke) handles this case. + """ + ws = _ws_client + if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] + return False + + try: + loop = asyncio.get_event_loop() + except RuntimeError: + return False + + if not loop.is_running(): + # Pure-sync caller — safe to block + return loop.run_until_complete(_async_gate(tool_call_id)) + + # Check if we're on a worker thread (no running loop on THIS + # thread) vs the event-loop thread itself. + try: + asyncio.get_running_loop() + # We ARE on the event-loop thread — can't block it. + return False + except RuntimeError: + pass + + # Worker thread: bridge the async gate to the main loop. + main_loop = getattr(ws, "_loop", None) + if main_loop is None or not main_loop.is_running(): + return False + + try: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + future = asyncio.run_coroutine_threadsafe( + _async_gate(tool_call_id), main_loop + ) + return future.result(timeout=timeout if timeout else 60.0) + except Exception: + return False + + def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 + """Return a blocked response compatible with ToolNode. + + Returns a ToolMessage for create_react_agent / ToolNode + compatibility. Falls back to bare string on import failure. + """ + try: + return ToolMessage(content=_BLOCKED_CONTENT, tool_call_id=tc_id, name="") + except Exception: + return _BLOCKED_CONTENT + def patched_invoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + tc_id = _extract_tool_call_id(input) + if tc_id and _sync_gate(tc_id): + return _blocked_response(tc_id) + return original_invoke(self, input, config=config, **kwargs) + + async def patched_ainvoke( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + tc_id = _extract_tool_call_id(input) + if tc_id and await _async_gate(tc_id): + return _blocked_response(tc_id) return await original_ainvoke(self, input, config=config, **kwargs) - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode.invoke / ainvoke") + original_arun = BaseTool.arun + + async def patched_arun( + self: Any, # noqa: ANN401 + tool_input: Any, # noqa: ANN401 + *args: Any, + tool_call_id: str | None = None, + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Gate on arun — AgentExecutor calls tool.arun directly.""" + if tool_call_id and await _async_gate(tool_call_id): + return _blocked_response(tool_call_id) + return await original_arun( + self, tool_input, *args, tool_call_id=tool_call_id, **kwargs + ) + + BaseTool.invoke = patched_invoke # type: ignore[assignment] + BaseTool.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseTool.arun = patched_arun # type: ignore[assignment] + BaseTool._adrian_base_tool_patched = True # type: ignore[attr-defined] + logger.debug("Patched BaseTool.invoke / ainvoke / arun (universal verdict gate)") + + +# --- 7. AgentExecutor (tool_call_id on agent_action, not on tool.arun) --- + + +def _patch_agent_executor() -> None: + """Patch AgentExecutor._aperform_agent_action for the executor path. + + AgentExecutor calls tool.arun without forwarding tool_call_id, + so the BaseTool.arun gate can't extract it. The tool_call_id lives + on agent_action.tool_call_id (set by OpenAI-style parsers). We + intercept here, await the verdict, and return a blocked observation + instead of calling the tool. + """ + AgentExecutor = None + AgentStep = None + for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): + try: + mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) + AgentExecutor = getattr(mod, "AgentExecutor", None) + AgentStep = getattr(mod, "AgentStep", None) + if AgentExecutor and AgentStep: + break + except ImportError: + continue + + if AgentExecutor is None or AgentStep is None: + return + if getattr(AgentExecutor, "_adrian_executor_patched", False): + return + + original_aperform = AgentExecutor._aperform_agent_action + + async def patched_aperform( + self: Any, + name_to_tool_map: Any, + color_mapping: Any, # noqa: ANN401 + agent_action: Any, + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + tc_id = getattr(agent_action, "tool_call_id", None) + if tc_id: + ws = _ws_client + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; blocking" + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if ws.policy_active(): + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager + ) + + AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] + AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] + logger.debug("Patched AgentExecutor._aperform_agent_action") diff --git a/sdk/python/adrian/ws.py b/sdk/python/adrian/ws.py index 1ab5df4..169cbdc 100644 --- a/sdk/python/adrian/ws.py +++ b/sdk/python/adrian/ws.py @@ -52,6 +52,8 @@ _MAX_RUN_ID_MAP = 1024 # Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). _MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 _DEFAULT_REPLAY_BUFFER_FRAMES = 1000 @@ -254,6 +256,10 @@ def __init__( # Set by close() so _handle_disconnect knows not to spawn a reconnect # during a graceful shutdown. self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None # Futures awaited by the patched ToolNode.ainvoke when the # active mode requires a wait (BLOCK or HITL). Each resolves # with the matching ``Verdict`` proto. Futures survive a @@ -472,6 +478,7 @@ async def connect(self) -> None: backoff = _INITIAL_BACKOFF loop = asyncio.get_running_loop() + self._loop = loop headers: dict[str, str] = {} @@ -491,7 +498,6 @@ async def connect(self) -> None: disconnected_at = self._disconnected_at is_reconnect = disconnected_at is not None - if disconnected_at is not None: downtime = time.monotonic() - disconnected_at self._disconnected_at = None @@ -927,6 +933,18 @@ def register_pending( return fut + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + async def wait_for_verdict( self, event_id: str, @@ -939,25 +957,30 @@ async def wait_for_verdict( ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the verdict, or ``None`` on timeout (fail-open). - Cleans up the ``_pending_verdicts`` entry on either path: - ``_on_verdict_frame`` only resolves the future, the dict - ownership belongs here so a late ``register_pending`` after the - verdict has already arrived can still find the resolved future. + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. """ fut = self.register_pending(event_id) try: - return await asyncio.wait_for(fut, timeout=timeout) + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result except TimeoutError: logger.warning( "Verdict timeout for event_id=%s after %ss", event_id, timeout, ) - - return None - finally: + # Timed-out future is useless — remove so a retry can + # register a fresh one. self._pending_verdicts.pop(event_id, None) + return None async def wait_for_tool_verdict( self, diff --git a/sdk/python/tests/test_block_mode.py b/sdk/python/tests/test_block_mode.py index 0d1c352..0bbbdaf 100644 --- a/sdk/python/tests/test_block_mode.py +++ b/sdk/python/tests/test_block_mode.py @@ -142,10 +142,16 @@ async def test_looks_up_llm_event_id_and_resolves(self) -> None: class TestToolNodePatchBlocking: async def test_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None: - """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → halt with synthetic ToolMessage.""" + """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → BaseTool.ainvoke gate blocks. - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + The verdict gate lives on BaseTool (the universal layer), not + ToolNode.ainvoke. Uses an async tool so BaseTool.ainvoke (not + BaseTool.invoke) is the entry point — matching the production + path for create_react_agent with async tools. + """ + + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" _real_tool.called = True # type: ignore[attr-defined] return x @@ -180,6 +186,7 @@ def _real_tool(x: str) -> str: result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + # BaseTool.ainvoke gate blocks — tool body does NOT run. assert _real_tool.called is False # type: ignore[attr-defined] msgs = result["messages"] assert len(msgs) == 1 @@ -190,7 +197,7 @@ async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) @@ -226,11 +233,12 @@ def _real_tool(x: str) -> str: assert captured == ["hi"] - async def test_timeout_fail_open_runs_tool(self, tmp_path: Path) -> None: + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: + """Verdict timeout in MODE_BLOCK → fail-closed (tool does NOT run).""" captured: list[str] = [] - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" captured.append(x) return x @@ -248,7 +256,7 @@ def _real_tool(x: str) -> str: _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) ws._connected.set() ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" - # No pending future → wait_for_verdict times out → fail-open. + # No pending future → wait_for_verdict times out → fail-closed (MODE_BLOCK). tool_node = ToolNode([_real_tool]) ai = AIMessage( @@ -259,7 +267,8 @@ def _real_tool(x: str) -> str: await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] - assert captured == ["hi"] + # Fail-closed: tool should NOT have run. + assert captured == [] class TestModeAlert: @@ -268,7 +277,7 @@ async def test_alert_mode_skips_wait(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) diff --git a/sdk/python/tests/test_block_mode_races.py b/sdk/python/tests/test_block_mode_races.py index fa0ad57..16d8e4a 100644 --- a/sdk/python/tests/test_block_mode_races.py +++ b/sdk/python/tests/test_block_mode_races.py @@ -5,17 +5,17 @@ LLM calls; no running backend. Scenarios mirror the validated shapes from the multi-agent work: - S1 subagents-as-tools - director → worker (nested) - S2 handoffs - triage → specialist (sequential) - S3 router - parallel fan-out via Send() - S4 hierarchical - 3-level deep (director → team-lead → worker) - S5 custom workflow - deterministic + LLM nodes mixed - S6 swarm - back-and-forth handoffs (Alice ↔ Bob) - S7 supervisor - central dispatcher to N workers - S8 deep research - parallel researchers via asyncio.gather + S1 subagents-as-tools , director → worker (nested) + S2 handoffs , triage → specialist (sequential) + S3 router , parallel fan-out via Send() + S4 hierarchical , 3-level deep (director → team-lead → worker) + S5 custom workflow , deterministic + LLM nodes mixed + S6 swarm , back-and-forth handoffs (Alice ↔ Bob) + S7 supervisor , central dispatcher to N workers + S8 deep research , parallel researchers via asyncio.gather The invariant under test: for EVERY pattern, each ToolNode invocation -blocks on the verdict of the LLM that emitted its specific tool_call.id - +blocks on the verdict of the LLM that emitted its specific tool_call.id , never a sibling, never a parent, never a stale global. """ @@ -117,9 +117,9 @@ def _init_block_mode(tmp_path: Path, block_timeout: float = 1.0) -> Any: def _tool(name: str, captured: list[str]) -> Any: - """Build a named stub tool that records its argument.""" + """Build a named async stub tool that records its argument.""" - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(f"{name}:{x}") diff --git a/sdk/python/tests/test_exec_modes.py b/sdk/python/tests/test_exec_modes.py index 1ea8ae1..f3f5e42 100644 --- a/sdk/python/tests/test_exec_modes.py +++ b/sdk/python/tests/test_exec_modes.py @@ -61,7 +61,7 @@ def _cleanup() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] def _stub_tool(captured: list[str]) -> Any: # noqa: ANN401 - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(x)