From 0de5fed08de994b9ad3640e99e66113ce5b3bab3 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Tue, 2 Jun 2026 23:54:28 +0200 Subject: [PATCH 1/5] Add: add blocking gates for langgraph_stategraph and langchain_executor --- sdk/adrian/__init__.py | 443 +++++++++++++++++++++++++++++++---------- sdk/adrian/ws.py | 12 ++ 2 files changed, 350 insertions(+), 105 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 0b1ed09..1d312ab 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -512,6 +512,23 @@ def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 # ------------------------------------------------------------------ +def _warn_unsupported_frameworks() -> None: + """Warn when raw openai SDK is used without a patchable framework.""" + import importlib.util + + if importlib.util.find_spec("openai") is not None: + # Only warn if openai is present without langchain + if importlib.util.find_spec("langchain_core") is None: + logger.warning( + "Detected raw openai SDK without LangChain/LangGraph. " + "Adrian's pre-execution block gate requires a supported " + "framework (LangGraph ToolNode or LangChain AgentExecutor). " + "Tool calls from raw openai loops are OBSERVED but NOT " + "pre-blocked in MODE_BLOCK. " + "See https://docs.adrian.secureagentics.ai/supported-frameworks" + ) + + def _auto_instrument_langchain() -> None: """Apply all monkey-patches to LangChain / LangGraph.""" try: @@ -520,6 +537,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_agent_executor() + _warn_unsupported_frameworks() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -531,12 +550,14 @@ def _auto_instrument_langchain() -> None: def _patch_runnable() -> None: - """Patch ``Runnable.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``Runnable.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(Runnable, "_adrian_patched", False): return original_invoke = Runnable.invoke original_ainvoke = Runnable.ainvoke + original_astream = Runnable.astream + original_stream = Runnable.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -544,9 +565,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync Runnable call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config, **kwargs) async def patched_ainvoke( @@ -555,15 +574,35 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async Runnable call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """AgentExecutor calls astream on the agent chain by default.""" + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config, **kwargs) + Runnable.invoke = patched_invoke # type: ignore[assignment] Runnable.ainvoke = patched_ainvoke # type: ignore[assignment] + Runnable.astream = patched_astream # type: ignore[assignment] + Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] - logger.debug("Patched Runnable.invoke / ainvoke") + logger.debug("Patched Runnable.invoke / ainvoke / astream / stream") # --- 2. CallbackManager --- @@ -634,12 +673,14 @@ def patched_configure( def _patch_chat_model() -> None: - """Patch ``BaseChatModel.invoke`` / ``ainvoke`` to inject callbacks.""" + """Patch ``BaseChatModel.invoke`` / ``ainvoke`` / ``astream`` / ``stream``.""" if getattr(BaseChatModel, "_adrian_chat_model_patched", False): return original_invoke = BaseChatModel.invoke original_ainvoke = BaseChatModel.ainvoke + original_astream = BaseChatModel.astream + original_stream = BaseChatModel.stream def patched_invoke( self: Any, # noqa: ANN401 @@ -647,9 +688,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync chat model call.""" config = _inject_callbacks(config) - return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -658,15 +697,34 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into async chat model call.""" config = _inject_callbacks(config) - return await original_ainvoke(self, input, config=config, **kwargs) + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk + + def patched_stream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + yield from original_stream(self, input, config=config, **kwargs) + BaseChatModel.invoke = patched_invoke # type: ignore[assignment] BaseChatModel.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseChatModel.astream = patched_astream # type: ignore[assignment] + BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseChatModel.invoke / ainvoke") + logger.debug("Patched BaseChatModel.invoke / ainvoke / astream / stream") # --- 4. LangGraph Pregel --- @@ -761,26 +819,22 @@ async def patched_astream( def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage], + state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, str]]: - """Extract tool_calls from the last AIMessage in ToolNode state. - - LangGraph's ``ToolNode.ainvoke`` accepts two input shapes: a state - dict whose ``"messages"`` key holds the message list, or a bare - list of messages. We handle both. - - Args: - state: The ToolNode input, a state dict with a ``"messages"`` - key, or a direct list of ``BaseMessage`` instances. + """Extract tool_calls from ToolNode input (state dict, message list, or per-tool-call dict).""" + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + if isinstance(tc, dict) and tc.get("id"): + return [tc] + if hasattr(tc, "id") and tc.id: + return [{"id": tc.id, "name": getattr(tc, "name", ""), "args": getattr(tc, "args", {})}] - Returns: - List of tool call dicts from the most recent ``AIMessage``, or - an empty list when none is found. - """ if isinstance(state, dict): messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - else: + elif isinstance(state, list): messages = list(state) + else: + return [] for msg in reversed(messages): if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): @@ -792,10 +846,7 @@ def _extract_tool_calls( def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override everything: ``continue_execution=False`` - means halt, ``True`` means continue. Otherwise the per-MAD policy - bool is the sole scope authority, if the verdict's tier is - in-scope, halt; if not, continue. + HITL resolutions override the per-MAD policy scope check. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution @@ -814,14 +865,7 @@ def _should_halt(verdict: pb.Verdict) -> bool: def _build_blocked_response( tool_calls: list[dict[str, str]], ) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls. - - Args: - tool_calls: List of tool call dicts extracted from the AIMessage. - - Returns: - Dict in the format ToolNode expects. - """ + """Build synthetic ToolMessage responses for blocked tool calls.""" blocked_messages: list[ToolMessage] = [ ToolMessage( content="[BLOCKED by security policy]", @@ -834,13 +878,67 @@ def _build_blocked_response( return {"messages": blocked_messages} +async def _adrian_tool_gate( + input: Any, # noqa: A002, ANN401 +) -> tuple[str, dict[str, Any] | None]: + """Pre-execution verdict gate. Returns ("halt", response), ("proceed", None), or ("skip", None).""" + ws = _ws_client + + if ws is None: + return ("skip", None) + + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "ToolNode: LoginAck not received within 5s; halting " + "(refusing to run a tool without a verified policy)" + ) + return ("halt", _build_blocked_response(_extract_tool_calls(input))) + + if not ws.policy_active(): + return ("skip", None) + + tool_calls = _extract_tool_calls(input) + tool_call_id = next( + (tc.get("id") for tc in tool_calls if tc.get("id")), + None, + ) + + if not tool_call_id: + return ("skip", None) + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + + verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) + + if verdict is None: + logger.warning( + "verdict timeout for tool_call_id=%s, fail-open", + tool_call_id, + ) + return ("skip", None) + + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return ("halt", _build_blocked_response(tool_calls)) + + return ("proceed", None) + + def _patch_tool_node() -> None: - """Patch ``ToolNode.invoke`` / ``ainvoke``. + """Patch ToolNode._afunc with the verdict gate, and public methods for callback injection. - In block mode, the async patch waits for the preceding LLM's verdict - before executing tools. On BLOCK (unless overridden by ``on_block``) - it returns synthetic ``ToolMessage`` responses instead of running the - tools. On timeout it fails open. + _afunc is the only reliable intercept -- Pregel bypasses ainvoke/astream entirely. """ try: from langgraph.prebuilt import ToolNode @@ -852,6 +950,22 @@ def _patch_tool_node() -> None: original_invoke = ToolNode.invoke original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) + original_stream = getattr(ToolNode, "stream", None) + original_afunc = ToolNode._afunc # type: ignore[attr-defined] + + async def patched_afunc( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + runtime: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate on ToolNode._afunc.""" + decision, blocked = await _adrian_tool_gate(input) + if decision == "halt": + return blocked + + return await original_afunc(self, input, config=config, runtime=runtime) def patched_invoke( self: Any, # noqa: ANN401 @@ -859,7 +973,7 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into sync ToolNode invocation.""" + """Inject Adrian callbacks into ToolNode.invoke.""" config = _inject_callbacks(config) return original_invoke(self, input, config=config, **kwargs) @@ -870,75 +984,194 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks; in BLOCK / HITL modes wait for verdict. + """Inject Adrian callbacks into ToolNode.ainvoke.""" + config = _inject_callbacks(config) - Per-tool-call correlation: every tool_call.id is mapped (in - ``WebSocketClient`` ) to the event_id of the LLM that emitted - it. Each ToolNode invocation awaits its specific LLM's verdict, - race-free under parallel agents, no graph-wide pause. - """ + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, # noqa: ANN401 + input: Any, # noqa: A002, ANN401 + config: Any = None, # noqa: ANN401 + **kwargs: Any, + ) -> Any: # noqa: ANN401 + """Inject Adrian callbacks into ToolNode.astream.""" config = _inject_callbacks(config) - ws = _ws_client - if ws is None: - return await original_ainvoke(self, input, config=config, **kwargs) + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk - # First-tool-call window: the recv loop may not have processed - # ``LoginAck`` yet, so ``policy_active()`` reads False even - # when the org is in BLOCK or HITL. Wait for the LoginAck - # event before checking. If it doesn't arrive within the - # window, halt, refusing to run is the only safe outcome - # when we can't verify the org's policy. - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" - ) - return _build_blocked_response(_extract_tool_calls(input)) - - if not ws.policy_active(): - return await original_ainvoke(self, input, config=config, **kwargs) + ToolNode._afunc = patched_afunc # type: ignore[attr-defined] + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode._afunc / invoke / ainvoke / astream") - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) - if not tool_call_id: - # Direct ToolNode invocation outside an LLM flow, no - # producing event_id to wait on, so let the tool run. - return await original_ainvoke(self, input, config=config, **kwargs) +# --- 6. AgentExecutor (langchain / langchain-classic) --- - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) +_BLOCKED_OBSERVATION = "[BLOCKED by security policy]" - if verdict is None: - logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", - tool_call_id, - ) - return await original_ainvoke(self, input, config=config, **kwargs) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return _build_blocked_response(tool_calls) +def _patch_agent_executor() -> None: + """Patch AgentExecutor tool dispatch with the verdict gate. - return await original_ainvoke(self, input, config=config, **kwargs) + Covers the legacy AgentExecutor path which bypasses ToolNode entirely. + Falls through for ReAct parsers that don't emit tool_call_id. + """ + AgentExecutor = None + AgentStep = None + for mod_path in ("langchain_classic.agents.agent", "langchain.agents.agent"): + try: + mod = __import__(mod_path, fromlist=["AgentExecutor", "AgentStep"]) + AgentExecutor = getattr(mod, "AgentExecutor", None) + AgentStep = getattr(mod, "AgentStep", None) + if AgentExecutor and AgentStep: + break + except ImportError: + continue + + if AgentExecutor is None or AgentStep is None: + return - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode.invoke / ainvoke") + if getattr(AgentExecutor, "_adrian_executor_patched", False): + return + + original_aperform = AgentExecutor._aperform_agent_action + original_perform = AgentExecutor._perform_agent_action + + async def patched_aperform( + self: Any, # noqa: ANN401 + name_to_tool_map: Any, # noqa: ANN401 + color_mapping: Any, # noqa: ANN401 + agent_action: Any, # noqa: ANN401 + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate before AgentExecutor dispatches a tool (async).""" + tool_call_id = getattr(agent_action, "tool_call_id", None) + + if tool_call_id: + ws = _ws_client + + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; " + "blocking tool %s", + agent_action.tool, + ) + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + + if ws.policy_active(): + cfg = _get_config() + # Short timeout: AgentExecutor LLM callbacks may not propagate, + # so verdicts may never arrive. + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict( + tool_call_id, timeout, + ) + + if verdict is not None and _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s " + "mad_code=%s (AgentExecutor path)", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for " + "tool_call_id=%s, fail-open", + tool_call_id, + ) + + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager, + ) + + def patched_perform( + self: Any, # noqa: ANN401 + name_to_tool_map: Any, # noqa: ANN401 + color_mapping: Any, # noqa: ANN401 + agent_action: Any, # noqa: ANN401 + run_manager: Any = None, # noqa: ANN401 + ) -> Any: # noqa: ANN401 + """Verdict gate before AgentExecutor dispatches a tool (sync).""" + tool_call_id = getattr(agent_action, "tool_call_id", None) + + if tool_call_id: + ws = _ws_client + + if ws is not None and ws._login_ack_received.is_set() and ws.policy_active(): # pyright: ignore[reportPrivateUsage] + import concurrent.futures + + async def _gate() -> bool: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict( + tool_call_id, timeout, + ) + if verdict is not None and _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s " + "mad_code=%s (AgentExecutor sync path)", + verdict.event_id, + verdict.mad_code, + ) + return True + return False + + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + future: concurrent.futures.Future[bool] = concurrent.futures.Future() + + async def _run() -> None: + try: + result = await _gate() + future.set_result(result) + except Exception as exc: + future.set_exception(exc) + + loop.create_task(_run()) + should_block = future.result(timeout=35) + else: + should_block = loop.run_until_complete(_gate()) + + if should_block: + return AgentStep( + action=agent_action, + observation=_BLOCKED_OBSERVATION, + ) + except Exception: + logger.debug( + "AgentExecutor sync gate failed, falling through", + exc_info=True, + ) + + return original_perform( + self, name_to_tool_map, color_mapping, agent_action, run_manager, + ) + + AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] + AgentExecutor._perform_agent_action = patched_perform # type: ignore[assignment] + AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] + logger.debug("Patched AgentExecutor._aperform_agent_action / _perform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 1ab5df4..30f1ab5 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -513,6 +513,18 @@ async def connect(self) -> None: else: logger.info("WebSocket connected: %s", self._url) + # Eager login: send the SessionLogin frame immediately + # so the server responds with LoginAck before any tool + # gate fires. Previously login was deferred to the + # first _send_frame call, which meant frameworks that + # don't trigger paired events (AgentExecutor) would + # never receive LoginAck and the block gate would time + # out. Provider/model are best-effort at this point + # (empty until the first LLM event auto-detects them). + if not self._logged_in: + await self._send_login(self._ws) + self._logged_in = True + # Drain anything buffered while we were offline, even # on the very first connect. ``_send_mcp_inventory`` # and other init-time emitters queue frames before the From 912456d1a5ccae9b6e1c0327c3195e99aa7831fa Mon Sep 17 00:00:00 2001 From: netan-sa Date: Tue, 9 Jun 2026 14:48:10 +0200 Subject: [PATCH 2/5] Add: agent executor flow, and fixes --- sdk/adrian/__init__.py | 561 ++++++++++++++++++++--------------------- sdk/adrian/ws.py | 12 - 2 files changed, 276 insertions(+), 297 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 1d312ab..1b55d83 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -74,7 +74,7 @@ from adrian.types import ToolCallRecord, VerdictContext from adrian.ws import WebSocketClient -__version__ = "1.0.0" +__version__ = "1.0.2" __all__ = [ "init", "shutdown", @@ -231,10 +231,12 @@ def init( resolved_key = api_key or os.getenv("ADRIAN_API_KEY") or None resolved_file = Path(os.getenv("ADRIAN_LOG_FILE", str(log_file))) - # Default to a local self-hosted backend (the one `make dev` brings - # up at deploy/compose.yaml). OSS users pointing at a remote - # deployment override via ws_url= or ADRIAN_WS_URL. - resolved_ws_url = os.getenv("ADRIAN_WS_URL") or ws_url or "ws://localhost:8080/ws" + # Default to the hosted Adrian backend so `adrian.init(api_key=...)` + # Just Works for freemium users. Self-hosted users override via + # ws_url= or ADRIAN_WS_URL. + resolved_ws_url = ( + os.getenv("ADRIAN_WS_URL") or ws_url or "wss://adrian.secureagentics.ai/ws" + ) resolved_session = ( os.getenv("ADRIAN_SESSION_ID") or session_id or resolve_session_id() ) @@ -512,23 +514,6 @@ def _inject_callbacks(config: Any) -> Any: # noqa: ANN401 # ------------------------------------------------------------------ -def _warn_unsupported_frameworks() -> None: - """Warn when raw openai SDK is used without a patchable framework.""" - import importlib.util - - if importlib.util.find_spec("openai") is not None: - # Only warn if openai is present without langchain - if importlib.util.find_spec("langchain_core") is None: - logger.warning( - "Detected raw openai SDK without LangChain/LangGraph. " - "Adrian's pre-execution block gate requires a supported " - "framework (LangGraph ToolNode or LangChain AgentExecutor). " - "Tool calls from raw openai loops are OBSERVED but NOT " - "pre-blocked in MODE_BLOCK. " - "See https://docs.adrian.secureagentics.ai/supported-frameworks" - ) - - def _auto_instrument_langchain() -> None: """Apply all monkey-patches to LangChain / LangGraph.""" try: @@ -537,8 +522,8 @@ def _auto_instrument_langchain() -> None: _patch_chat_model() _patch_langgraph() _patch_tool_node() + _patch_base_tool() _patch_agent_executor() - _warn_unsupported_frameworks() logger.debug("LangChain auto-instrumentation applied") except ImportError: logger.debug("LangChain not found, skipping auto-instrumentation") @@ -583,7 +568,6 @@ async def patched_astream( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """AgentExecutor calls astream on the agent chain by default.""" config = _inject_callbacks(config) async for chunk in original_astream(self, input, config, **kwargs): yield chunk @@ -602,7 +586,7 @@ def patched_stream( Runnable.astream = patched_astream # type: ignore[assignment] Runnable.stream = patched_stream # type: ignore[assignment] Runnable._adrian_patched = True # type: ignore[attr-defined] - logger.debug("Patched Runnable.invoke / ainvoke / astream / stream") + logger.debug("Patched Runnable.invoke / ainvoke") # --- 2. CallbackManager --- @@ -724,7 +708,7 @@ def patched_stream( BaseChatModel.astream = patched_astream # type: ignore[assignment] BaseChatModel.stream = patched_stream # type: ignore[assignment] BaseChatModel._adrian_chat_model_patched = True # type: ignore[attr-defined] - logger.debug("Patched BaseChatModel.invoke / ainvoke / astream / stream") + logger.debug("Patched BaseChatModel.invoke / ainvoke") # --- 4. LangGraph Pregel --- @@ -815,157 +799,256 @@ async def patched_astream( logger.debug("Patched Pregel.invoke / ainvoke / astream") -# --- 5. ToolNode --- - - -def _extract_tool_calls( - state: dict[str, Any] | list[BaseMessage] | Any, -) -> list[dict[str, str]]: - """Extract tool_calls from ToolNode input (state dict, message list, or per-tool-call dict).""" - if isinstance(state, dict) and "tool_call" in state: - tc = state["tool_call"] - if isinstance(tc, dict) and tc.get("id"): - return [tc] - if hasattr(tc, "id") and tc.id: - return [{"id": tc.id, "name": getattr(tc, "name", ""), "args": getattr(tc, "args", {})}] - - if isinstance(state, dict): - messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - elif isinstance(state, list): - messages = list(state) - else: - return [] - - for msg in reversed(messages): - if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): - return msg.tool_calls # type: ignore[no-any-return] - - return [] +# --- 5. ToolNode (callback injection only — gate is on BaseTool) --- def _should_halt(verdict: pb.Verdict) -> bool: """Decide whether a verdict should halt tool execution. - HITL resolutions override the per-MAD policy scope check. + HITL resolutions override per-MAD policy when present. """ if verdict.HasField("hitl"): return not verdict.hitl.continue_execution mad_prefix = verdict.mad_code[:2] - in_scope = { + return { "M0": verdict.policy.policy_m0, "M2": verdict.policy.policy_m2, "M3": verdict.policy.policy_m3, "M4": verdict.policy.policy_m4, }.get(mad_prefix, False) - return in_scope - -def _build_blocked_response( - tool_calls: list[dict[str, str]], -) -> dict[str, list[ToolMessage]]: - """Build synthetic ToolMessage responses for blocked tool calls.""" - blocked_messages: list[ToolMessage] = [ - ToolMessage( - content="[BLOCKED by security policy]", - tool_call_id=str(tc.get("id", "")), - name=str(tc.get("name", "")), - ) - for tc in tool_calls - ] +def _patch_tool_node() -> None: + """Patch ToolNode for callback injection + async verdict gate. - return {"messages": blocked_messages} + ToolNode dispatches tools via tool.invoke (sync) even within async + Pregel. BaseTool.invoke can't await a verdict from the event loop + thread, so we add the verdict gate here on ToolNode.ainvoke — the + entry point Pregel calls before tool dispatch begins. This is a + complementary gate to BaseTool (which covers direct callers). + """ + try: + from langgraph.prebuilt import ToolNode + except ImportError: + return + if getattr(ToolNode, "_adrian_tool_node_patched", False): + return -async def _adrian_tool_gate( - input: Any, # noqa: A002, ANN401 -) -> tuple[str, dict[str, Any] | None]: - """Pre-execution verdict gate. Returns ("halt", response), ("proceed", None), or ("skip", None).""" - ws = _ws_client + original_invoke = ToolNode.invoke + original_ainvoke = ToolNode.ainvoke + original_astream = getattr(ToolNode, "astream", None) - if ws is None: - return ("skip", None) + def _extract_tool_call_ids(state: Any) -> list[str]: # noqa: ANN401 + """Extract tool_call_ids from ToolNode input (any shape).""" + # Shape 3: per-tool-call dict from _afunc dispatch + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + return [tc_id] if tc_id else [] + # Shape 1/2: state dict or message list + messages = ( + list(state.get("messages") or []) + if isinstance(state, dict) + else list(state) + if isinstance(state, list) + else [] + ) + for msg in reversed(messages): + if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): + return [tc.get("id") for tc in msg.tool_calls if tc.get("id")] + return [] - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: + async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 + """Returns True if tools should be BLOCKED.""" + ws = _ws_client + if ws is None: + return False + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for(ws._login_ack_received.wait(), timeout=5.0) # pyright: ignore[reportPrivateUsage] + except TimeoutError: + logger.warning("ToolNode: LoginAck not received within 5s; blocking") + return True + if not ws.policy_active(): + return False + + tc_ids = _extract_tool_call_ids(state) + if not tc_ids: + return False + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + # Gate on the first tool_call_id (all come from the same LLM turn) + verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) + if verdict is None: + logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") + return True + if _should_halt(verdict): logger.warning( - "ToolNode: LoginAck not received within 5s; halting " - "(refusing to run a tool without a verified policy)" + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, ) - return ("halt", _build_blocked_response(_extract_tool_calls(input))) + return True + return False + + def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 + tc_ids = _extract_tool_call_ids(state) + return { + "messages": [ + ToolMessage( + content="[BLOCKED by security policy]", tool_call_id=tid, name="" + ) + for tid in tc_ids + ] + } - if not ws.policy_active(): - return ("skip", None) + def patched_invoke( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + return original_invoke(self, input, config=config, **kwargs) - tool_calls = _extract_tool_calls(input) - tool_call_id = next( - (tc.get("id") for tc in tool_calls if tc.get("id")), - None, - ) + async def patched_ainvoke( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + if await _gate_tool_calls(input): + return _build_blocked(input) + return await original_ainvoke(self, input, config=config, **kwargs) + + async def patched_astream( + self: Any, + input: Any, + config: Any = None, + **kwargs: Any, # noqa: A002, ANN401 + ) -> Any: # noqa: ANN401 + config = _inject_callbacks(config) + if await _gate_tool_calls(input): + yield _build_blocked(input) + return + async for chunk in original_astream(self, input, config=config, **kwargs): + yield chunk - if not tool_call_id: - return ("skip", None) + ToolNode.invoke = patched_invoke # type: ignore[assignment] + ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] + if original_astream is not None: + ToolNode.astream = patched_astream # type: ignore[assignment] + ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] + logger.debug("Patched ToolNode.invoke / ainvoke / astream") - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) +# --- 6. BaseTool (universal verdict gate) --- - if verdict is None: - logger.warning( - "verdict timeout for tool_call_id=%s, fail-open", - tool_call_id, - ) - return ("skip", None) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return ("halt", _build_blocked_response(tool_calls)) +_BLOCKED_CONTENT = "[BLOCKED by security policy]" - return ("proceed", None) +def _patch_base_tool() -> None: + """Patch ``BaseTool.invoke`` and ``BaseTool.ainvoke`` with the verdict gate. -def _patch_tool_node() -> None: - """Patch ToolNode._afunc with the verdict gate, and public methods for callback injection. + Every LangChain tool — whether dispatched by ToolNode, AgentExecutor, + create_react_agent, or a manual ``tool.invoke(tool_call)`` loop — + funnels through ``BaseTool.invoke`` (sync) or ``BaseTool.ainvoke`` + (async). Gating here covers all frameworks in one place. + + The gate extracts ``tool_call_id`` from the input (a ``ToolCall`` + TypedDict), awaits the classifier verdict for the producing LLM + event, and returns a ``[BLOCKED]`` string instead of running the + tool body when the verdict is in-scope (M3/M4 under MODE_BLOCK). - _afunc is the only reliable intercept -- Pregel bypasses ainvoke/astream entirely. + In MODE_BLOCK, verdict timeout is fail-closed (block the tool) + because the absence of a verdict in block mode is a policy violation. + In MODE_ALERT, no gate fires at all (skip). """ - try: - from langgraph.prebuilt import ToolNode - except ImportError: - return + from langchain_core.tools import BaseTool + from langchain_core.tools.base import _is_tool_call # pyright: ignore[reportPrivateUsage] - if getattr(ToolNode, "_adrian_tool_node_patched", False): + if getattr(BaseTool, "_adrian_base_tool_patched", False): return - original_invoke = ToolNode.invoke - original_ainvoke = ToolNode.ainvoke - original_astream = getattr(ToolNode, "astream", None) - original_stream = getattr(ToolNode, "stream", None) - original_afunc = ToolNode._afunc # type: ignore[attr-defined] + original_invoke = BaseTool.invoke + original_ainvoke = BaseTool.ainvoke - async def patched_afunc( - self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 - runtime: Any = None, # noqa: ANN401 - ) -> Any: # noqa: ANN401 - """Verdict gate on ToolNode._afunc.""" - decision, blocked = await _adrian_tool_gate(input) - if decision == "halt": - return blocked + def _extract_tool_call_id(input: Any) -> str | None: # noqa: A002, ANN401 + """Extract tool_call_id from a ToolCall input, or None.""" + if isinstance(input, dict) and _is_tool_call(input): + return input.get("id") + return None + + async def _async_gate(tool_call_id: str) -> bool: + """Returns True if the tool should be BLOCKED.""" + ws = _ws_client + if ws is None: + return False + + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "BaseTool: LoginAck not received within 5s; " + "blocking tool (refusing to run without verified policy)" + ) + return True + + if not ws.policy_active(): + return False + + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tool_call_id, timeout) + + if verdict is None: + # Fail-closed in block mode: no verdict = block. + logger.warning( + "BaseTool: verdict timeout for tool_call_id=%s; " + "blocking (fail-closed in MODE_BLOCK)", + tool_call_id, + ) + return True + + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return True + + return False + + def _sync_gate(tool_call_id: str) -> bool: + """Sync verdict gate for pure-sync callers (no running event loop). + + When called from within an async loop (ToolNode._func dispatched + by Pregel), this cannot work — use the ToolNode.ainvoke gate + instead. Returns False (skip) when a running loop is detected. + """ + ws = _ws_client + if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] + return False - return await original_afunc(self, input, config=config, runtime=runtime) + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # Can't block the event loop thread. The ToolNode.ainvoke + # gate handles this path. + return False + return loop.run_until_complete(_async_gate(tool_call_id)) + except RuntimeError: + return False def patched_invoke( self: Any, # noqa: ANN401 @@ -973,9 +1056,10 @@ def patched_invoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into ToolNode.invoke.""" config = _inject_callbacks(config) - + tc_id = _extract_tool_call_id(input) + if tc_id and _sync_gate(tc_id): + return _BLOCKED_CONTENT return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -984,43 +1068,46 @@ async def patched_ainvoke( config: Any = None, # noqa: ANN401 **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into ToolNode.ainvoke.""" config = _inject_callbacks(config) - + tc_id = _extract_tool_call_id(input) + if tc_id and await _async_gate(tc_id): + return _BLOCKED_CONTENT return await original_ainvoke(self, input, config=config, **kwargs) - async def patched_astream( + original_arun = BaseTool.arun + + async def patched_arun( self: Any, # noqa: ANN401 - input: Any, # noqa: A002, ANN401 - config: Any = None, # noqa: ANN401 + tool_input: Any, # noqa: ANN401 + *args: Any, + tool_call_id: str | None = None, **kwargs: Any, ) -> Any: # noqa: ANN401 - """Inject Adrian callbacks into ToolNode.astream.""" - config = _inject_callbacks(config) - - async for chunk in original_astream(self, input, config=config, **kwargs): - yield chunk - - ToolNode._afunc = patched_afunc # type: ignore[attr-defined] - ToolNode.invoke = patched_invoke # type: ignore[assignment] - ToolNode.ainvoke = patched_ainvoke # type: ignore[assignment] - if original_astream is not None: - ToolNode.astream = patched_astream # type: ignore[assignment] - ToolNode._adrian_tool_node_patched = True # type: ignore[attr-defined] - logger.debug("Patched ToolNode._afunc / invoke / ainvoke / astream") - + """Gate on arun — AgentExecutor calls tool.arun directly.""" + if tool_call_id and await _async_gate(tool_call_id): + return _BLOCKED_CONTENT + return await original_arun( + self, tool_input, *args, tool_call_id=tool_call_id, **kwargs + ) -# --- 6. AgentExecutor (langchain / langchain-classic) --- + BaseTool.invoke = patched_invoke # type: ignore[assignment] + BaseTool.ainvoke = patched_ainvoke # type: ignore[assignment] + BaseTool.arun = patched_arun # type: ignore[assignment] + BaseTool._adrian_base_tool_patched = True # type: ignore[attr-defined] + logger.debug("Patched BaseTool.invoke / ainvoke / arun (universal verdict gate)") -_BLOCKED_OBSERVATION = "[BLOCKED by security policy]" +# --- 7. AgentExecutor (tool_call_id on agent_action, not on tool.arun) --- def _patch_agent_executor() -> None: - """Patch AgentExecutor tool dispatch with the verdict gate. + """Patch AgentExecutor._aperform_agent_action for the executor path. - Covers the legacy AgentExecutor path which bypasses ToolNode entirely. - Falls through for ReAct parsers that don't emit tool_call_id. + AgentExecutor calls tool.arun without forwarding tool_call_id, + so the BaseTool.arun gate can't extract it. The tool_call_id lives + on agent_action.tool_call_id (set by OpenAI-style parsers). We + intercept here, await the verdict, and return a blocked observation + instead of calling the tool. """ AgentExecutor = None AgentStep = None @@ -1036,142 +1123,46 @@ def _patch_agent_executor() -> None: if AgentExecutor is None or AgentStep is None: return - if getattr(AgentExecutor, "_adrian_executor_patched", False): return original_aperform = AgentExecutor._aperform_agent_action - original_perform = AgentExecutor._perform_agent_action async def patched_aperform( - self: Any, # noqa: ANN401 - name_to_tool_map: Any, # noqa: ANN401 - color_mapping: Any, # noqa: ANN401 - agent_action: Any, # noqa: ANN401 - run_manager: Any = None, # noqa: ANN401 - ) -> Any: # noqa: ANN401 - """Verdict gate before AgentExecutor dispatches a tool (async).""" - tool_call_id = getattr(agent_action, "tool_call_id", None) - - if tool_call_id: - ws = _ws_client - - if ws is not None: - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for( - ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] - timeout=5.0, - ) - except TimeoutError: - logger.warning( - "AgentExecutor: LoginAck not received within 5s; " - "blocking tool %s", - agent_action.tool, - ) - return AgentStep( - action=agent_action, - observation=_BLOCKED_OBSERVATION, - ) - - if ws.policy_active(): - cfg = _get_config() - # Short timeout: AgentExecutor LLM callbacks may not propagate, - # so verdicts may never arrive. - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict( - tool_call_id, timeout, - ) - - if verdict is not None and _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s " - "mad_code=%s (AgentExecutor path)", - verdict.event_id, - verdict.mad_code, - ) - return AgentStep( - action=agent_action, - observation=_BLOCKED_OBSERVATION, - ) - - if verdict is None: - logger.warning( - "AgentExecutor: verdict timeout for " - "tool_call_id=%s, fail-open", - tool_call_id, - ) - - return await original_aperform( - self, name_to_tool_map, color_mapping, agent_action, run_manager, - ) - - def patched_perform( - self: Any, # noqa: ANN401 - name_to_tool_map: Any, # noqa: ANN401 + self: Any, + name_to_tool_map: Any, color_mapping: Any, # noqa: ANN401 - agent_action: Any, # noqa: ANN401 + agent_action: Any, run_manager: Any = None, # noqa: ANN401 ) -> Any: # noqa: ANN401 - """Verdict gate before AgentExecutor dispatches a tool (sync).""" - tool_call_id = getattr(agent_action, "tool_call_id", None) - - if tool_call_id: + tc_id = getattr(agent_action, "tool_call_id", None) + if tc_id: ws = _ws_client - - if ws is not None and ws._login_ack_received.is_set() and ws.policy_active(): # pyright: ignore[reportPrivateUsage] - import concurrent.futures - - async def _gate() -> bool: - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict( - tool_call_id, timeout, + if ( + ws is not None + and ws._login_ack_received.is_set() + and ws.policy_active() + ): # pyright: ignore[reportPrivateUsage] + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, ) - if verdict is not None and _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s " - "mad_code=%s (AgentExecutor sync path)", - verdict.event_id, - verdict.mad_code, - ) - return True - return False - - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - future: concurrent.futures.Future[bool] = concurrent.futures.Future() - - async def _run() -> None: - try: - result = await _gate() - future.set_result(result) - except Exception as exc: - future.set_exception(exc) - - loop.create_task(_run()) - should_block = future.result(timeout=35) - else: - should_block = loop.run_until_complete(_gate()) - - if should_block: - return AgentStep( - action=agent_action, - observation=_BLOCKED_OBSERVATION, - ) - except Exception: - logger.debug( - "AgentExecutor sync gate failed, falling through", - exc_info=True, + return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, ) - - return original_perform( - self, name_to_tool_map, color_mapping, agent_action, run_manager, + return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) + return await original_aperform( + self, name_to_tool_map, color_mapping, agent_action, run_manager ) AgentExecutor._aperform_agent_action = patched_aperform # type: ignore[assignment] - AgentExecutor._perform_agent_action = patched_perform # type: ignore[assignment] AgentExecutor._adrian_executor_patched = True # type: ignore[attr-defined] - logger.debug("Patched AgentExecutor._aperform_agent_action / _perform_agent_action") + logger.debug("Patched AgentExecutor._aperform_agent_action") diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 30f1ab5..1ab5df4 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -513,18 +513,6 @@ async def connect(self) -> None: else: logger.info("WebSocket connected: %s", self._url) - # Eager login: send the SessionLogin frame immediately - # so the server responds with LoginAck before any tool - # gate fires. Previously login was deferred to the - # first _send_frame call, which meant frameworks that - # don't trigger paired events (AgentExecutor) would - # never receive LoginAck and the block gate would time - # out. Provider/model are best-effort at this point - # (empty until the first LLM event auto-detects them). - if not self._logged_in: - await self._send_login(self._ws) - self._logged_in = True - # Drain anything buffered while we were offline, even # on the very first connect. ``_send_mcp_inventory`` # and other init-time emitters queue frames before the From b66f0bc619899509131ed6762766809fc6b762b3 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Mon, 15 Jun 2026 22:22:27 +0200 Subject: [PATCH 3/5] Fix: remove double-gating --- sdk/adrian/__init__.py | 200 ++++++++++++++++++++--------- sdk/adrian/ws.py | 41 ++++-- sdk/tests/test_block_mode.py | 29 +++-- sdk/tests/test_block_mode_races.py | 22 ++-- sdk/tests/test_exec_modes.py | 2 +- 5 files changed, 201 insertions(+), 93 deletions(-) diff --git a/sdk/adrian/__init__.py b/sdk/adrian/__init__.py index 1b55d83..e50ddda 100644 --- a/sdk/adrian/__init__.py +++ b/sdk/adrian/__init__.py @@ -799,7 +799,46 @@ async def patched_astream( logger.debug("Patched Pregel.invoke / ainvoke / astream") -# --- 5. ToolNode (callback injection only — gate is on BaseTool) --- +# --- 5. ToolNode --- + + +def _extract_tool_calls( + state: dict[str, Any] | list[BaseMessage] | Any, +) -> list[dict[str, Any]]: + """Extract tool_calls from ToolNode input (all three dispatch shapes). + + Returns full tool_call dicts (with id, name, args) for backward + compat with tests and callers that need the full shape. + """ + # Shape 3: per-tool-call dict from _afunc dispatch + if isinstance(state, dict) and "tool_call" in state: + tc = state["tool_call"] + if isinstance(tc, dict) and tc.get("id"): + return [tc] + tc_id = getattr(tc, "id", None) + if tc_id: + return [ + { + "id": tc_id, + "name": getattr(tc, "name", ""), + "args": getattr(tc, "args", {}), + } + ] + return [] + + # Shape 1/2: state dict or message list + if isinstance(state, dict): + messages = list(state.get("messages") or []) # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] + elif isinstance(state, list): + messages = list(state) + else: + return [] + + for msg in reversed(messages): + if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): + return msg.tool_calls # type: ignore[no-any-return] + + return [] def _should_halt(verdict: pb.Verdict) -> bool: @@ -840,26 +879,6 @@ def _patch_tool_node() -> None: original_ainvoke = ToolNode.ainvoke original_astream = getattr(ToolNode, "astream", None) - def _extract_tool_call_ids(state: Any) -> list[str]: # noqa: ANN401 - """Extract tool_call_ids from ToolNode input (any shape).""" - # Shape 3: per-tool-call dict from _afunc dispatch - if isinstance(state, dict) and "tool_call" in state: - tc = state["tool_call"] - tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) - return [tc_id] if tc_id else [] - # Shape 1/2: state dict or message list - messages = ( - list(state.get("messages") or []) - if isinstance(state, dict) - else list(state) - if isinstance(state, list) - else [] - ) - for msg in reversed(messages): - if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None): - return [tc.get("id") for tc in msg.tool_calls if tc.get("id")] - return [] - async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 """Returns True if tools should be BLOCKED.""" ws = _ws_client @@ -874,13 +893,14 @@ async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 if not ws.policy_active(): return False - tc_ids = _extract_tool_call_ids(state) + tc_ids: list[str] = [ + str(tc.get("id")) for tc in _extract_tool_calls(state) if tc.get("id") + ] if not tc_ids: return False cfg = _get_config() timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - # Gate on the first tool_call_id (all come from the same LLM turn) verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) if verdict is None: logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") @@ -895,7 +915,7 @@ async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 return False def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 - tc_ids = _extract_tool_call_ids(state) + tc_ids = [tc.get("id") for tc in _extract_tool_calls(state) if tc.get("id")] return { "messages": [ ToolMessage( @@ -921,8 +941,11 @@ async def patched_ainvoke( **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - if await _gate_tool_calls(input): - return _build_blocked(input) + # Verdict gate removed — BaseTool.ainvoke/arun is the single + # gate layer. Gating here too caused double-gate: ToolNode + # consumed the verdict future, BaseTool's gate registered a + # fresh future that never resolved → 30s timeout on a benign + # verdict. Callback injection is kept so events still flow. return await original_ainvoke(self, input, config=config, **kwargs) async def patched_astream( @@ -932,9 +955,7 @@ async def patched_astream( **kwargs: Any, # noqa: A002, ANN401 ) -> Any: # noqa: ANN401 config = _inject_callbacks(config) - if await _gate_tool_calls(input): - yield _build_blocked(input) - return + assert original_astream is not None # guarded by line below async for chunk in original_astream(self, input, config=config, **kwargs): yield chunk @@ -970,7 +991,9 @@ def _patch_base_tool() -> None: In MODE_ALERT, no gate fires at all (skip). """ from langchain_core.tools import BaseTool - from langchain_core.tools.base import _is_tool_call # pyright: ignore[reportPrivateUsage] + from langchain_core.tools.base import ( + _is_tool_call, # pyright: ignore[reportPrivateUsage] + ) if getattr(BaseTool, "_adrian_base_tool_patched", False): return @@ -1030,11 +1053,19 @@ async def _async_gate(tool_call_id: str) -> bool: return False def _sync_gate(tool_call_id: str) -> bool: - """Sync verdict gate for pure-sync callers (no running event loop). + """Sync verdict gate — works for pure-sync and worker-thread callers. + + Pure-sync (no event loop): runs ``_async_gate`` via + ``loop.run_until_complete``. + + Worker-thread (Pregel dispatches sync tools on a thread-pool + worker while the event loop runs on the main thread): bridges + the async gate to the main loop via ``run_coroutine_threadsafe`` + and blocks the worker thread until the verdict resolves. - When called from within an async loop (ToolNode._func dispatched - by Pregel), this cannot work — use the ToolNode.ainvoke gate - instead. Returns False (skip) when a running loop is detected. + Event-loop thread (calling tool.invoke directly from async + code): cannot block — returns False (skip). The async path + (BaseTool.ainvoke) handles this case. """ ws = _ws_client if ws is None or not ws._login_ack_received.is_set() or not ws.policy_active(): # pyright: ignore[reportPrivateUsage] @@ -1042,14 +1073,45 @@ def _sync_gate(tool_call_id: str) -> bool: try: loop = asyncio.get_event_loop() - if loop.is_running(): - # Can't block the event loop thread. The ToolNode.ainvoke - # gate handles this path. - return False + except RuntimeError: + return False + + if not loop.is_running(): + # Pure-sync caller — safe to block return loop.run_until_complete(_async_gate(tool_call_id)) + + # Check if we're on a worker thread (no running loop on THIS + # thread) vs the event-loop thread itself. + try: + asyncio.get_running_loop() + # We ARE on the event-loop thread — can't block it. + return False except RuntimeError: + pass + + # Worker thread: bridge the async gate to the main loop. + main_loop = getattr(ws, "_loop", None) + if main_loop is None or not main_loop.is_running(): + return False + + try: + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + future = asyncio.run_coroutine_threadsafe( + _async_gate(tool_call_id), main_loop + ) + return future.result(timeout=timeout if timeout else 60.0) + except Exception: return False + def _blocked_response(tc_id: str) -> Any: # noqa: ANN401 + """Return a blocked response compatible with ToolNode (ToolMessage) + and legacy callers (falls back to bare string).""" + try: + return ToolMessage(content=_BLOCKED_CONTENT, tool_call_id=tc_id, name="") + except Exception: + return _BLOCKED_CONTENT + def patched_invoke( self: Any, # noqa: ANN401 input: Any, # noqa: A002, ANN401 @@ -1059,7 +1121,7 @@ def patched_invoke( config = _inject_callbacks(config) tc_id = _extract_tool_call_id(input) if tc_id and _sync_gate(tc_id): - return _BLOCKED_CONTENT + return _blocked_response(tc_id) return original_invoke(self, input, config=config, **kwargs) async def patched_ainvoke( @@ -1071,7 +1133,7 @@ async def patched_ainvoke( config = _inject_callbacks(config) tc_id = _extract_tool_call_id(input) if tc_id and await _async_gate(tc_id): - return _BLOCKED_CONTENT + return _blocked_response(tc_id) return await original_ainvoke(self, input, config=config, **kwargs) original_arun = BaseTool.arun @@ -1085,7 +1147,7 @@ async def patched_arun( ) -> Any: # noqa: ANN401 """Gate on arun — AgentExecutor calls tool.arun directly.""" if tool_call_id and await _async_gate(tool_call_id): - return _BLOCKED_CONTENT + return _blocked_response(tool_call_id) return await original_arun( self, tool_input, *args, tool_call_id=tool_call_id, **kwargs ) @@ -1138,27 +1200,41 @@ async def patched_aperform( tc_id = getattr(agent_action, "tool_call_id", None) if tc_id: ws = _ws_client - if ( - ws is not None - and ws._login_ack_received.is_set() - and ws.policy_active() - ): # pyright: ignore[reportPrivateUsage] - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) - if verdict is None: - logger.warning( - "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", - tc_id, - ) - return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return AgentStep(action=agent_action, observation=_BLOCKED_CONTENT) + if ws is not None: + if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] + try: + await asyncio.wait_for( + ws._login_ack_received.wait(), # pyright: ignore[reportPrivateUsage] + timeout=5.0, + ) + except TimeoutError: + logger.warning( + "AgentExecutor: LoginAck not received within 5s; blocking" + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if ws.policy_active(): + cfg = _get_config() + timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) + verdict = await ws.wait_for_tool_call_verdict(tc_id, timeout) + if verdict is None: + logger.warning( + "AgentExecutor: verdict timeout for tool_call_id=%s, blocking (fail-closed)", + tc_id, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) + if _should_halt(verdict): + logger.warning( + "halting tool execution for event_id=%s mad_code=%s", + verdict.event_id, + verdict.mad_code, + ) + return AgentStep( + action=agent_action, observation=_BLOCKED_CONTENT + ) return await original_aperform( self, name_to_tool_map, color_mapping, agent_action, run_manager ) diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py index 1ab5df4..169cbdc 100644 --- a/sdk/adrian/ws.py +++ b/sdk/adrian/ws.py @@ -52,6 +52,8 @@ _MAX_RUN_ID_MAP = 1024 # Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). _MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 _DEFAULT_REPLAY_BUFFER_FRAMES = 1000 @@ -254,6 +256,10 @@ def __init__( # Set by close() so _handle_disconnect knows not to spawn a reconnect # during a graceful shutdown. self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None # Futures awaited by the patched ToolNode.ainvoke when the # active mode requires a wait (BLOCK or HITL). Each resolves # with the matching ``Verdict`` proto. Futures survive a @@ -472,6 +478,7 @@ async def connect(self) -> None: backoff = _INITIAL_BACKOFF loop = asyncio.get_running_loop() + self._loop = loop headers: dict[str, str] = {} @@ -491,7 +498,6 @@ async def connect(self) -> None: disconnected_at = self._disconnected_at is_reconnect = disconnected_at is not None - if disconnected_at is not None: downtime = time.monotonic() - disconnected_at self._disconnected_at = None @@ -927,6 +933,18 @@ def register_pending( return fut + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + async def wait_for_verdict( self, event_id: str, @@ -939,25 +957,30 @@ async def wait_for_verdict( ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the verdict, or ``None`` on timeout (fail-open). - Cleans up the ``_pending_verdicts`` entry on either path: - ``_on_verdict_frame`` only resolves the future, the dict - ownership belongs here so a late ``register_pending`` after the - verdict has already arrived can still find the resolved future. + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. """ fut = self.register_pending(event_id) try: - return await asyncio.wait_for(fut, timeout=timeout) + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result except TimeoutError: logger.warning( "Verdict timeout for event_id=%s after %ss", event_id, timeout, ) - - return None - finally: + # Timed-out future is useless — remove so a retry can + # register a fresh one. self._pending_verdicts.pop(event_id, None) + return None async def wait_for_tool_verdict( self, diff --git a/sdk/tests/test_block_mode.py b/sdk/tests/test_block_mode.py index 0d1c352..0bbbdaf 100644 --- a/sdk/tests/test_block_mode.py +++ b/sdk/tests/test_block_mode.py @@ -142,10 +142,16 @@ async def test_looks_up_llm_event_id_and_resolves(self) -> None: class TestToolNodePatchBlocking: async def test_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None: - """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → halt with synthetic ToolMessage.""" + """MODE_BLOCK + policy_m4=true + mad_code='M4_a' → BaseTool.ainvoke gate blocks. - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + The verdict gate lives on BaseTool (the universal layer), not + ToolNode.ainvoke. Uses an async tool so BaseTool.ainvoke (not + BaseTool.invoke) is the entry point — matching the production + path for create_react_agent with async tools. + """ + + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" _real_tool.called = True # type: ignore[attr-defined] return x @@ -180,6 +186,7 @@ def _real_tool(x: str) -> str: result = await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] + # BaseTool.ainvoke gate blocks — tool body does NOT run. assert _real_tool.called is False # type: ignore[attr-defined] msgs = result["messages"] assert len(msgs) == 1 @@ -190,7 +197,7 @@ async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) @@ -226,11 +233,12 @@ def _real_tool(x: str) -> str: assert captured == ["hi"] - async def test_timeout_fail_open_runs_tool(self, tmp_path: Path) -> None: + async def test_timeout_fail_closed_blocks_tool(self, tmp_path: Path) -> None: + """Verdict timeout in MODE_BLOCK → fail-closed (tool does NOT run).""" captured: list[str] = [] - def _real_tool(x: str) -> str: - """Real tool stub for block-mode tests.""" + async def _real_tool(x: str) -> str: + """Real async tool stub for block-mode tests.""" captured.append(x) return x @@ -248,7 +256,7 @@ def _real_tool(x: str) -> str: _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) ws._connected.set() ws._tool_call_id_to_event_id["tc-1"] = "llm-evt" - # No pending future → wait_for_verdict times out → fail-open. + # No pending future → wait_for_verdict times out → fail-closed (MODE_BLOCK). tool_node = ToolNode([_real_tool]) ai = AIMessage( @@ -259,7 +267,8 @@ def _real_tool(x: str) -> str: await tool_node.ainvoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType] - assert captured == ["hi"] + # Fail-closed: tool should NOT have run. + assert captured == [] class TestModeAlert: @@ -268,7 +277,7 @@ async def test_alert_mode_skips_wait(self, tmp_path: Path) -> None: captured: list[str] = [] - def _real_tool(x: str) -> str: + async def _real_tool(x: str) -> str: """Real tool stub for block-mode tests.""" captured.append(x) diff --git a/sdk/tests/test_block_mode_races.py b/sdk/tests/test_block_mode_races.py index fa0ad57..16d8e4a 100644 --- a/sdk/tests/test_block_mode_races.py +++ b/sdk/tests/test_block_mode_races.py @@ -5,17 +5,17 @@ LLM calls; no running backend. Scenarios mirror the validated shapes from the multi-agent work: - S1 subagents-as-tools - director → worker (nested) - S2 handoffs - triage → specialist (sequential) - S3 router - parallel fan-out via Send() - S4 hierarchical - 3-level deep (director → team-lead → worker) - S5 custom workflow - deterministic + LLM nodes mixed - S6 swarm - back-and-forth handoffs (Alice ↔ Bob) - S7 supervisor - central dispatcher to N workers - S8 deep research - parallel researchers via asyncio.gather + S1 subagents-as-tools , director → worker (nested) + S2 handoffs , triage → specialist (sequential) + S3 router , parallel fan-out via Send() + S4 hierarchical , 3-level deep (director → team-lead → worker) + S5 custom workflow , deterministic + LLM nodes mixed + S6 swarm , back-and-forth handoffs (Alice ↔ Bob) + S7 supervisor , central dispatcher to N workers + S8 deep research , parallel researchers via asyncio.gather The invariant under test: for EVERY pattern, each ToolNode invocation -blocks on the verdict of the LLM that emitted its specific tool_call.id - +blocks on the verdict of the LLM that emitted its specific tool_call.id , never a sibling, never a parent, never a stale global. """ @@ -117,9 +117,9 @@ def _init_block_mode(tmp_path: Path, block_timeout: float = 1.0) -> Any: def _tool(name: str, captured: list[str]) -> Any: - """Build a named stub tool that records its argument.""" + """Build a named async stub tool that records its argument.""" - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(f"{name}:{x}") diff --git a/sdk/tests/test_exec_modes.py b/sdk/tests/test_exec_modes.py index 1ea8ae1..f3f5e42 100644 --- a/sdk/tests/test_exec_modes.py +++ b/sdk/tests/test_exec_modes.py @@ -61,7 +61,7 @@ def _cleanup() -> Iterator[None]: # pyright: ignore[reportUnusedFunction] def _stub_tool(captured: list[str]) -> Any: # noqa: ANN401 - def _impl(x: str) -> str: + async def _impl(x: str) -> str: """Stub tool.""" captured.append(x) From 3b9218876d4924bbe4bdfa23ea948b0a0eb66404 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Mon, 15 Jun 2026 22:55:47 +0200 Subject: [PATCH 4/5] Add: merge conflicts and merge --- sdk/adrian/ws.py | 1038 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1038 insertions(+) create mode 100644 sdk/adrian/ws.py diff --git a/sdk/adrian/ws.py b/sdk/adrian/ws.py new file mode 100644 index 0000000..169cbdc --- /dev/null +++ b/sdk/adrian/ws.py @@ -0,0 +1,1038 @@ +"""Async WebSocket ``EventHandler`` that streams ``PairedEvent`` to the worker core API. + +Converts each ``PairedEvent`` into a ``pb.PairedEvent`` protobuf, wraps it in a +``ClientFrame.paired_batch``, and sends it over a long-lived WebSocket +connection. Verdicts received back resolve block-mode futures and fire the +callback handler's verdict processing. + +Implements the ``EventHandler`` protocol so it slots into the SDK's hook +registry alongside ``JSONLHandler``. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import json +import logging +import time +from collections import OrderedDict, deque +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import websockets + +if TYPE_CHECKING: + from adrian.config import OnDisconnectCallback, OnReconnectCallback + from adrian.handler import AdrianCallbackHandler + +from adrian.format.types import ( + AgentContext, + LlmPairData, + PairedEvent, + ParentContext, +) +from adrian.proto import event_pb2 as pb + +logger = logging.getLogger("adrian.ws") + +SCHEMA_VERSION = 2 + +_INITIAL_BACKOFF = 1.0 +_MAX_BACKOFF = 30.0 +# Server close code: quota exhausted. Spec'd in +# server/internal/websocket/handler.go (closeQuotaExceeded). Returning +# every 30s would hammer the server while quota is depleted; one +# minute is slow enough to be cheap, fast enough that the next hourly +# / daily / monthly window-rollover is picked up within tolerance. +_QUOTA_EXHAUSTED_CLOSE_CODE = 4003 +_QUOTA_RECONNECT_DELAY = 60.0 +# Cap on in-flight LLM run_id → event_id mappings. Evicted LRU-style; +# block-mode lookups for evicted entries fail open. +_MAX_RUN_ID_MAP = 1024 +# Cap on in-flight tool_call_id → event_id mappings (block-mode correlation). +_MAX_TOOL_CALL_MAP = 1024 +# Cap on resolved verdict futures kept for late-waiter replay. +_MAX_PENDING_VERDICTS = 512 + +_DEFAULT_REPLAY_BUFFER_FRAMES = 1000 + +# Heartbeat tuning. 10s interval / 15s pong timeout detects half-open +# connections (ALB idle cut, NAT drop, dead remote process) without +# flooding the wire. Kept in sync with the backend's pingInterval / +# pongTimeout, if these change, update server/internal/websocket/handler.go. +_PING_INTERVAL = 10.0 +_PING_TIMEOUT = 15.0 + +_PROVIDER_PREFIXES: dict[str, str] = { + "chatanthropic": "anthropic", + "chatopenai": "openai", + "chatgooglegenai": "google", + "chatcohere": "cohere", + "chatmistralai": "mistral", +} + +_PAIR_TYPE_MAP: dict[str, pb.PairType.ValueType] = { + "llm": pb.PAIR_TYPE_LLM, + "tool": pb.PAIR_TYPE_TOOL, +} + + +def _derive_provider(model_class_name: str) -> str: + """Derive the LLM provider from the model class name. + + Args: + model_class_name: Class name like ``"ChatAnthropic"`` or ``"ChatOpenAI"``. + + Returns: + Provider string (e.g. ``"anthropic"``), or the class name lower-cased + if no known prefix matches. + """ + key = model_class_name.lower() + + return _PROVIDER_PREFIXES.get(key, key) + + +def _fill_agent_context( + pb_ctx: pb.AgentContext, src: AgentContext | ParentContext +) -> None: + """Copy an AgentContext / ParentContext dataclass into its proto counterpart.""" + pb_ctx.agent_id = src.agent_id + pb_ctx.system_prompt = src.system_prompt + pb_ctx.user_instruction = src.user_instruction + + +def _safe_cancel( + task_or_future: asyncio.Task[Any] | asyncio.Future[Any] | None, +) -> None: + """Cancel a task / future, ignoring closed-loop errors at shutdown. + + Adrian's ``atexit`` handler may run after the user's loop has been + closed; in that path ``adrian.shutdown`` spawns a new ``asyncio.run`` + and walks each handler's ``close()``. Tasks bound to the *old* loop + can no longer be cancelled (``call_soon`` raises ``Event loop is + closed``). Swallowing the error here keeps the cleanup path quiet, + the task will be reaped when the dead loop is GC'd. + """ + if task_or_future is None or task_or_future.done(): + return + # "Event loop is closed", old loop is gone, nothing to cancel. + with contextlib.suppress(RuntimeError): + task_or_future.cancel() + + +def _paired_event_to_proto(event: PairedEvent) -> pb.PairedEvent: + """Convert a ``PairedEvent`` dataclass into its protobuf form. + + ``parent.agent_id`` empty-string signals "no parent agent". + ``parent_run_id`` empty-string signals "no parent in run tree". + """ + proto = pb.PairedEvent( + event_id=event.event_id, + invocation_id=event.invocation_id, + session_id=event.session_id, + run_id=event.run_id, + parent_run_id=event.parent_run_id, + timestamp=event.timestamp, + pair_type=_PAIR_TYPE_MAP.get(event.pair_type, pb.PAIR_TYPE_UNSPECIFIED), + ) + + _fill_agent_context(proto.agent, event.agent) + + if event.parent is not None: + _fill_agent_context(proto.parent, event.parent) + + if isinstance(event.data, LlmPairData): + proto.llm.model = event.data.model + + for msg in event.data.messages: + pb_msg = proto.llm.messages.add() + pb_msg.role = msg["role"] + pb_msg.content = msg["content"] + + proto.llm.output = event.data.output + + for tc in event.data.tool_calls: + pb_tc = proto.llm.tool_calls.add() + pb_tc.name = tc["name"] + pb_tc.args = json.dumps(tc["args"], default=str) + pb_tc.id = tc["id"] + + if event.data.usage is not None: + proto.llm.usage.prompt_tokens = event.data.usage["prompt_tokens"] + proto.llm.usage.completion_tokens = event.data.usage["completion_tokens"] + proto.llm.usage.total_tokens = event.data.usage["total_tokens"] + else: + # Union is LlmPairData | ToolPairData; this branch is the + # ToolPairData case. + proto.tool.tool_name = event.data.tool_name + proto.tool.tool_call_id = event.data.tool_call_id or "" + proto.tool.input = event.data.input + proto.tool.output = event.data.output + + if event.metadata: + proto.metadata_json = json.dumps(event.metadata, default=str).encode() + + return proto + + +class WebSocketClient: + """Streams ``PairedEvent`` instances to the worker core API. + + Connects eagerly via :meth:`schedule_connect` with exponential backoff, + auto-detects the LLM provider on the first LLM pair, sends paired events + as protobuf frames, and resolves block-mode futures when verdicts arrive. + """ + + def __init__( + self, + url: str, + session_id: str, + api_key: str, + handler: AdrianCallbackHandler | None = None, + on_disconnect: OnDisconnectCallback | None = None, + on_reconnect: OnReconnectCallback | None = None, + on_login_ack: Callable[[], Awaitable[None]] | None = None, + replay_buffer_frames: int = _DEFAULT_REPLAY_BUFFER_FRAMES, + ) -> None: + """Initialise without connecting. + + Args: + url: WebSocket endpoint URL. + session_id: Session ID sent in the login frame. + api_key: Adrian API key for the ``Authorization`` header. + handler: Callback handler for verdict processing. + on_disconnect: Fired when the connection is lost (sync or async). + Receives a reason string. + on_reconnect: Fired when the connection re-establishes after a + prior disconnect (sync or async). Does not fire on initial + connect. + on_login_ack: Async hook fired after each ``LoginAck`` frame is + applied, once per (re)connect. Used internally to push a + fresh ``McpInventory`` on every login. Exceptions are + logged and swallowed. + replay_buffer_frames: Ring-buffer capacity (frame count, not + bytes). When the cap is reached each further append evicts + the oldest frame; a one-shot WARN fires on first fill, and + the cumulative drop count is logged at WARN on the next + reconnect. + """ + self._url = url + self._session_id = session_id + self._api_key = api_key + self._handler = handler + self._on_disconnect = on_disconnect + self._on_reconnect = on_reconnect + self._on_login_ack_cb = on_login_ack + self._provider = "" + self._model = "" + # Server-supplied execution-mode policy. Populated when the + # first ServerFrame{login_ack} arrives after each (re)connect. + # ``policy_active()`` and ``block_timeout()`` read this state + # to decide whether the patched ToolNode should wait for a + # verdict and how long. + self._mode: int = pb.MODE_UNSPECIFIED + self._policy: pb.PolicySnapshot | None = None + # Set the first time a ``ServerFrame{login_ack}`` is applied. + # Used in two places: + # 1. ``on_paired_event`` defensively pre-registers a + # verdict-wait future when this event is unset, so the + # very first tool-bearing LLM emission is covered even + # though the recv loop hasn't yet processed LoginAck and + # ``policy_active()`` reads False. + # 2. The patched ``ToolNode.ainvoke`` ``await``s this event + # (with a short timeout) before deciding whether to wait + # for a verdict, so the first ToolNode invocation cannot + # run-through-without-waiting in the same window. + # Stays set across disconnect/reconnect because mode/policy + # state survives, a fresh LoginAck on reconnect simply re-sets + # an already-set event. + self._login_ack_received: asyncio.Event = asyncio.Event() + self._ws: websockets.ClientConnection | None = None + self._logged_in = False + self._connected = asyncio.Event() + self._connect_task: asyncio.Task[None] | None = None + self._recv_task: asyncio.Task[None] | None = None + # Set by close() so _handle_disconnect knows not to spawn a reconnect + # during a graceful shutdown. + self._closing = False + # Event loop running the WebSocket tasks. Captured on first + # connect so _sync_gate can bridge async waits from worker + # threads via run_coroutine_threadsafe. + self._loop: asyncio.AbstractEventLoop | None = None + # Futures awaited by the patched ToolNode.ainvoke when the + # active mode requires a wait (BLOCK or HITL). Each resolves + # with the matching ``Verdict`` proto. Futures survive a + # disconnect: a late verdict after reconnect still resolves + # the wait; if none arrives, ``wait_for_verdict``'s timeout + # produces a natural fail-open in BLOCK mode. + self._pending_verdicts: dict[str, asyncio.Future[pb.Verdict]] = {} + # Maps LLM pair run_id → event_id so a subsequent tool call can + # look up the verdict by its parent_run_id (the LLM's run_id). + # LRU-capped at _MAX_RUN_ID_MAP to bound memory on long sessions. + self._run_id_to_event_id: OrderedDict[str, str] = OrderedDict() + # Verdict-correlation map: maps each tool_call.id emitted by + # an LLM to the event_id of the LLM pair that emitted it. + # Populated on every LLM PairedEvent that has tool_calls. + # Consulted by the patched ``ToolNode.ainvoke`` so each tool + # in a parallel fan-out waits on its own producing LLM's + # verdict, not a global "last" pointer. LRU-capped at + # ``_MAX_TOOL_CALL_MAP``. + self._tool_call_id_to_event_id: OrderedDict[str, str] = OrderedDict() + # Serialises the lazy login-then-send sequence so two concurrent + # on_paired_event calls (parallel agents) cannot both send a login. + # Reused by _replay_buffer_to_ws to coordinate with live sends. + self._login_lock = asyncio.Lock() + # Ring buffer of recently serialised ClientFrame bytes. Appended + # only from the offline-or-send-failure paths in _send_frame; the + # happy path bypasses the ring entirely. Drained on reconnect. + self._replay_buffer: deque[bytes] = deque(maxlen=replay_buffer_frames) + # Flips True on the first append that reaches maxlen. Gates the + # one-shot "buffer full" WARN so we don't flood logs. + self._replay_buffer_filled: bool = False + # Monotonic counter of frames dropped due to buffer overflow + # (oldest evicted when a new append arrives at a full ring). + # Logged at WARN on the next reconnect. + self._replay_buffer_dropped: int = 0 + # True while the reconnect path is draining the replay buffer. + # Live sends observed during this window are routed back into + # the same deque so they slot in AFTER the pre-outage tail + # rather than racing onto the wire ahead of older buffered + # frames. Flipped on as the first sync line of + # _replay_buffer_to_ws and cleared in its finally. + self._replaying: bool = False + # Set by _handle_disconnect, cleared on successful reconnect. + # Used to gate on_reconnect and measure downtime. + self._disconnected_at: float | None = None + # One-shot delay applied before the next ``connect()`` attempt. + # Set when the server closes with a code that requests a longer + # wait (currently only 4003 quota exhausted); cleared by + # ``connect()`` after honouring it. ``None`` means use the + # standard exponential schedule. + self._next_reconnect_delay: float | None = None + + # -- Mode / policy state (populated by LoginAck) -- + + def policy_active(self) -> bool: + """Whether the active server mode requires waiting on verdicts. + + Single predicate consulted by the patched ``ToolNode.ainvoke``. + Returns ``True`` for ``MODE_BLOCK`` and ``MODE_HITL``; ``False`` + for ``MODE_ALERT`` and unset (pre-login) state. + """ + return self._mode in (pb.MODE_BLOCK, pb.MODE_HITL) + + def block_timeout(self, kwarg_default: float) -> float | None: + """Effective per-tool-call wait timeout for the active mode. + + - ``MODE_BLOCK``: ``kwarg_default`` (typically 30s), fail-open + if the server doesn't classify in time. + - ``MODE_HITL``: ``None``, wait indefinitely for human review. + - ``MODE_ALERT`` / unset: ``0``, caller short-circuits before + registering a future. + """ + if self._mode == pb.MODE_BLOCK: + return kwarg_default + elif self._mode == pb.MODE_HITL: + return None + else: + return 0 + + # -- EventHandler protocol -- + + async def on_paired_event(self, event: PairedEvent) -> None: + """Send a paired event over the WebSocket. + + Auto-detects the LLM provider on the first LLM pair, updates the + run_id → event_id map for block mode, converts the dataclass to + protobuf, and sends a ``ClientFrame.paired_batch`` frame. + + For LLM pairs that carry tool_calls, registers the verdict-wait + future *before* the frame leaves the SDK. This closes the race + where a fast verdict roundtrip resolves and is dropped before + the patched ``ToolNode.ainvoke`` reaches its own + ``register_pending`` call. The matching ``register_pending`` + from the wait site is a get-or-create that returns the existing + future. + + Args: + event: The paired event to stream. + """ + if ( + event.pair_type == "llm" + and not self._provider + and isinstance(event.data, LlmPairData) + ): + self._model = event.data.model + self._provider = _derive_provider(event.data.model) + + if event.pair_type == "llm": + self._run_id_to_event_id[event.run_id] = event.event_id + self._run_id_to_event_id.move_to_end(event.run_id) + + if len(self._run_id_to_event_id) > _MAX_RUN_ID_MAP: + self._run_id_to_event_id.popitem(last=False) + + # Populate tool_call.id → event_id so each tool call can block + # on its own producing LLM's verdict under parallel fan-out. + if isinstance(event.data, LlmPairData) and event.data.tool_calls: + for tc in event.data.tool_calls: + tc_id = tc.get("id") or "" + + if not tc_id: + continue + + self._tool_call_id_to_event_id[tc_id] = event.event_id + self._tool_call_id_to_event_id.move_to_end(tc_id) + + if len(self._tool_call_id_to_event_id) > _MAX_TOOL_CALL_MAP: + self._tool_call_id_to_event_id.popitem(last=False) + + # Pre-register the wait future so an eager verdict + # cannot race ahead of the ToolNode patch. Gated on + # ``policy_active()`` so ALERT-mode sessions don't + # accumulate futures that will never be resolved or + # awaited, except for the very first event of the + # session, where ``LoginAck`` may not yet have been + # processed by the recv loop and ``policy_active()`` + # therefore reads False even when the mode will + # imminently be set to BLOCK or HITL. Pre-register + # defensively in that window; in ALERT mode the gate + # filters out every subsequent event so the leak is + # bounded to one orphan future per session. + if self.policy_active() or not self._login_ack_received.is_set(): + self.register_pending(event.event_id) + + proto = _paired_event_to_proto(event) + frame = pb.ClientFrame() + added = frame.paired_batch.events.add() + added.CopyFrom(proto) + + await self._send_frame(frame) + + async def close(self) -> None: + """Cancel background tasks and close the WebSocket. + + Sets ``_closing`` so any in-flight ``_handle_disconnect`` does not + spawn a reconnect during graceful shutdown. + + Defensive against the ``atexit`` shutdown path: ``adrian.shutdown`` + spawns a fresh ``asyncio.run`` loop after the user's loop has + already closed, so background tasks bound to the old loop can no + longer be cancelled cleanly (``call_soon`` raises + ``Event loop is closed``). Skip the cancel in that case, the + old loop is gone, the task will be reaped by GC. + """ + self._closing = True + + _safe_cancel(self._recv_task) + self._recv_task = None + _safe_cancel(self._connect_task) + self._connect_task = None + + if self._ws is not None: + with contextlib.suppress(Exception): + await asyncio.wait_for(self._ws.close(), timeout=2.0) + self._ws = None + + for fut in self._pending_verdicts.values(): + if not fut.done(): + _safe_cancel(fut) + self._pending_verdicts.clear() + + # -- Connection lifecycle -- + + def schedule_connect(self, loop: asyncio.AbstractEventLoop) -> None: + """Schedule :meth:`connect` as a background task on the given loop.""" + if self._connect_task is None or self._connect_task.done(): + self._connect_task = loop.create_task(self.connect()) + + async def connect(self) -> None: + """Establish the WebSocket with exponential-backoff retry. + + Heartbeat (``ping_interval`` / ``ping_timeout``) is configured on + the underlying ``websockets`` client; if the server fails to pong + within ``_PING_TIMEOUT`` the library closes the connection and + ``_recv_loop`` surfaces the disconnect via ``_handle_disconnect``. + + On a reconnect (``_disconnected_at`` set by a prior disconnect), + drains the replay buffer and fires ``on_reconnect``. Login is + deferred to ``_send_frame`` / ``_replay_buffer_to_ws`` so the + auto-detected provider/model is included. An ``api_key``, if + configured, is sent as an ``Authorization: Bearer `` header. + + Honours ``_next_reconnect_delay`` if a previous disconnect set + it (e.g. 4003 quota exhausted requests a slower retry). The + delay is consumed on the first attempt; subsequent failures + fall back to the standard exponential schedule. + """ + initial_delay = self._next_reconnect_delay + self._next_reconnect_delay = None + + if initial_delay is not None: + logger.info( + "delaying reconnect by %.0fs (server-requested)", + initial_delay, + ) + await asyncio.sleep(initial_delay) + + backoff = _INITIAL_BACKOFF + loop = asyncio.get_running_loop() + self._loop = loop + + headers: dict[str, str] = {} + + if self._api_key: + headers["Authorization"] = f"Bearer {self._api_key}" + + while True: + try: + self._ws = await websockets.connect( + self._url, + additional_headers=headers, + ping_interval=_PING_INTERVAL, + ping_timeout=_PING_TIMEOUT, + ) + self._connected.set() + self._recv_task = loop.create_task(self._recv_loop()) + + disconnected_at = self._disconnected_at + is_reconnect = disconnected_at is not None + if disconnected_at is not None: + downtime = time.monotonic() - disconnected_at + self._disconnected_at = None + logger.warning( + "WebSocket reconnected: %s (session_id=%s, downtime=%.2fs)", + self._url, + self._session_id, + downtime, + ) + + if self._replay_buffer_dropped > 0: + logger.warning( + "replay buffer dropped %d frames due to overflow " + "before this reconnect (session_id=%s); " + "increase replay_buffer_frames if this recurs", + self._replay_buffer_dropped, + self._session_id, + ) + else: + logger.info("WebSocket connected: %s", self._url) + + # Drain anything buffered while we were offline, even + # on the very first connect. ``_send_mcp_inventory`` + # and other init-time emitters queue frames before the + # WS is open; without this drain those frames never + # ship until something else triggers a live send. + if self._replay_buffer: + logger.info( + "replaying %d buffered frames after connect", + len(self._replay_buffer), + ) + await self._replay_buffer_to_ws() + + if is_reconnect: + await self._fire_on_reconnect() + + return + except Exception: + logger.warning( + "WebSocket connect to %s failed, retrying in %.0fs", + self._url, + backoff, + ) + try: + await asyncio.sleep(backoff) + except RuntimeError: + # Loop closed mid-retry (atexit shutdown). Bail out + # quietly rather than dumping a traceback. + return + backoff = min(backoff * 2, _MAX_BACKOFF) + + async def _send_login(self, ws: websockets.ClientConnection) -> None: + """Send the mandatory SessionLogin frame.""" + frame = pb.ClientFrame() + frame.login.session_id = self._session_id + frame.login.llm_stack.provider = self._provider + frame.login.llm_stack.model = self._model + frame.login.schema_version = SCHEMA_VERSION + await ws.send(frame.SerializeToString()) + logger.debug( + "Sent login (session=%s, provider=%s, model=%s, schema=%d)", + self._session_id, + self._provider, + self._model, + SCHEMA_VERSION, + ) + + async def _send_frame(self, frame: pb.ClientFrame) -> None: + """Serialise and send a ``ClientFrame``, buffering on failure. + + Happy path (connected + healthy): send over WS, bypass the ring + entirely, zero overhead. Offline on entry: buffer for replay. + During reconnect replay: buffer as well, so the drain loop picks + this frame up after the pre-outage tail (preserves order across + the outage boundary). Send raises: buffer the in-flight frame + then trigger ``_handle_disconnect`` so state is cleared and + reconnect is spawned. + """ + frame_bytes = frame.SerializeToString() + kind = frame.WhichOneof("frame") + + if not self._connected.is_set() or self._replaying: + self._buffer_frame(frame_bytes) + reason = "disconnected" if not self._connected.is_set() else "replaying" + logger.info( + "buffered for replay (session_id=%s, kind=%s, " + "buffer_size=%d, reason=%s)", + self._session_id, + kind, + len(self._replay_buffer), + reason, + ) + + return + + ws = self._ws + + if ws is None: + self._buffer_frame(frame_bytes) + + return + + try: + async with self._login_lock: + if not self._logged_in: + await self._send_login(ws) + self._logged_in = True + + await ws.send(frame_bytes) + logger.debug("Sent %s frame", kind) + except Exception: + # Send raised, we cannot confirm the server received this frame. + # Buffer it so the reconnect replay ships it, then clean up state. + self._buffer_frame(frame_bytes) + await self._handle_disconnect("send_failure") + + async def _recv_loop(self) -> None: + """Read ``ServerFrame``s, dispatch by oneof kind. + + First frame after each (re)login MUST be ``login_ack``; anything + else is a protocol error and we tear the connection down so the + reconnect path can try again. Subsequent frames are + ``verdict``s. Unknown oneof kinds (future server additions like + a quota-exhausted signal) are logged and dropped rather than + crashing the loop. + + Any exit path (clean close, exception, cancellation) calls + ``_handle_disconnect`` via ``finally`` so state is cleared and a + reconnect is spawned. + """ + ws = self._ws + + if ws is None: + return + + awaiting_login_ack = True + try: + async for message in ws: + if not isinstance(message, bytes): + continue + + frame = pb.ServerFrame() + frame.ParseFromString(message) + kind = frame.WhichOneof("frame") + + if awaiting_login_ack: + awaiting_login_ack = False + if kind != "login_ack": + logger.error( + "expected ServerFrame{login_ack} as first frame, " + "got %r, closing connection", + kind, + ) + return + + if kind == "login_ack": + self._on_login_ack(frame.login_ack) + elif kind == "verdict": + await self._on_verdict_frame(frame.verdict) + else: + logger.warning( + "ignoring unknown ServerFrame kind %r " + "(future server addition?)", + kind, + ) + except asyncio.CancelledError: + # Expected on graceful shutdown or when _handle_disconnect cancels + # us from the send_failure path. Re-raise to honour cancellation. + raise + except Exception as exc: + logger.warning("recv_loop exited: %s", exc) + finally: + close_code = getattr(ws, "close_code", None) + + if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE: + self._next_reconnect_delay = _QUOTA_RECONNECT_DELAY + + reason = ( + f"quota_exhausted (close={close_code})" + if close_code == _QUOTA_EXHAUSTED_CLOSE_CODE + else "recv_loop_exit" + ) + await self._handle_disconnect(reason) + + def _on_login_ack(self, ack: pb.LoginAck) -> None: + """Apply the org's effective execution-mode policy. + + Fires the ``on_login_ack`` hook (if configured) as a fire-and-forget + task on the running loop so the recv loop doesn't block waiting on it. + """ + self._mode = ack.policy.mode + self._policy = ack.policy + self._login_ack_received.set() + logger.info( + "LoginAck received: mode=%s policy_m0=%s policy_m2=%s " + "policy_m3=%s policy_m4=%s", + pb.Mode.Name(ack.policy.mode), + ack.policy.policy_m0, + ack.policy.policy_m2, + ack.policy.policy_m3, + ack.policy.policy_m4, + ) + + if self._on_login_ack_cb is not None: + asyncio.create_task(self._run_login_ack_cb()) + + async def _run_login_ack_cb(self) -> None: + """Invoke the on_login_ack hook, swallowing exceptions.""" + if self._on_login_ack_cb is None: + return + try: + await self._on_login_ack_cb() + except Exception: + logger.exception("on_login_ack hook raised") + + async def _on_verdict_frame(self, verdict: pb.Verdict) -> None: + """Fire callbacks then resolve the matching pending future, if any. + + The future is left in ``_pending_verdicts`` after ``set_result`` so + a later ``register_pending`` (e.g. from the patched ToolNode after + the verdict has already round-tripped) returns the resolved + future and the wait completes immediately. ``wait_for_verdict`` + owns the cleanup: its ``finally`` pops the entry after the await + returns. + """ + logger.info( + "Verdict received: event_id=%s mad_code=%s mode=%s hitl=%s", + verdict.event_id, + verdict.mad_code or "-", + pb.Mode.Name(verdict.policy.mode), + verdict.HasField("hitl"), + ) + + if self._handler is not None: + await self._handler.handle_verdict(verdict) + + fut = self._pending_verdicts.get(verdict.event_id) + + if fut is None: + if verdict.HasField("hitl"): + logger.warning( + "HITL resolution for unknown event_id=%s, ignoring " + "(stale resolution from a prior SDK process)", + verdict.event_id, + ) + return + + if not fut.done(): + fut.set_result(verdict) + + # -- Resilience: buffering, replay, disconnect/reconnect -- + + def _buffer_frame(self, frame_bytes: bytes) -> None: + """Append a serialised frame to the replay ring. + + Tracks overflow drops and fires the one-shot "buffer full" WARN. + Called only from the offline or send-failure paths in + ``_send_frame``, the happy path bypasses the ring entirely. + """ + if len(self._replay_buffer) == self._replay_buffer.maxlen: + self._replay_buffer_dropped += 1 + + self._replay_buffer.append(frame_bytes) + + if ( + not self._replay_buffer_filled + and len(self._replay_buffer) == self._replay_buffer.maxlen + ): + self._replay_buffer_filled = True + logger.warning( + "adrian replay buffer reached capacity (%d frames); " + "further frames will evict oldest. Tune via " + "replay_buffer_frames or ADRIAN_REPLAY_BUFFER_FRAMES.", + self._replay_buffer.maxlen, + ) + + async def _replay_buffer_to_ws(self) -> None: + """Reissue buffered frames over the current WebSocket. + + Sends ``SessionLogin`` first if not already logged in (the server + requires it as the first frame on every new connection). Uses + ``_login_lock`` so a concurrent live send does not race on the + login check. + + Drains the deque one frame at a time via ``popleft`` inside a + ``while`` loop, rather than taking a snapshot up front. That + way, a live ``_send_frame`` call observed during the drain + routes its frame to the back of the same deque (because + ``_replaying`` is set) and this loop picks it up in the next + iteration, preserving across-outage order + ``[pre-outage] → [live during replay] → [post-replay live]``. + + On a mid-drain send failure, the failed frame is put back at + the front with ``appendleft`` and the function returns; the + next reconnect resumes from exactly where this one stopped. + """ + ws = self._ws + + if ws is None: + return + + self._replaying = True + try: + async with self._login_lock: + if not self._logged_in: + try: + await self._send_login(ws) + self._logged_in = True + except Exception as exc: + logger.warning( + "replay aborted: login send failed: %s", + exc, + ) + + return + + sent = 0 + while self._replay_buffer: + frame_bytes = self._replay_buffer.popleft() + try: + await ws.send(frame_bytes) + except Exception as exc: + # Put the failed frame back at the front so the next + # reconnect's drain resumes from exactly this point. + self._replay_buffer.appendleft(frame_bytes) + logger.warning( + "replay aborted after %d frame(s), %d remaining: %s", + sent, + len(self._replay_buffer), + exc, + ) + + return + sent += 1 + + logger.info("replayed %d buffered frames", sent) + self._replay_buffer_dropped = 0 + self._replay_buffer_filled = False + finally: + self._replaying = False + + async def _handle_disconnect(self, reason: str) -> None: + """Clear connection state and spawn a reconnect. + + Idempotent: if already disconnected or closing, returns immediately. + Pending verdict futures are intentionally left pending across the + disconnect, a late verdict after reconnect resolves them; if none + arrives, ``wait_for_verdict``'s timeout fires naturally. + """ + if self._closing or not self._connected.is_set(): + return + + self._connected.clear() + self._disconnected_at = time.monotonic() + + # Only cancel the recv task if we are not currently running inside it. + # When _recv_loop's own finally calls us, self._recv_task IS the + # current task, cancelling it would raise CancelledError inside the + # finally and prevent us from finishing disconnect handling. + current = asyncio.current_task() + + if self._recv_task is not None and self._recv_task is not current: + self._recv_task.cancel() + + self._recv_task = None + self._ws = None + self._logged_in = False + + logger.warning( + "disconnected (session_id=%s, reason=%s, pending_verdicts=%d)", + self._session_id, + reason, + len(self._pending_verdicts), + ) + + await self._fire_on_disconnect(reason) + + if self._closing: + return + + loop = asyncio.get_running_loop() + + if self._connect_task is None or self._connect_task.done(): + self._connect_task = loop.create_task(self.connect()) + + async def _fire_on_disconnect(self, reason: str) -> None: + """Invoke the on_disconnect callback, catching any exception.""" + if self._on_disconnect is None: + return + + try: + result = self._on_disconnect(reason) + + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception("on_disconnect callback raised") + + async def _fire_on_reconnect(self) -> None: + """Invoke the on_reconnect callback, catching any exception.""" + if self._on_reconnect is None: + return + + try: + result = self._on_reconnect() + + if asyncio.iscoroutine(result): + await result + except Exception: + logger.exception("on_reconnect callback raised") + + # -- Verdict-wait support -- + + def register_pending( + self, + event_id: str, + ) -> asyncio.Future[pb.Verdict]: + """Return a future awaiting a verdict for ``event_id``. + + Reuses an existing pending future if one is already registered, + so concurrent callers waiting on the same event_id see the same + verdict once it arrives. Must be called BEFORE sending the event + to avoid the race where the verdict arrives before the future exists. + """ + existing = self._pending_verdicts.get(event_id) + + if existing is not None: + return existing + + loop = asyncio.get_running_loop() + fut: asyncio.Future[pb.Verdict] = loop.create_future() + self._pending_verdicts[event_id] = fut + + return fut + + def _evict_resolved_verdicts(self) -> None: + """Remove oldest resolved futures when the dict exceeds the cap.""" + while len(self._pending_verdicts) > _MAX_PENDING_VERDICTS: + # Evict the oldest entry (dict preserves insertion order). + oldest_id = next(iter(self._pending_verdicts)) + oldest_fut = self._pending_verdicts[oldest_id] + if oldest_fut.done(): + del self._pending_verdicts[oldest_id] + else: + # Don't evict an in-flight future; stop evicting. + break + + async def wait_for_verdict( + self, + event_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for a verdict for ``event_id``. + + ``timeout`` is mode-derived (see :meth:`block_timeout`): + a positive float for ``MODE_BLOCK`` (fail-open at timeout), + ``None`` for ``MODE_HITL`` (wait indefinitely). Returns the + verdict, or ``None`` on timeout (fail-open). + + Resolved futures are kept in ``_pending_verdicts`` so a second + waiter on the same event_id (e.g. BaseTool.ainvoke firing after + ToolNode.ainvoke already consumed the verdict) finds the already- + resolved future and returns instantly instead of timing out. + Timed-out (unconsumed) futures are removed immediately; resolved + futures are evicted when the dict exceeds ``_MAX_PENDING_VERDICTS``. + """ + fut = self.register_pending(event_id) + + try: + result = await asyncio.wait_for(fut, timeout=timeout) + # Keep resolved future in dict for late waiters; cap size. + self._evict_resolved_verdicts() + return result + except TimeoutError: + logger.warning( + "Verdict timeout for event_id=%s after %ss", + event_id, + timeout, + ) + # Timed-out future is useless — remove so a retry can + # register a fresh one. + self._pending_verdicts.pop(event_id, None) + return None + + async def wait_for_tool_verdict( + self, + parent_run_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for the verdict of the LLM pair that produced this tool call. + + Looks up the LLM event_id from the run_id map and awaits its verdict. + Returns ``None`` (fail-open) when the parent LLM has not been seen, + e.g. tools invoked outside an LLM flow. + """ + event_id = self._run_id_to_event_id.get(parent_run_id) + + if event_id is None: + logger.debug( + "No LLM context for parent_run_id=%s, skipping verdict wait", + parent_run_id, + ) + + return None + + return await self.wait_for_verdict(event_id, timeout) + + async def wait_for_tool_call_verdict( + self, + tool_call_id: str, + timeout: float | None, + ) -> pb.Verdict | None: + """Wait for the verdict of the LLM pair that emitted ``tool_call_id``. + + Every tool call in an AIMessage carries the id the LLM assigned + to it; that id is threaded through LangChain to the ToolNode + invocation. Looking it up against ``_tool_call_id_to_event_id`` + gives the producing LLM's event_id, correct under parallel + agents where a ``last_llm_event_id``-style global would race. + + Returns ``None`` (fail-open) when ``tool_call_id`` is empty or + unknown (direct ToolNode invocation, pre-LLM tool, or the LLM + pair that produced it was evicted from the LRU map). + """ + if not tool_call_id: + return None + + event_id = self._tool_call_id_to_event_id.get(tool_call_id) + + if event_id is None: + logger.debug( + "No LLM context for tool_call_id=%s, skipping verdict wait", + tool_call_id, + ) + + return None + + return await self.wait_for_verdict(event_id, timeout) From 087f47fa23e8ae5158853573e372bdf06ced2145 Mon Sep 17 00:00:00 2001 From: netan-sa Date: Mon, 15 Jun 2026 23:03:20 +0200 Subject: [PATCH 5/5] Fix: linter --- sdk/python/adrian/__init__.py | 48 +---------------------------------- 1 file changed, 1 insertion(+), 47 deletions(-) diff --git a/sdk/python/adrian/__init__.py b/sdk/python/adrian/__init__.py index 8a264cf..d0d6d81 100644 --- a/sdk/python/adrian/__init__.py +++ b/sdk/python/adrian/__init__.py @@ -802,7 +802,7 @@ async def patched_astream( # --- 5. ToolNode --- -def _extract_tool_calls( +def _extract_tool_calls( # pyright: ignore[reportUnusedFunction] state: dict[str, Any] | list[BaseMessage] | Any, ) -> list[dict[str, Any]]: """Extract tool_calls from ToolNode input (all three dispatch shapes). @@ -879,52 +879,6 @@ def _patch_tool_node() -> None: original_ainvoke = ToolNode.ainvoke original_astream = getattr(ToolNode, "astream", None) - async def _gate_tool_calls(state: Any) -> bool: # noqa: ANN401 - """Returns True if tools should be BLOCKED.""" - ws = _ws_client - if ws is None: - return False - if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage] - try: - await asyncio.wait_for(ws._login_ack_received.wait(), timeout=5.0) # pyright: ignore[reportPrivateUsage] - except TimeoutError: - logger.warning("ToolNode: LoginAck not received within 5s; blocking") - return True - if not ws.policy_active(): - return False - - tc_ids: list[str] = [ - str(tc.get("id")) for tc in _extract_tool_calls(state) if tc.get("id") - ] - if not tc_ids: - return False - - cfg = _get_config() - timeout = ws.block_timeout(cfg.block_timeout if cfg else 30.0) - verdict = await ws.wait_for_tool_call_verdict(tc_ids[0], timeout) - if verdict is None: - logger.warning("ToolNode: verdict timeout, blocking (fail-closed)") - return True - if _should_halt(verdict): - logger.warning( - "halting tool execution for event_id=%s mad_code=%s", - verdict.event_id, - verdict.mad_code, - ) - return True - return False - - def _build_blocked(state: Any) -> dict[str, list[ToolMessage]]: # noqa: ANN401 - tc_ids = [tc.get("id") for tc in _extract_tool_calls(state) if tc.get("id")] - return { - "messages": [ - ToolMessage( - content="[BLOCKED by security policy]", tool_call_id=tid, name="" - ) - for tid in tc_ids - ] - } - def patched_invoke( self: Any, input: Any,