Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions tests/code_executors/container/test_container_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,12 @@ async def test_exec_run_exception_propagates(self, mock_cc_cls, mock_ctx):
class TestCodeBlockDelimiter:

@patch("trpc_agent_sdk.code_executors.container._container_code_executor.ContainerClient")
def test_returns_correct_delimiter(self, mock_cc_cls):
def test_returns_default_delimiters(self, mock_cc_cls):
mock_cc_cls.return_value = Mock()
executor = ContainerCodeExecutor(image="img")
delim = executor.code_block_delimiter()
delims = executor.code_block_delimiters

assert isinstance(delim, CodeBlockDelimiter)
assert delim.start == "```tool_code\n"
assert delim.end == "\n```"
assert isinstance(delims, list)
assert all(isinstance(d, CodeBlockDelimiter) for d in delims)
assert delims[0].start == "```tool_code\n"
assert delims[0].end == "\n```"
25 changes: 10 additions & 15 deletions tests/code_executors/local/test_unsafe_local_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,22 @@ def test_custom_clean_temp_files(self):
executor = UnsafeLocalCodeExecutor(clean_temp_files=False)
assert executor.clean_temp_files is False

def test_custom_delimiter(self):
delim = CodeBlockDelimiter(start="<<<", end=">>>")
executor = UnsafeLocalCodeExecutor(delimiter=delim)
assert executor.delimiter.start == "<<<"
assert executor.delimiter.end == ">>>"


class TestCodeBlockDelimiter:
"""Tests for code_block_delimiter method."""
"""Tests for code block delimiter configuration."""

def test_returns_default_delimiter(self):
def test_returns_default_delimiters(self):
executor = UnsafeLocalCodeExecutor()
delimiter = executor.code_block_delimiter()
assert isinstance(delimiter, CodeBlockDelimiter)
assert delimiter.start == "```"
assert delimiter.end == "```"
delimiters = executor.code_block_delimiters
assert isinstance(delimiters, list)
assert all(isinstance(d, CodeBlockDelimiter) for d in delimiters)
assert delimiters[0].start == "```tool_code\n"
assert delimiters[0].end == "\n```"

def test_returns_custom_delimiter(self):
def test_returns_custom_delimiters(self):
custom = CodeBlockDelimiter(start="---", end="---")
executor = UnsafeLocalCodeExecutor(delimiter=custom)
assert executor.code_block_delimiter() == custom
executor = UnsafeLocalCodeExecutor(code_block_delimiters=[custom])
assert executor.code_block_delimiters == [custom]


