Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CLA.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@ To accept this Agreement, open a pull request that adds an entry to the table be
| _example placeholder_ | _@example_ | _2026-01-01_ |
| Dhrit Timinkumar Patel | @d180 | 2026-05-20 |
| Adarsh Tiwari | @adarsh9977 | 2026-05-22 |
| Muhammad usman | @Muhammad-usman92 | 2026-06-11 |

Once a CLA-bot (cla-assistant.io or equivalent) is wired up, this manual table will be replaced by the bot's status check on each pull request. Existing signatures in this table remain valid; the bot reads from a separate signers list.
118 changes: 110 additions & 8 deletions sdk/adrian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,9 +337,11 @@ def init(
if loop is not None:
_ws_client.schedule_connect(loop)
else:
logger.debug(
"No running event loop at init(); WebSocket will connect on "
"first send from within an async context."
logger.warning(
"Adrian initialised without a running event loop. WebSocket "
"transport and BLOCK/HITL verdict handling may not be active "
"yet; sync ToolNode.invoke will fail closed until an event "
"loop connects the WebSocket and receives a policy LoginAck."
)

if auto_instrument:
Expand Down Expand Up @@ -849,13 +851,110 @@ def _build_blocked_response(
return {"messages": blocked_messages}


def _resolved_tool_call_verdict(
ws: WebSocketClient,
tool_call_id: str,
) -> tuple[pb.Verdict | None, bool]:
"""Return an already-resolved verdict for ``tool_call_id`` if one exists."""
event_id = ws._tool_call_id_to_event_id.get(tool_call_id) # pyright: ignore[reportPrivateUsage]

if event_id is None:
return None, False

fut = ws._pending_verdicts.get(event_id) # pyright: ignore[reportPrivateUsage]

if fut is None or not fut.done():
return None, False

try:
return fut.result(), True
except asyncio.CancelledError:
logger.warning(
"ToolNode: resolved sync verdict future was cancelled; halting "
"tool_call_id=%s event_id=%s",
tool_call_id,
event_id,
)
return None, False
except Exception:
logger.exception(
"ToolNode: resolved sync verdict future failed; halting "
"tool_call_id=%s event_id=%s",
tool_call_id,
event_id,
)
return None, False
finally:
ws._pending_verdicts.pop(event_id, None) # pyright: ignore[reportPrivateUsage]


def _sync_tool_node_policy_gate(input: Any) -> dict[str, list[ToolMessage]] | None: # noqa: ANN401
"""Apply the BLOCK / HITL ToolNode gate from the sync invoke path.

Returns a synthetic blocked response when execution should halt, or
``None`` when the original ToolNode should run.
"""
ws = _ws_client

if ws is None:
return None

tool_calls = _extract_tool_calls(input)

if not ws._login_ack_received.is_set(): # pyright: ignore[reportPrivateUsage]
logger.warning(
"ToolNode: LoginAck not received in sync invoke; halting "
"(refusing to run a tool without a verified policy)"
)
return _build_blocked_response(tool_calls)

if not ws.policy_active():
return None

tool_call_id = next(
(tc.get("id") for tc in tool_calls if tc.get("id")),
None,
)

if not tool_call_id:
return None

verdict, resolved = _resolved_tool_call_verdict(ws, tool_call_id)

if not resolved:
logger.warning(
"ToolNode: sync invoke cannot wait for a BLOCK/HITL verdict; "
"halting tool_call_id=%s",
tool_call_id,
)
return _build_blocked_response(tool_calls)

if verdict is None:
logger.warning(
"ToolNode: sync invoke resolved an empty verdict; halting "
"tool_call_id=%s",
tool_call_id,
)
return _build_blocked_response(tool_calls)

if _should_halt(verdict):
logger.warning(
"halting tool execution for event_id=%s mad_code=%s",
verdict.event_id,
verdict.mad_code,
)
return _build_blocked_response(tool_calls)

return None


def _patch_tool_node() -> None:
"""Patch ``ToolNode.invoke`` / ``ainvoke``.

In block mode, the async patch waits for the preceding LLM's verdict
before executing tools. On BLOCK (unless overridden by ``on_block``)
it returns synthetic ``ToolMessage`` responses instead of running the
tools. On timeout it fails open.
The async path waits for the preceding LLM's verdict before executing
tools. The sync path consumes already-resolved verdicts only; when
policy/verdict state is unavailable it fails closed because the SDK
cannot safely run the WebSocket wait without a running event loop.
"""
try:
from langgraph.prebuilt import ToolNode
Expand All @@ -874,9 +973,12 @@ def patched_invoke(
config: Any = None, # noqa: ANN401
**kwargs: Any,
) -> Any: # noqa: ANN401
"""Inject Adrian callbacks into sync ToolNode invocation."""
"""Inject Adrian callbacks; in BLOCK / HITL modes gate sync tools."""
config = _inject_callbacks(config)
blocked = _sync_tool_node_policy_gate(input)

if blocked is not None:
return blocked
return original_invoke(self, input, config=config, **kwargs)

async def patched_ainvoke(
Expand Down
163 changes: 163 additions & 0 deletions sdk/tests/test_block_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ def _tool_pair() -> PairedEvent:
)


def _resolved_verdict_future(verdict: pb.Verdict) -> asyncio.Future[pb.Verdict]:
"""Build a completed future for sync ToolNode.invoke tests."""
loop = asyncio.new_event_loop()
try:
fut: asyncio.Future[pb.Verdict] = loop.create_future()
fut.set_result(verdict)
return fut
finally:
loop.close()


class TestRunIdCorrelation:
async def test_llm_pair_populates_run_id_map(self) -> None:
mock_ws = AsyncMock()
Expand Down Expand Up @@ -185,6 +196,120 @@ def _real_tool(x: str) -> str:
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

def test_sync_in_scope_block_verdict_halts_tool(self, tmp_path: Path) -> None:
"""Sync MODE_BLOCK mirrors async: in-scope blocking verdict halts."""

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
_real_tool.called = True # type: ignore[attr-defined]

return x

_real_tool.called = False # type: ignore[attr-defined]

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

ws = adrian._ws_client
assert ws is not None
policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True)
ws._connected.set()
ws._tool_call_id_to_event_id["tc-1"] = "llm-evt"
ws._pending_verdicts["llm-evt"] = _resolved_verdict_future(
pb.Verdict(event_id="llm-evt", mad_code="M4_a", policy=policy),
)

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert _real_tool.called is False # type: ignore[attr-defined]
msgs = result["messages"]
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

def test_sync_missing_login_ack_halts_tool(self, tmp_path: Path) -> None:
"""Sync ToolNode.invoke fails closed until server policy is known."""

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
_real_tool.called = True # type: ignore[attr-defined]

return x

_real_tool.called = False # type: ignore[attr-defined]

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert _real_tool.called is False # type: ignore[attr-defined]
msgs = result["messages"]
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

def test_sync_unresolved_active_policy_halts_tool(self, tmp_path: Path) -> None:
"""Sync ToolNode.invoke fails closed when it cannot wait for verdicts."""

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
_real_tool.called = True # type: ignore[attr-defined]

return x

_real_tool.called = False # type: ignore[attr-defined]

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

ws = adrian._ws_client
assert ws is not None
_apply_mode(ws, pb.MODE_BLOCK, policy_m4=True)
ws._connected.set()
ws._tool_call_id_to_event_id["tc-1"] = "llm-evt"

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

result = tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert _real_tool.called is False # type: ignore[attr-defined]
msgs = result["messages"]
assert len(msgs) == 1
assert "BLOCKED" in msgs[0].content

async def test_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None:
"""MODE_BLOCK with policy_m2=false + mad_code='M2' → continue (out-of-scope)."""

Expand Down Expand Up @@ -226,6 +351,44 @@ def _real_tool(x: str) -> str:

assert captured == ["hi"]

def test_sync_out_of_scope_verdict_runs_tool(self, tmp_path: Path) -> None:
"""Sync MODE_BLOCK continues when the verdict family is out of scope."""
captured: list[str] = []

def _real_tool(x: str) -> str:
"""Real tool stub for sync block-mode tests."""
captured.append(x)

return x

adrian.init(
api_key="k",
log_file=str(tmp_path / "events.jsonl"),
auto_instrument=True,
ws_url="ws://x",
block_timeout=1.0,
)

ws = adrian._ws_client
assert ws is not None
policy = _apply_mode(ws, pb.MODE_BLOCK, policy_m4=True) # m2 stays False
ws._connected.set()
ws._tool_call_id_to_event_id["tc-1"] = "llm-evt"
ws._pending_verdicts["llm-evt"] = _resolved_verdict_future(
pb.Verdict(event_id="llm-evt", mad_code="M2", policy=policy),
)

tool_node = ToolNode([_real_tool])
ai = AIMessage(
content="",
tool_calls=[{"id": "tc-1", "name": "_real_tool", "args": {"x": "hi"}}],
)
state: dict[str, Any] = {"messages": [ai]}

tool_node.invoke(state, config=_runtime_config()) # pyright: ignore[reportUnknownMemberType]

assert captured == ["hi"]

async def test_timeout_fail_open_runs_tool(self, tmp_path: Path) -> None:
captured: list[str] = []

Expand Down
19 changes: 19 additions & 0 deletions sdk/tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import logging
import os
from collections.abc import Iterator
from pathlib import Path
Expand Down Expand Up @@ -66,6 +67,24 @@ def test_creates_jsonl_file(self, tmp_path: Path) -> None:

assert log.exists()

def test_warns_when_ws_init_has_no_running_loop(
self,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None:
"""init() should warn when WS enforcement starts without a loop."""
caplog.set_level(logging.WARNING, logger="adrian")
log = tmp_path / "events.jsonl"

adrian.init(
api_key="k",
log_file=str(log),
auto_instrument=False,
ws_url="ws://x",
)

assert "without a running event loop" in caplog.text


class TestShutdown:
"""Tests for adrian.shutdown()."""
Expand Down
Loading