Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 133 additions & 1 deletion tests/storage/test_sql_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
import base64
import json
import pickle
from copy import deepcopy
from types import SimpleNamespace
from typing import Iterator
from unittest.mock import MagicMock

import pytest

from sqlalchemy import Text
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import postgresql
Expand All @@ -31,9 +35,11 @@
UTF8MB4String,
decode_content,
decode_grounding_metadata,
decode_grounding_metadata,
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY,
TypeDecoratorHookRegistry,
)


# ---------------------------------------------------------------------------
# decode_content
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -459,3 +465,129 @@ def test_all_symbols_reexported(self):
assert _U is UTF8MB4String
assert _dc is decode_content
assert _dg is decode_grounding_metadata


def _build_dialect(name: str) -> SimpleNamespace:
return SimpleNamespace(name=name, type_descriptor=lambda t: t)


@pytest.fixture(autouse=True)
def reset_hook_registry() -> Iterator[None]:
"""Reset global hook registry around each test."""
old_load = deepcopy(TypeDecoratorHookRegistry._load_dialect_hooks)
old_bind = deepcopy(TypeDecoratorHookRegistry._process_bind_hooks)
old_result = deepcopy(TypeDecoratorHookRegistry._process_result_hooks)
try:
TypeDecoratorHookRegistry._load_dialect_hooks = {}
TypeDecoratorHookRegistry._process_bind_hooks = {}
TypeDecoratorHookRegistry._process_result_hooks = {}
yield
finally:
TypeDecoratorHookRegistry._load_dialect_hooks = old_load
TypeDecoratorHookRegistry._process_bind_hooks = old_bind
TypeDecoratorHookRegistry._process_result_hooks = old_result


def test_dynamic_json_all_hooks_can_override() -> None:
"""DynamicJSON supports load/bind/result hook overrides."""
json_type = DynamicJSON()
sqlite = _build_dialect("sqlite")
load_marker = object()

def load_hook(decorator, dialect): # noqa: ANN001
assert decorator is json_type
assert dialect.name == "sqlite"
return load_marker

def bind_hook(decorator, value, dialect): # noqa: ANN001
assert decorator is json_type
assert dialect.name == "sqlite"
return f"hooked-bind-{value}"

def result_hook(decorator, value, dialect): # noqa: ANN001
assert decorator is json_type
assert dialect.name == "sqlite"
return {"hooked_result": value}

GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_load_dialect_hook(DynamicJSON, load_hook)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicJSON, bind_hook)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(DynamicJSON, result_hook)

assert json_type.load_dialect_impl(sqlite) is load_marker
assert json_type.process_bind_param({"k": "v"}, sqlite) == "hooked-bind-{'k': 'v'}"
assert json_type.process_result_value('{"k":"v"}', sqlite) == {"hooked_result": '{"k":"v"}'}


def test_dynamic_json_hook_none_falls_back_to_default_logic() -> None:
"""Hook skips when returning None."""
json_type = DynamicJSON()
sqlite = _build_dialect("sqlite")

GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicJSON, lambda _d, _v, _dialect: None)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(DynamicJSON, lambda _d, _v, _dialect: None)

encoded = json_type.process_bind_param({"a": 1}, sqlite)
decoded = json_type.process_result_value(encoded, sqlite)

assert encoded == '{"a": 1}'
assert decoded == {"a": 1}


def test_precise_timestamp_supports_all_three_hooks() -> None:
"""PreciseTimestamp supports load/bind/result hooks."""
ts_type = PreciseTimestamp()
sqlite = _build_dialect("sqlite")
load_marker = object()

GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_load_dialect_hook(PreciseTimestamp, lambda _d, _dialect: load_marker)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(PreciseTimestamp,
lambda _d, value, _dialect: f"bind-{value}")
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(PreciseTimestamp,
lambda _d, value, _dialect: f"result-{value}")

assert ts_type.load_dialect_impl(sqlite) is load_marker
assert ts_type.process_bind_param("2026-01-01", sqlite) == "bind-2026-01-01"
assert ts_type.process_result_value("2026-01-01", sqlite) == "result-2026-01-01"


def test_utf8mb4_string_supports_all_three_hooks() -> None:
"""UTF8MB4String supports load/bind/result hooks."""
str_type = UTF8MB4String(length=128)
sqlite = _build_dialect("sqlite")
load_marker = object()

GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_load_dialect_hook(UTF8MB4String, lambda _d, _dialect: load_marker)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(UTF8MB4String,
lambda _d, value, _dialect: f"bind-{value}")
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(UTF8MB4String,
lambda _d, value, _dialect: f"result-{value}")

assert str_type.load_dialect_impl(sqlite) is load_marker
assert str_type.process_bind_param("hello", sqlite) == "bind-hello"
assert str_type.process_result_value("hello", sqlite) == "result-hello"


def test_dynamic_pickle_hook_order_uses_first_override_result() -> None:
"""First non-None hook result wins for DynamicPickleType."""
pickle_type = DynamicPickleType()
sqlite = _build_dialect("sqlite")
calls: list[str] = []

def hook_a(_d, _v, _dialect): # noqa: ANN001
calls.append("a")
return None

def hook_b(_d, _v, _dialect): # noqa: ANN001
calls.append("b")
return "override"