class TestPrepareWorkDir:
Expand Down
11 changes: 6 additions & 5 deletions tests/code_executors/test_container_container_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,17 @@ async def test_execute_code_exception_handling(self, mock_container_client_class
await executor.execute_code(self.mock_ctx, code_input)

@patch('trpc_agent_sdk.code_executors.container._container_code_executor.ContainerClient')
def test_code_block_delimiter(self, mock_container_client_class):
"""Test code_block_delimiter method."""
def test_code_block_delimiters(self, mock_container_client_class):
"""Test default code_block_delimiters value."""
mock_container_client = Mock()
mock_container_client_class.return_value = mock_container_client

executor = ContainerCodeExecutor(image="python:3-slim")
delimiter = executor.code_block_delimiter()
delimiters = executor.code_block_delimiters

assert delimiter.start == "```tool_code\n"
assert delimiter.end == "\n```"
assert isinstance(delimiters, list)
assert delimiters[0].start == "```tool_code\n"
assert delimiters[0].end == "\n```"

@patch('trpc_agent_sdk.code_executors.container._container_code_executor.ContainerClient')
async def test_execute_code_empty_language_defaults_to_python(self, mock_container_client_class):
Expand Down
9 changes: 5 additions & 4 deletions tests/code_executors/test_local_unsafe_local_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ def test_init_with_custom_values(self):
assert executor.timeout == 30.0
assert executor.clean_temp_files is False

def test_code_block_delimiter(self):
"""Test code_block_delimiter method."""
def test_code_block_delimiters(self):
"""Test default code_block_delimiters value."""
executor = UnsafeLocalCodeExecutor()
delimiter = executor.code_block_delimiter()
delimiters = executor.code_block_delimiters

assert isinstance(delimiter, CodeBlockDelimiter)
assert isinstance(delimiters, list)
assert all(isinstance(delimiter, CodeBlockDelimiter) for delimiter in delimiters)

@patch('trpc_agent_sdk.code_executors.local._unsafe_local_code_executor.async_execute_command')
async def test_execute_code_python(self, mock_async_execute):
Expand Down
42 changes: 31 additions & 11 deletions trpc_agent_sdk/agents/core/_code_execution_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,19 +248,20 @@ async def _run_post_processor(
if not llm_response or not llm_response.content:
return

# For container and unsafe local executors, we handle execution in post-processing
if isinstance(code_executor, (ContainerCodeExecutor, UnsafeLocalCodeExecutor)):
# Continue with post-processing for these executors
pass

code_executor_context = CodeExecutorContext(invocation_context.session.state)
if (code_executor.execute_once_per_invocation
and code_executor_context.has_executed_in_invocation(invocation_context.invocation_id)):
return

# Skip if the error count exceeds the max retry attempts.
if code_executor_context.get_error_count(invocation_context.invocation_id) >= code_executor.error_retry_attempts:
return

# [Step 1] Extract code from the model predict response and truncate the
# content to the part with the first code block.
response_content = llm_response.content
# [Step 1] Extract code from a cloned response content.
# IMPORTANT: extract_code_and_truncate_content mutates the input Content in place.
# Clone first to avoid mutating shared references that are later used by
# telemetry tracing and session persistence.
response_content = copy.deepcopy(llm_response.content)
code_blocks = CodeExecutionUtils.extract_code_and_truncate_content(response_content,
code_executor.code_block_delimiters,
code_executor.ignore_codes)
Expand All @@ -284,6 +285,7 @@ async def _run_post_processor(
code_blocks,
code_execution_result,
)
code_executor_context.mark_executed_in_invocation(invocation_context.invocation_id)

# Generate events for code execution results
# Event 1: Code execution event
Expand All @@ -302,9 +304,27 @@ async def _run_post_processor(
code_execution_result)
yield result_event

# [Step 3] Skip processing the original model response
# to continue code generation loop.
llm_response.content = None
# [Step 3] Skip executable code parts to continue the code generation loop,
# while preserving:
# 1) text parts (after truncation/extraction) for conversation memory
# 2) function_call parts for downstream function_call/function_response pairing
retained_parts: list[Part] = []

# Keep text parts from the transformed response content (code stripped out).
if response_content and response_content.parts:
retained_parts.extend([copy.deepcopy(part) for part in response_content.parts if part.text])

# Keep original function_call parts from the original response payload.
if llm_response.content and llm_response.content.parts:
retained_parts.extend([copy.deepcopy(part) for part in llm_response.content.parts if part.function_call])

if retained_parts:
llm_response.content = Content(
role=llm_response.content.role if llm_response.content else "model",
parts=retained_parts,
)
else:
llm_response.content = None


def _extract_and_replace_inline_files(
Expand Down
12 changes: 10 additions & 2 deletions trpc_agent_sdk/code_executors/_base_code_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ class BaseCodeExecutor(BaseModel):
error_retry_attempts: int = 2
"""The number of attempts to retry on consecutive code execution errors. Default to 2."""

execute_once_per_invocation: bool = False
"""Whether to execute model-extracted code at most once per invocation.

When enabled, post-processing code execution runs only for the first
detected code block in a single ``invocation_id`` and skips subsequent
auto-execution attempts for that invocation.
"""

code_block_delimiters: list[CodeBlockDelimiter] = [
CodeBlockDelimiter(start="```tool_code\n", end="\n```"),
CodeBlockDelimiter(start="```python\n", end="\n```"),
Expand Down Expand Up @@ -100,10 +108,10 @@ async def execute_code(
The code execution result.
"""

@abc.abstractmethod
def code_block_delimiter(self) -> CodeBlockDelimiter:
def code_block_delimiter(self) -> list[CodeBlockDelimiter]:
"""Return the code block delimiter used by this executor.

Returns:
CodeBlockDelimiter instance
"""
return self.code_block_delimiters
9 changes: 9 additions & 0 deletions trpc_agent_sdk/code_executors/_code_executor_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def _ensure_code_execution_state(self) -> None:
"execution_id": None,
"error_counts": {},
"code_execution_results": {},
"executed_invocations": {},
}

def get_input_files(self) -> List[CodeFile]:
Expand Down Expand Up @@ -138,6 +139,14 @@ def update_code_execution_result(self, invocation_id: str, code_blocks: List[Cod
code_execution_result.model_dump(),
})

def has_executed_in_invocation(self, invocation_id: str) -> bool:
"""Whether code has already been executed in a given invocation."""
return bool(self.session_state["code_execution"]["executed_invocations"].get(invocation_id, False))

def mark_executed_in_invocation(self, invocation_id: str) -> None:
"""Mark that code execution has happened in a given invocation."""
self.session_state["code_execution"]["executed_invocations"][invocation_id] = True

def get_state_delta(self) -> Dict:
"""Get state delta for the current execution.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from trpc_agent_sdk.context import InvocationContext

from .._base_code_executor import BaseCodeExecutor
from .._types import CodeBlockDelimiter
from .._types import CodeExecutionInput
from .._types import CodeExecutionResult
from .._types import create_code_execution_result
Expand Down Expand Up @@ -149,12 +148,3 @@ async def execute_code(
output = "".join(all_output)
err_str = "".join(all_errors)
return create_code_execution_result(stdout=output, stderr=err_str)

@override
def code_block_delimiter(self) -> CodeBlockDelimiter:
"""Return the code block delimiter used by this executor.

Returns:
CodeBlockDelimiter instance
"""
return CodeBlockDelimiter(start="```tool_code\n", end="\n```")
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from .._base_code_executor import BaseCodeExecutor
from .._types import CodeBlock
from .._types import CodeBlockDelimiter
from .._types import CodeExecutionInput
from .._types import CodeExecutionResult
from .._types import create_code_execution_result
Expand All @@ -48,9 +47,6 @@ class UnsafeLocalCodeExecutor(BaseCodeExecutor):
clean_temp_files: bool = Field(default=True,
description="Whether to clean temporary files after the code execution.")

delimiter: CodeBlockDelimiter = Field(default_factory=CodeBlockDelimiter,
description="The delimiter for the code execution.")

def __init__(self, **data):
"""Initialize the UnsafeLocalCodeExecutor."""
if "stateful" in data and data["stateful"]:
Expand Down Expand Up @@ -97,11 +93,6 @@ async def execute_code(self, invocation_context: InvocationContext,
return create_code_execution_result(stdout="\n".join(output_parts) if output_parts else "",
stderr="\n".join(error_parts) if error_parts else "")

@override
def code_block_delimiter(self) -> CodeBlockDelimiter:
"""Return the code block delimiter used by this executor."""
return self.delimiter

def _prepare_work_dir(self, execution_id: str) -> tuple[Path, bool]:
"""Prepare working directory for execution.

Expand Down
Loading