Skip to content

Commit ffe942e

Browse files
author
Dylan Huang
committed
fix rollout
1 parent c4b9b3c commit ffe942e

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

eval_protocol/pytest/default_agent_rollout_processor.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,13 @@ async def call_agent(self) -> str:
5252

5353
message = await self._call_model(self.messages, tools)
5454
self.append_message_and_log(message)
55-
if message["tool_calls"]:
55+
if message.tool_calls:
5656
# Create tasks for all tool calls to run them in parallel
5757
tool_tasks = []
58-
for tool_call in message["tool_calls"]:
59-
tool_call_id = tool_call["id"]
60-
tool_name = tool_call["function"]["name"]
61-
tool_args = tool_call["function"]["arguments"]
58+
for tool_call in message.tool_calls:
59+
tool_call_id = tool_call.id
60+
tool_name = tool_call.function.name
61+
tool_args = tool_call.function.arguments
6262
tool_args_dict = json.loads(tool_args)
6363

6464
# Create a task for each tool call
@@ -69,7 +69,7 @@ async def call_agent(self) -> str:
6969
tool_results: List[tuple[str, List[TextContent]]] = await asyncio.gather(*tool_tasks)
7070

7171
# Add all tool results to messages (they will be in the same order as tool_calls)
72-
for tool_call, (tool_call_id, content) in zip(message["tool_calls"], tool_results):
72+
for tool_call, (tool_call_id, content) in zip(message.tool_calls, tool_results):
7373
self.append_message_and_log(
7474
Message(
7575
role="tool",
@@ -80,7 +80,7 @@ async def call_agent(self) -> str:
8080
)
8181
)
8282
return await self.call_agent()
83-
return message["content"]
83+
return message.content
8484

8585
async def _call_model(self, messages: list[Message], tools: Optional[list[ChatCompletionToolParam]]) -> Message:
8686
messages = [message.model_dump() if hasattr(message, "model_dump") else message for message in messages]

0 commit comments

Comments
 (0)