def hook_c(_d, _v, _dialect): # noqa: ANN001
calls.append("c")
return "should-not-run"

GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicPickleType, hook_a)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicPickleType, hook_b)
GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicPickleType, hook_c)

assert pickle_type.process_bind_param({"k": "v"}, sqlite) == "override"
assert calls == ["a", "b"]
32 changes: 24 additions & 8 deletions trpc_agent_sdk/agents/core/_llm_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,36 @@ async def call_llm_async(self,
return

# Step 2: Call the model and process responses with telemetry tracing.
# Avoid start_as_current_span in async generators because cancellation can
# close the generator from a different context, which may trigger
# "Token was created in a different Context" during detach.
span = tracer.start_span('call_llm')
try:
with tracer.start_as_current_span('call_llm'):
event_id = Event.new_id()
final_llm_response = None
aggregated_raw_function_calls: list[dict] = []
aggregated_event_function_calls: list[dict] = []

def _append_function_calls(target: list[dict], calls: list) -> None:
for call in calls or []:
# Keep only telemetry-safe fields for trace attributes.
target.append({
"id": getattr(call, "id", None),
"name": getattr(call, "name", None),
"args": getattr(call, "args", None),
})

async for llm_response in self.model.generate_async(request, stream=stream, ctx=context):
# Collect raw model-level function calls from every chunk.
raw_calls = []
if llm_response.content and llm_response.content.parts:
for part in llm_response.content.parts:
if part.function_call:
raw_calls.append(part.function_call)
_append_function_calls(aggregated_raw_function_calls, raw_calls)

# Create Event directly from LlmResponse
event = self._create_event_from_response(context, event_id, llm_response)

# Process response with planner if available
event = self._process_planning_response(event, context)
_append_function_calls(aggregated_event_function_calls, event.get_function_calls())

# Track the latest non-partial response for tracing
# In streaming mode, only the final (non-partial) response
Expand All @@ -111,9 +127,9 @@ async def call_llm_async(self,
event_id,
request,
final_llm_response,
instruction_metadata=instruction_metadata)
finally:
span.end()
instruction_metadata=instruction_metadata,
stream_function_calls_raw=aggregated_raw_function_calls,
stream_function_calls_post_planner=aggregated_event_function_calls)

except Exception as ex: # pylint: disable=broad-except
logger.error("LLM call failed for agent %s: %s", author, ex)
Expand Down
18 changes: 15 additions & 3 deletions trpc_agent_sdk/agents/core/_tools_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ async def execute_tools_async(
state_end.update(merged_event.actions.state_delta)

# Add merged tool call tracing
with tracer.start_as_current_span("execute_tool (merged)"):
with tracer.start_as_current_span(
"execute_tool (merged)",
attributes={"gen_ai.operation.name": "execute_tool"},
):
trace_merged_tool_calls(
response_event_id=merged_event.id,
function_response_event=merged_event,
Expand Down Expand Up @@ -285,8 +288,17 @@ async def _execute_tool(self, tool_call: FunctionCall, tool: BaseTool, context:
Event: The result of tool execution
"""

# Wrap tool execution in telemetry span
with tracer.start_as_current_span(f"execute_tool {tool.name}"):
# Wrap tool execution in telemetry span.
# Pass initial attributes so the Galileo sampler can make a sampling
# decision at span-creation time (before trace_tool_call sets them).
with tracer.start_as_current_span(
f"execute_tool {tool.name}",
attributes={
"gen_ai.operation.name": "execute_tool",
"gen_ai.tool.name": tool.name,
"gen_ai.tool.description": tool.description or "",
},
):
# Capture state before tool execution
state_begin = dict(context.session.state)

Expand Down
6 changes: 2 additions & 4 deletions trpc_agent_sdk/filter/_run_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ async def run_stream_filters(ctx: AgentContext, req: Any, filters: list[BaseFilt
if handle is None:
raise ValueError("handle must be provided")
current_handle = partial(stream_handler_adapter, handle)
filters.reverse()
for filter in filters:
for filter in reversed(filters):
current_handle = partial(filter.run_stream, ctx, req, current_handle)
async for event in current_handle():
yield event.rsp
Expand Down Expand Up @@ -95,9 +94,8 @@ async def run_filters(ctx: AgentContext, req: Any, filters: list[BaseFilter],
"""
if handle is None:
raise ValueError("handle must be provided")
filters.reverse()
current_handle = partial(coroutine_handler_adapter, handle)
for filter in filters:
for filter in reversed(filters):
current_handle = partial(filter.run, ctx, req, current_handle)
rsp, error = await current_handle()
if error:
Expand Down
4 changes: 4 additions & 0 deletions trpc_agent_sdk/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
from ._sql_common import decode_content
from ._sql_common import decode_grounding_metadata
from ._sql_common import decode_usage_metadata
from ._sql_common import GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY
from ._sql_common import TypeDecoratorHookRegistry

__all__ = [
"EXPIRE_METHOD",
Expand Down Expand Up @@ -57,4 +59,6 @@
"decode_content",
"decode_grounding_metadata",
"decode_usage_metadata",
"GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY",
"TypeDecoratorHookRegistry",
]
Loading
Loading