Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions dreadnode/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from rigging.message import inject_system_content
from ulid import ULID # can't access via rg

from dreadnode.agent.error import MaxStepsError
from dreadnode.agent.error import MaxStepsError, MaxToolCallsError
from dreadnode.agent.events import (
AgentEnd,
AgentError,
Expand Down Expand Up @@ -89,7 +89,9 @@ class Agent(Model):
)
"""The agent's core instructions."""
max_steps: int = Config(default=10)
"""The maximum number of steps (generation + tool calls)."""
"""The maximum number of steps (generations)."""
max_tool_calls: int = Config(default=-1)
"""The maximum number of tool calls. Defaults to infinite."""
Comment thread
mkultraWasHere marked this conversation as resolved.
caching: rg.caching.CacheMode | None = Config(default=None, repr=False)
"""How to handle cache_control entries on inference messages."""

Expand Down Expand Up @@ -488,10 +490,16 @@ async def _dispatch(event: AgentEvent) -> t.AsyncIterator[AgentEvent]: # noqa:
raise winning_reaction

# Tool calling
tool_calls = 0

async def _process_tool_call(
tool_call: "rg.tools.ToolCall",
) -> t.AsyncGenerator[AgentEvent, None]:
nonlocal tool_calls

if self.max_tool_calls != -1 and tool_calls >= self.max_tool_calls:
raise Finish("Reached maximum allowed tool calls.")
Comment thread
rdheekonda marked this conversation as resolved.

async for event in _dispatch(
ToolStart(
session_id=session_id,
Expand All @@ -513,6 +521,7 @@ async def _process_tool_call(
tool = next((t for t in self.all_tools if t.name == tool_call.name), None)

if tool is not None:
tool_calls += 1
try:
message, stop = await tool.handle_tool_call(tool_call)
except Reaction:
Expand Down Expand Up @@ -690,6 +699,9 @@ async def _process_tool_call(
if step >= self.max_steps:
error = MaxStepsError(max_steps=self.max_steps)
stop_reason = "max_steps_reached"
elif self.max_tool_calls != -1 and tool_calls >= self.max_tool_calls:
error = MaxToolCallsError(max_tool_calls=self.max_tool_calls)
stop_reason = "max_tool_calls_reached"
elif error is not None:
stop_reason = "error"
elif events and isinstance(events[-1], AgentStalled):
Expand Down
8 changes: 8 additions & 0 deletions dreadnode/agent/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,11 @@ class MaxStepsError(Exception):
def __init__(self, max_steps: int):
super().__init__(f"Maximum steps reached ({max_steps}).")
self.max_steps = max_steps


class MaxToolCallsError(Exception):
"""Raise from a hook to stop the agent's run due to reaching the maximum number of tool calls."""

def __init__(self, max_tool_calls: int):
super().__init__(f"Maximum tool calls reached ({max_tool_calls}).")
self.max_tool_calls = max_tool_calls
4 changes: 3 additions & 1 deletion dreadnode/agent/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
if t.TYPE_CHECKING:
from dreadnode.agent.agent import Agent

AgentStopReason = t.Literal["finished", "max_steps_reached", "error", "stalled"]
AgentStopReason = t.Literal[
"finished", "max_steps_reached", "max_tool_calls_reached", "error", "stalled"
]


@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
Expand Down
27 changes: 26 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rigging.generator.base import GeneratedMessage

from dreadnode.agent.agent import Agent, TaskAgent
from dreadnode.agent.error import MaxStepsError
from dreadnode.agent.error import MaxStepsError, MaxToolCallsError
from dreadnode.agent.events import AgentEnd, AgentEvent, AgentStalled, Reacted, ToolStart
from dreadnode.agent.hooks.base import retry_with_feedback
from dreadnode.agent.reactions import RetryWithFeedback
Expand Down Expand Up @@ -298,6 +298,31 @@ async def test_run_stops_on_max_steps(mock_generator: MockGenerator, simple_tool
assert result.steps == 1


@pytest.mark.asyncio
async def test_run_stops_on_max_tool_calls(
mock_generator: MockGenerator, simple_tool: AnyTool
) -> None:
"""Ensure the agent run terminates with a MaxToolCallsError when exceeding max_tool_calls."""
# The agent will just keep calling the tool.
mock_generator._responses = [
MockGenerator.tool_response("get_weather", {"city": "A"}),
MockGenerator.tool_response("get_weather", {"city": "B"}),
MockGenerator.tool_response("get_weather", {"city": "C"}),
]

agent = Agent(
name="MaxToolCallsAgent",
model=mock_generator,
tools=[simple_tool],
max_tool_calls=2,
)
result = await agent.run("...")

assert result.failed
assert result.stop_reason == "max_tool_calls_reached"
assert isinstance(result.error, MaxToolCallsError)
Comment thread
rdheekonda marked this conversation as resolved.


@pytest.mark.asyncio
async def test_run_stops_on_stop_condition(
mock_generator: MockGenerator, simple_tool: AnyTool
Expand Down