diff --git a/dreadnode/agent/agent.py b/dreadnode/agent/agent.py index a680b877..79fbb965 100644 --- a/dreadnode/agent/agent.py +++ b/dreadnode/agent/agent.py @@ -14,7 +14,7 @@ from rigging.message import inject_system_content from ulid import ULID # can't access via rg -from dreadnode.agent.error import MaxStepsError +from dreadnode.agent.error import MaxStepsError, MaxToolCallsError from dreadnode.agent.events import ( AgentEnd, AgentError, @@ -89,7 +89,9 @@ class Agent(Model): ) """The agent's core instructions.""" max_steps: int = Config(default=10) - """The maximum number of steps (generation + tool calls).""" + """The maximum number of steps (generations).""" + max_tool_calls: int = Config(default=-1) + """The maximum number of tool calls. Defaults to infinite.""" caching: rg.caching.CacheMode | None = Config(default=None, repr=False) """How to handle cache_control entries on inference messages.""" @@ -488,10 +490,16 @@ async def _dispatch(event: AgentEvent) -> t.AsyncIterator[AgentEvent]: # noqa: raise winning_reaction # Tool calling + tool_calls = 0 async def _process_tool_call( tool_call: "rg.tools.ToolCall", ) -> t.AsyncGenerator[AgentEvent, None]: + nonlocal tool_calls + + if self.max_tool_calls != -1 and tool_calls >= self.max_tool_calls: + raise Finish("Reached maximum allowed tool calls.") + async for event in _dispatch( ToolStart( session_id=session_id, @@ -513,6 +521,7 @@ async def _process_tool_call( tool = next((t for t in self.all_tools if t.name == tool_call.name), None) if tool is not None: + tool_calls += 1 try: message, stop = await tool.handle_tool_call(tool_call) except Reaction: @@ -690,6 +699,9 @@ async def _process_tool_call( if step >= self.max_steps: error = MaxStepsError(max_steps=self.max_steps) stop_reason = "max_steps_reached" + elif self.max_tool_calls != -1 and tool_calls >= self.max_tool_calls: + error = MaxToolCallsError(max_tool_calls=self.max_tool_calls) + stop_reason = "max_tool_calls_reached" elif error is not None: stop_reason = "error" elif events and isinstance(events[-1], AgentStalled): diff --git a/dreadnode/agent/error.py b/dreadnode/agent/error.py index feda14c7..f132cce0 100644 --- a/dreadnode/agent/error.py +++ b/dreadnode/agent/error.py @@ -4,3 +4,11 @@ class MaxStepsError(Exception): def __init__(self, max_steps: int): super().__init__(f"Maximum steps reached ({max_steps}).") self.max_steps = max_steps + + +class MaxToolCallsError(Exception): + """Raise from a hook to stop the agent's run due to reaching the maximum number of tool calls.""" + + def __init__(self, max_tool_calls: int): + super().__init__(f"Maximum tool calls reached ({max_tool_calls}).") + self.max_tool_calls = max_tool_calls diff --git a/dreadnode/agent/result.py b/dreadnode/agent/result.py index f8d952c1..9a5fe01d 100644 --- a/dreadnode/agent/result.py +++ b/dreadnode/agent/result.py @@ -8,7 +8,9 @@ if t.TYPE_CHECKING: from dreadnode.agent.agent import Agent -AgentStopReason = t.Literal["finished", "max_steps_reached", "error", "stalled"] +AgentStopReason = t.Literal[ + "finished", "max_steps_reached", "max_tool_calls_reached", "error", "stalled" +] @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) diff --git a/tests/test_agent.py b/tests/test_agent.py index 5ffe3ded..dcb41ff7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -8,7 +8,7 @@ from rigging.generator.base import GeneratedMessage from dreadnode.agent.agent import Agent, TaskAgent -from dreadnode.agent.error import MaxStepsError +from dreadnode.agent.error import MaxStepsError, MaxToolCallsError from dreadnode.agent.events import AgentEnd, AgentEvent, AgentStalled, Reacted, ToolStart from dreadnode.agent.hooks.base import retry_with_feedback from dreadnode.agent.reactions import RetryWithFeedback @@ -298,6 +298,31 @@ async def test_run_stops_on_max_steps(mock_generator: MockGenerator, simple_tool assert result.steps == 1 +@pytest.mark.asyncio +async def test_run_stops_on_max_tool_calls( + mock_generator: MockGenerator, simple_tool: AnyTool +) -> None: + """Ensure the agent run terminates with a MaxToolCallsError when exceeding max_tool_calls.""" + # The agent will just keep calling the tool. + mock_generator._responses = [ + MockGenerator.tool_response("get_weather", {"city": "A"}), + MockGenerator.tool_response("get_weather", {"city": "B"}), + MockGenerator.tool_response("get_weather", {"city": "C"}), + ] + + agent = Agent( + name="MaxToolCallsAgent", + model=mock_generator, + tools=[simple_tool], + max_tool_calls=2, + ) + result = await agent.run("...") + + assert result.failed + assert result.stop_reason == "max_tool_calls_reached" + assert isinstance(result.error, MaxToolCallsError) + + @pytest.mark.asyncio async def test_run_stops_on_stop_condition( mock_generator: MockGenerator, simple_tool: AnyTool