From 1073c0ecadc2062706b929cf2cf6d30b1efbd372 Mon Sep 17 00:00:00 2001 From: raychen <815315825@qq.com> Date: Fri, 24 Apr 2026 11:29:02 +0800 Subject: [PATCH] =?UTF-8?q?bugfix:=20=E4=BF=AE=E5=A4=8Dtool=E7=9B=91?= =?UTF-8?q?=E6=8E=A7=E4=B8=A2=E5=A4=B1=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/storage/test_sql_common.py | 134 +++++++++++++++++- trpc_agent_sdk/agents/core/_llm_processor.py | 32 +++-- .../agents/core/_tools_processor.py | 18 ++- trpc_agent_sdk/filter/_run_filter.py | 6 +- trpc_agent_sdk/storage/__init__.py | 4 + trpc_agent_sdk/storage/_sql_common.py | 106 ++++++++++++++ trpc_agent_sdk/telemetry/_trace.py | 33 ++++- 7 files changed, 310 insertions(+), 23 deletions(-) diff --git a/tests/storage/test_sql_common.py b/tests/storage/test_sql_common.py index a00a91a..639c654 100644 --- a/tests/storage/test_sql_common.py +++ b/tests/storage/test_sql_common.py @@ -10,9 +10,13 @@ import base64 import json import pickle +from copy import deepcopy +from types import SimpleNamespace +from typing import Iterator from unittest.mock import MagicMock import pytest + from sqlalchemy import Text from sqlalchemy.dialects import mysql from sqlalchemy.dialects import postgresql @@ -31,9 +35,11 @@ UTF8MB4String, decode_content, decode_grounding_metadata, + decode_grounding_metadata, + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY, + TypeDecoratorHookRegistry, ) - # --------------------------------------------------------------------------- # decode_content # --------------------------------------------------------------------------- @@ -459,3 +465,129 @@ def test_all_symbols_reexported(self): assert _U is UTF8MB4String assert _dc is decode_content assert _dg is decode_grounding_metadata + + +def _build_dialect(name: str) -> SimpleNamespace: + return SimpleNamespace(name=name, type_descriptor=lambda t: t) + + +@pytest.fixture(autouse=True) +def reset_hook_registry() -> Iterator[None]: + """Reset global hook registry around each test.""" + old_load = deepcopy(TypeDecoratorHookRegistry._load_dialect_hooks) + old_bind = deepcopy(TypeDecoratorHookRegistry._process_bind_hooks) + old_result = deepcopy(TypeDecoratorHookRegistry._process_result_hooks) + try: + TypeDecoratorHookRegistry._load_dialect_hooks = {} + TypeDecoratorHookRegistry._process_bind_hooks = {} + TypeDecoratorHookRegistry._process_result_hooks = {} + yield + finally: + TypeDecoratorHookRegistry._load_dialect_hooks = old_load + TypeDecoratorHookRegistry._process_bind_hooks = old_bind + TypeDecoratorHookRegistry._process_result_hooks = old_result + + +def test_dynamic_json_all_hooks_can_override() -> None: + """DynamicJSON supports load/bind/result hook overrides.""" + json_type = DynamicJSON() + sqlite = _build_dialect("sqlite") + load_marker = object() + + def load_hook(decorator, dialect): # noqa: ANN001 + assert decorator is json_type + assert dialect.name == "sqlite" + return load_marker + + def bind_hook(decorator, value, dialect): # noqa: ANN001 + assert decorator is json_type + assert dialect.name == "sqlite" + return f"hooked-bind-{value}" + + def result_hook(decorator, value, dialect): # noqa: ANN001 + assert decorator is json_type + assert dialect.name == "sqlite" + return {"hooked_result": value} + + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_load_dialect_hook(DynamicJSON, load_hook) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicJSON, bind_hook) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(DynamicJSON, result_hook) + + assert json_type.load_dialect_impl(sqlite) is load_marker + assert json_type.process_bind_param({"k": "v"}, sqlite) == "hooked-bind-{'k': 'v'}" + assert json_type.process_result_value('{"k":"v"}', sqlite) == {"hooked_result": '{"k":"v"}'} + + +def test_dynamic_json_hook_none_falls_back_to_default_logic() -> None: + """Hook skips when returning None.""" + json_type = DynamicJSON() + sqlite = _build_dialect("sqlite") + + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicJSON, lambda _d, _v, _dialect: None) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(DynamicJSON, lambda _d, _v, _dialect: None) + + encoded = json_type.process_bind_param({"a": 1}, sqlite) + decoded = json_type.process_result_value(encoded, sqlite) + + assert encoded == '{"a": 1}' + assert decoded == {"a": 1} + + +def test_precise_timestamp_supports_all_three_hooks() -> None: + """PreciseTimestamp supports load/bind/result hooks.""" + ts_type = PreciseTimestamp() + sqlite = _build_dialect("sqlite") + load_marker = object() + + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_load_dialect_hook(PreciseTimestamp, lambda _d, _dialect: load_marker) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(PreciseTimestamp, + lambda _d, value, _dialect: f"bind-{value}") + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(PreciseTimestamp, + lambda _d, value, _dialect: f"result-{value}") + + assert ts_type.load_dialect_impl(sqlite) is load_marker + assert ts_type.process_bind_param("2026-01-01", sqlite) == "bind-2026-01-01" + assert ts_type.process_result_value("2026-01-01", sqlite) == "result-2026-01-01" + + +def test_utf8mb4_string_supports_all_three_hooks() -> None: + """UTF8MB4String supports load/bind/result hooks.""" + str_type = UTF8MB4String(length=128) + sqlite = _build_dialect("sqlite") + load_marker = object() + + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_load_dialect_hook(UTF8MB4String, lambda _d, _dialect: load_marker) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(UTF8MB4String, + lambda _d, value, _dialect: f"bind-{value}") + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_result_hook(UTF8MB4String, + lambda _d, value, _dialect: f"result-{value}") + + assert str_type.load_dialect_impl(sqlite) is load_marker + assert str_type.process_bind_param("hello", sqlite) == "bind-hello" + assert str_type.process_result_value("hello", sqlite) == "result-hello" + + +def test_dynamic_pickle_hook_order_uses_first_override_result() -> None: + """First non-None hook result wins for DynamicPickleType.""" + pickle_type = DynamicPickleType() + sqlite = _build_dialect("sqlite") + calls: list[str] = [] + + def hook_a(_d, _v, _dialect): # noqa: ANN001 + calls.append("a") + return None + + def hook_b(_d, _v, _dialect): # noqa: ANN001 + calls.append("b") + return "override" + + def hook_c(_d, _v, _dialect): # noqa: ANN001 + calls.append("c") + return "should-not-run" + + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicPickleType, hook_a) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicPickleType, hook_b) + GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.register_process_bind_hook(DynamicPickleType, hook_c) + + assert pickle_type.process_bind_param({"k": "v"}, sqlite) == "override" + assert calls == ["a", "b"] diff --git a/trpc_agent_sdk/agents/core/_llm_processor.py b/trpc_agent_sdk/agents/core/_llm_processor.py index 50bd5c5..834766d 100644 --- a/trpc_agent_sdk/agents/core/_llm_processor.py +++ b/trpc_agent_sdk/agents/core/_llm_processor.py @@ -79,20 +79,36 @@ async def call_llm_async(self, return # Step 2: Call the model and process responses with telemetry tracing. - # Avoid start_as_current_span in async generators because cancellation can - # close the generator from a different context, which may trigger - # "Token was created in a different Context" during detach. - span = tracer.start_span('call_llm') - try: + with tracer.start_as_current_span('call_llm'): event_id = Event.new_id() final_llm_response = None + aggregated_raw_function_calls: list[dict] = [] + aggregated_event_function_calls: list[dict] = [] + + def _append_function_calls(target: list[dict], calls: list) -> None: + for call in calls or []: + # Keep only telemetry-safe fields for trace attributes. + target.append({ + "id": getattr(call, "id", None), + "name": getattr(call, "name", None), + "args": getattr(call, "args", None), + }) async for llm_response in self.model.generate_async(request, stream=stream, ctx=context): + # Collect raw model-level function calls from every chunk. + raw_calls = [] + if llm_response.content and llm_response.content.parts: + for part in llm_response.content.parts: + if part.function_call: + raw_calls.append(part.function_call) + _append_function_calls(aggregated_raw_function_calls, raw_calls) + # Create Event directly from LlmResponse event = self._create_event_from_response(context, event_id, llm_response) # Process response with planner if available event = self._process_planning_response(event, context) + _append_function_calls(aggregated_event_function_calls, event.get_function_calls()) # Track the latest non-partial response for tracing # In streaming mode, only the final (non-partial) response @@ -111,9 +127,9 @@ async def call_llm_async(self, event_id, request, final_llm_response, - instruction_metadata=instruction_metadata) - finally: - span.end() + instruction_metadata=instruction_metadata, + stream_function_calls_raw=aggregated_raw_function_calls, + stream_function_calls_post_planner=aggregated_event_function_calls) except Exception as ex: # pylint: disable=broad-except logger.error("LLM call failed for agent %s: %s", author, ex) diff --git a/trpc_agent_sdk/agents/core/_tools_processor.py b/trpc_agent_sdk/agents/core/_tools_processor.py index bea101a..9977d54 100644 --- a/trpc_agent_sdk/agents/core/_tools_processor.py +++ b/trpc_agent_sdk/agents/core/_tools_processor.py @@ -221,7 +221,10 @@ async def execute_tools_async( state_end.update(merged_event.actions.state_delta) # Add merged tool call tracing - with tracer.start_as_current_span("execute_tool (merged)"): + with tracer.start_as_current_span( + "execute_tool (merged)", + attributes={"gen_ai.operation.name": "execute_tool"}, + ): trace_merged_tool_calls( response_event_id=merged_event.id, function_response_event=merged_event, @@ -285,8 +288,17 @@ async def _execute_tool(self, tool_call: FunctionCall, tool: BaseTool, context: Event: The result of tool execution """ - # Wrap tool execution in telemetry span - with tracer.start_as_current_span(f"execute_tool {tool.name}"): + # Wrap tool execution in telemetry span. + # Pass initial attributes so the Galileo sampler can make a sampling + # decision at span-creation time (before trace_tool_call sets them). + with tracer.start_as_current_span( + f"execute_tool {tool.name}", + attributes={ + "gen_ai.operation.name": "execute_tool", + "gen_ai.tool.name": tool.name, + "gen_ai.tool.description": tool.description or "", + }, + ): # Capture state before tool execution state_begin = dict(context.session.state) diff --git a/trpc_agent_sdk/filter/_run_filter.py b/trpc_agent_sdk/filter/_run_filter.py index 0a407bf..98a7364 100644 --- a/trpc_agent_sdk/filter/_run_filter.py +++ b/trpc_agent_sdk/filter/_run_filter.py @@ -51,8 +51,7 @@ async def run_stream_filters(ctx: AgentContext, req: Any, filters: list[BaseFilt if handle is None: raise ValueError("handle must be provided") current_handle = partial(stream_handler_adapter, handle) - filters.reverse() - for filter in filters: + for filter in reversed(filters): current_handle = partial(filter.run_stream, ctx, req, current_handle) async for event in current_handle(): yield event.rsp @@ -95,9 +94,8 @@ async def run_filters(ctx: AgentContext, req: Any, filters: list[BaseFilter], """ if handle is None: raise ValueError("handle must be provided") - filters.reverse() current_handle = partial(coroutine_handler_adapter, handle) - for filter in filters: + for filter in reversed(filters): current_handle = partial(filter.run, ctx, req, current_handle) rsp, error = await current_handle() if error: diff --git a/trpc_agent_sdk/storage/__init__.py b/trpc_agent_sdk/storage/__init__.py index 56c2432..a7d38bb 100644 --- a/trpc_agent_sdk/storage/__init__.py +++ b/trpc_agent_sdk/storage/__init__.py @@ -30,6 +30,8 @@ from ._sql_common import decode_content from ._sql_common import decode_grounding_metadata from ._sql_common import decode_usage_metadata +from ._sql_common import GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY +from ._sql_common import TypeDecoratorHookRegistry __all__ = [ "EXPIRE_METHOD", @@ -57,4 +59,6 @@ "decode_content", "decode_grounding_metadata", "decode_usage_metadata", + "GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY", + "TypeDecoratorHookRegistry", ] diff --git a/trpc_agent_sdk/storage/_sql_common.py b/trpc_agent_sdk/storage/_sql_common.py index 2a9c6fd..165e991 100644 --- a/trpc_agent_sdk/storage/_sql_common.py +++ b/trpc_agent_sdk/storage/_sql_common.py @@ -44,6 +44,64 @@ from trpc_agent_sdk.types import GroundingMetadata from trpc_agent_sdk.types import GenerateContentResponseUsageMetadata +LoadDialectHook = Callable[[TypeDecorator, Dialect], Any] +ProcessBindHook = Callable[[TypeDecorator, Any, Dialect], Any] +ProcessResultHook = Callable[[TypeDecorator, Any, Dialect], Any] + + +class TypeDecoratorHookRegistry: + """Global hook registry for SQLAlchemy TypeDecorator callbacks.""" + + _load_dialect_hooks: dict[type[TypeDecorator], list[LoadDialectHook]] = {} + _process_bind_hooks: dict[type[TypeDecorator], list[ProcessBindHook]] = {} + _process_result_hooks: dict[type[TypeDecorator], list[ProcessResultHook]] = {} + + @classmethod + def register_load_dialect_hook(cls, decorator_cls: type[TypeDecorator], hook: LoadDialectHook) -> None: + """Register hook for ``load_dialect_impl``.""" + cls._load_dialect_hooks.setdefault(decorator_cls, []).append(hook) + + @classmethod + def register_process_bind_hook(cls, decorator_cls: type[TypeDecorator], hook: ProcessBindHook) -> None: + """Register hook for ``process_bind_param``.""" + cls._process_bind_hooks.setdefault(decorator_cls, []).append(hook) + + @classmethod + def register_process_result_hook(cls, decorator_cls: type[TypeDecorator], hook: ProcessResultHook) -> None: + """Register hook for ``process_result_value``.""" + cls._process_result_hooks.setdefault(decorator_cls, []).append(hook) + + @classmethod + def run_load_dialect_hooks(cls, decorator: TypeDecorator, dialect: Dialect) -> Any: + """Run load hooks and return first override result.""" + for hook in cls._load_dialect_hooks.get(type(decorator), []): + result = hook(decorator, dialect) + if result is not None: + return result + return None + + @classmethod + def run_process_bind_hooks(cls, decorator: TypeDecorator, value: Any, dialect: Dialect) -> Any: + """Run bind hooks and return first override result.""" + for hook in cls._process_bind_hooks.get(type(decorator), []): + result = hook(decorator, value, dialect) + if result is not None: + return result + return None + + @classmethod + def run_process_result_hooks(cls, decorator: TypeDecorator, value: Any, dialect: Dialect) -> Any: + """Run result hooks and return first override result.""" + for hook in cls._process_result_hooks.get(type(decorator), []): + result = hook(decorator, value, dialect) + if result is not None: + return result + return None + + +# Global class object used as unified registration entry. +GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY = TypeDecoratorHookRegistry + def decode_content(content: Optional[dict[str, Any]]) -> Optional[Content]: """Decode a content object from a JSON dictionary. @@ -119,6 +177,9 @@ class DynamicJSON(TypeDecorator): impl = Text # Default implementation is TEXT def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_load_dialect_hooks(self, dialect) + if hook_result is not None: + return hook_result if dialect.name == "postgresql": return dialect.type_descriptor(postgresql.JSONB) # type: ignore if dialect.name == "mysql": @@ -126,6 +187,9 @@ def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: return dialect.type_descriptor(Text) # Default to Text for other dialects # type: ignore def process_bind_param(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_bind_hooks(self, value, dialect) + if hook_result is not None: + return hook_result if value is not None: if dialect.name == "postgresql": return value # JSONB handles dict directly @@ -134,6 +198,9 @@ def process_bind_param(self, value: Any, dialect: Dialect) -> Any: return value def process_result_value(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_result_hooks(self, value, dialect) + if hook_result is not None: + return hook_result if value is not None: if dialect.name == "postgresql": return value # JSONB returns dict directly @@ -157,6 +224,9 @@ def __init__(self, length: Optional[int] = None, *args: Any, **kwargs: Any) -> N self.length = length def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_load_dialect_hooks(self, dialect) + if hook_result is not None: + return hook_result if dialect.name == "mysql": # Use VARCHAR with utf8mb4 charset and utf8mb4_unicode_ci collation return dialect.type_descriptor(mysql.VARCHAR(self.length, charset='utf8mb4', @@ -166,6 +236,18 @@ def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: return dialect.type_descriptor(String(self.length)) return dialect.type_descriptor(String()) + def process_bind_param(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_bind_hooks(self, value, dialect) + if hook_result is not None: + return hook_result + return value + + def process_result_value(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_result_hooks(self, value, dialect) + if hook_result is not None: + return hook_result + return value + class PreciseTimestamp(TypeDecorator): """Represents a timestamp precise to the microsecond.""" @@ -174,10 +256,25 @@ class PreciseTimestamp(TypeDecorator): cache_ok = True def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_load_dialect_hooks(self, dialect) + if hook_result is not None: + return hook_result if dialect.name == "mysql": return dialect.type_descriptor(mysql.DATETIME(fsp=6)) return self.impl + def process_bind_param(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_bind_hooks(self, value, dialect) + if hook_result is not None: + return hook_result + return value + + def process_result_value(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_result_hooks(self, value, dialect) + if hook_result is not None: + return hook_result + return value + class DynamicPickleType(TypeDecorator): """Represents a type that can be pickled.""" @@ -185,6 +282,9 @@ class DynamicPickleType(TypeDecorator): impl = PickleType def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_load_dialect_hooks(self, dialect) + if hook_result is not None: + return hook_result if dialect.name == "spanner+spanner": return dialect.type_descriptor(SpannerPickleType) # type: ignore if dialect.name == "mysql": @@ -192,12 +292,18 @@ def load_dialect_impl(self, dialect: Dialect) -> TypeDecorator: return self.impl def process_bind_param(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_bind_hooks(self, value, dialect) + if hook_result is not None: + return hook_result if value is not None: if dialect.name == "spanner+spanner": return pickle.dumps(value) return value def process_result_value(self, value: Any, dialect: Dialect) -> Any: + hook_result = GLOBAL_TYPE_DECORATOR_HOOK_REGISTRY.run_process_result_hooks(self, value, dialect) + if hook_result is not None: + return hook_result if value is not None: if dialect.name == "spanner+spanner": return pickle.loads(value) diff --git a/trpc_agent_sdk/telemetry/_trace.py b/trpc_agent_sdk/telemetry/_trace.py index 51b4721..aac6657 100644 --- a/trpc_agent_sdk/telemetry/_trace.py +++ b/trpc_agent_sdk/telemetry/_trace.py @@ -58,7 +58,6 @@ def get_trpc_agent_span_name() -> str: """ Get the span name for the trpc agent. """ - global _trpc_agent_span_name # pylint: disable=invalid-name return _trpc_agent_span_name @@ -104,7 +103,6 @@ def trace_runner( state_begin: The state before the runner execution. state_end: The state after the runner execution. """ - global _trpc_agent_span_name # pylint: disable=invalid-name span = trace.get_current_span() span.set_attribute("gen_ai.system", _trpc_agent_span_name) span.set_attribute("gen_ai.operation.name", "run_runner") @@ -161,7 +159,6 @@ def trace_cancellation( state_begin: The state before the runner execution. state_partial: The partial state at cancellation point. """ - global _trpc_agent_span_name # pylint: disable=invalid-name span = trace.get_current_span() # Set span status to ERROR for cancellation @@ -225,7 +222,6 @@ def trace_agent( state_begin: The state before the agent run. state_end: The state after the agent run. """ - global _trpc_agent_span_name # pylint: disable=invalid-name span = trace.get_current_span() span.set_attribute("gen_ai.system", _trpc_agent_span_name) span.set_attribute("gen_ai.operation.name", "run_agent") @@ -265,7 +261,6 @@ def trace_tool_call( state_begin: The state before the tool execution. state_end: The state after the tool execution. """ - global _trpc_agent_span_name # pylint: disable=invalid-name span = trace.get_current_span() span.set_attribute("gen_ai.system", _trpc_agent_span_name) span.set_attribute("gen_ai.operation.name", "execute_tool") @@ -331,7 +326,6 @@ def trace_merged_tool_calls( state_begin: The state before the tool execution. state_end: The state after the tool execution. """ - global _trpc_agent_span_name # pylint: disable=invalid-name span = trace.get_current_span() span.set_attribute("gen_ai.system", _trpc_agent_span_name) span.set_attribute("gen_ai.operation.name", "execute_tool") @@ -372,6 +366,8 @@ def trace_call_llm( llm_request: LlmRequest, llm_response: LlmResponse, instruction_metadata: Optional[InstructionMetadata] = None, + stream_function_calls_raw: Optional[list[dict[str, Any]]] = None, + stream_function_calls_post_planner: Optional[list[dict[str, Any]]] = None, ): """Traces a call to the LLM. @@ -387,8 +383,11 @@ def trace_call_llm( with ``name``, ``version``, and ``labels`` attributes. When provided, these are written directly to the call_llm span for precise instruction-to-generation association. + stream_function_calls_raw: Optional function calls collected from all + raw LLM stream chunks. + stream_function_calls_post_planner: Optional function calls collected + from post-planner events emitted during stream processing. """ - global _trpc_agent_span_name # pylint: disable=invalid-name span = trace.get_current_span() # Special standard Open Telemetry GenaI attributes that indicate # that this is a span related to a Generative AI system. @@ -414,6 +413,26 @@ def trace_call_llm( llm_response_json, ) + if stream_function_calls_raw: + span.set_attribute( + f"{_trpc_agent_span_name}.stream_function_calls.raw", + _safe_json_serialize(stream_function_calls_raw), + ) + span.set_attribute( + f"{_trpc_agent_span_name}.stream_function_calls.raw_count", + len(stream_function_calls_raw), + ) + + if stream_function_calls_post_planner: + span.set_attribute( + f"{_trpc_agent_span_name}.stream_function_calls.post_planner", + _safe_json_serialize(stream_function_calls_post_planner), + ) + span.set_attribute( + f"{_trpc_agent_span_name}.stream_function_calls.post_planner_count", + len(stream_function_calls_post_planner), + ) + if llm_response.usage_metadata is not None: usage = llm_response.usage_metadata if usage.prompt_token_count and usage.total_token_count: