Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/sdk-lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Lint the bundled Python SDK on pull requests.
#
# Mirrors the local pre-commit hooks (ruff format, ruff check,
# basedpyright) so what passes `pre-commit run` also passes CI. Scoped
# to the SDK and scripts so it only runs when those change.

name: SDK Lint

on:
pull_request:
paths:
- "sdk/**"
- "scripts/**"
- ".github/workflows/sdk-lint.yml"

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: astral-sh/setup-uv@v6
with:
enable-cache: true
# Install into a repo-root .venv (same as `make sdk-install`), which
# is where sdk/pyrightconfig.json's `venvPath: ".."` expects it. A
# sdk/.venv (uv sync --project) would make basedpyright fail to find
# the venv and exit 3.
- name: Install SDK + dev deps
run: |
uv venv
uv pip install -e "./sdk[dev]"
- name: Ruff format
run: .venv/bin/ruff format --check sdk scripts
- name: Ruff lint
run: .venv/bin/ruff check sdk scripts
- name: Basedpyright
run: .venv/bin/basedpyright -p sdk sdk/adrian
10 changes: 8 additions & 2 deletions sdk/adrian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ async def patched_astream(

def _extract_tool_calls(
state: dict[str, Any] | list[BaseMessage],
) -> list[dict[str, str]]:
) -> list[dict[str, Any]]:
"""Extract tool_calls from the ToolNode input.

``ToolNode`` is reached with three input shapes:
Expand All @@ -789,7 +789,13 @@ def _extract_tool_calls(
return [tc]
tc_id = getattr(tc, "id", None)
if tc_id:
return [{"id": tc_id, "name": getattr(tc, "name", ""), "args": getattr(tc, "args", {})}]
return [
{
"id": tc_id,
"name": getattr(tc, "name", ""),
"args": getattr(tc, "args", {}),
}
]
return []

if isinstance(state, dict):
Expand Down
4 changes: 2 additions & 2 deletions sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.24.0",
"pytest-cov>=6.0.0",
"ruff>=0.15.0",
"basedpyright>=1.38.0",
"ruff==0.15.12",
"basedpyright==1.39.3",
"pre-commit>=4.6.0",
"langgraph==1.1.2",
"langgraph-prebuilt==1.0.8",
Expand Down
18 changes: 14 additions & 4 deletions sdk/tests/test_extract_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
Covers all three ToolNode input shapes. Shape 3 (per-tool-call dispatch) is the
one that previously returned ``[]`` and silently un-gated the tool.
"""
from __future__ import annotations

from langchain_core.messages import AIMessage, HumanMessage
from __future__ import annotations

from adrian import _extract_tool_calls
from langchain_core.messages import AIMessage, HumanMessage

_TC = {"name": "read_file", "args": {"path": "/etc/shadow"}, "id": "call_1", "type": "tool_call"}
_TC = {
"name": "read_file",
"args": {"path": "/etc/shadow"},
"id": "call_1",
"type": "tool_call",
}


def test_shape1_state_dict_with_messages() -> None:
Expand All @@ -36,4 +41,9 @@ def test_shape3_per_tool_call_dict() -> None:

def test_no_tool_calls_returns_empty() -> None:
assert _extract_tool_calls({"messages": [HumanMessage("hi")]}) == []
assert _extract_tool_calls({"__type": "tool_call", "tool_call": {"name": "x"}, "state": {}}) == []
assert (
_extract_tool_calls(
{"__type": "tool_call", "tool_call": {"name": "x"}, "state": {}}
)
== []
)
Loading