diff --git a/py/noxfile.py b/py/noxfile.py index 61ea0aec..5e51312b 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -187,7 +187,7 @@ def test_anthropic(session, version): def test_google_genai(session, version): _install_test_deps(session) _install(session, "google-genai", version) - _run_tests(session, f"{WRAPPER_DIR}/test_google_genai.py") + _run_tests(session, f"{INTEGRATION_DIR}/google_genai/test_google_genai.py") _run_core_tests(session) diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index 25dd436a..48276ea3 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -7,7 +7,13 @@ import logging from contextlib import contextmanager -from braintrust.integrations import ADKIntegration, AgnoIntegration, AnthropicIntegration, ClaudeAgentSDKIntegration +from braintrust.integrations import ( + ADKIntegration, + AgnoIntegration, + AnthropicIntegration, + ClaudeAgentSDKIntegration, + GoogleGenAIIntegration, +) __all__ = ["auto_instrument"] @@ -113,7 +119,7 @@ def auto_instrument( if pydantic_ai: results["pydantic_ai"] = _instrument_pydantic_ai() if google_genai: - results["google_genai"] = _instrument_google_genai() + results["google_genai"] = _instrument_integration(GoogleGenAIIntegration) if agno: results["agno"] = _instrument_integration(AgnoIntegration) if claude_agent_sdk: @@ -156,14 +162,6 @@ def _instrument_pydantic_ai() -> bool: return False -def _instrument_google_genai() -> bool: - with _try_patch(): - from braintrust.wrappers.google_genai import setup_genai - - return setup_genai() - return False - - def _instrument_dspy() -> bool: with _try_patch(): from braintrust.wrappers.dspy import patch_dspy diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index 35324c1c..db4f048f 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -2,6 +2,13 @@ from .agno import AgnoIntegration from .anthropic import AnthropicIntegration from .claude_agent_sdk import ClaudeAgentSDKIntegration +from .google_genai import GoogleGenAIIntegration -__all__ = ["ADKIntegration", "AgnoIntegration", "AnthropicIntegration", "ClaudeAgentSDKIntegration"] +__all__ = [ + "ADKIntegration", + "AgnoIntegration", + "AnthropicIntegration", + "ClaudeAgentSDKIntegration", + "GoogleGenAIIntegration", +] diff --git a/py/src/braintrust/integrations/google_genai/__init__.py b/py/src/braintrust/integrations/google_genai/__init__.py new file mode 100644 index 00000000..1b8a24dd --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/__init__.py @@ -0,0 +1,39 @@ +"""Braintrust integration for Google GenAI.""" + +import logging + +from braintrust.logger import NOOP_SPAN, current_span, init_logger + +from .integration import GoogleGenAIIntegration + + +logger = logging.getLogger(__name__) + +__all__ = [ + "GoogleGenAIIntegration", + "setup_genai", +] + + +def setup_genai( + api_key: str | None = None, + project_id: str | None = None, + project_name: str | None = None, +) -> bool: + """Setup Braintrust integration with Google GenAI. + + Will automatically patch Google GenAI models for automatic tracing. + + Args: + api_key: Braintrust API key. + project_id: Braintrust project ID. + project_name: Braintrust project name. + + Returns: + True if setup was successful, False if google-genai is not installed. + """ + span = current_span() + if span == NOOP_SPAN: + init_logger(project=project_name, api_key=api_key, project_id=project_id) + + return GoogleGenAIIntegration.setup() diff --git a/py/src/braintrust/wrappers/cassettes/test_basic_completion[stream].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion[stream].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_basic_completion[stream].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion[stream].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_basic_completion[sync].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion[sync].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_basic_completion[sync].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion[sync].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_basic_completion_async[async].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion_async[async].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_basic_completion_async[async].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion_async[async].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_basic_completion_async[async_stream].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion_async[async_stream].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_basic_completion_async[async_stream].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_basic_completion_async[async_stream].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_embed_content.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_embed_content.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_embed_content.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_embed_content.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_embed_content_async.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_embed_content_async.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_embed_content_async.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_embed_content_async.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_error_handling.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_error_handling.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_error_handling.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_error_handling.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_multi_turn.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_multi_turn.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_multi_turn.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_multi_turn.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_stop_sequences.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_stop_sequences.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_stop_sequences.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_stop_sequences.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_system_prompt.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_system_prompt.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_system_prompt.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_system_prompt.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_temperature_and_top_p.yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_temperature_and_top_p.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_temperature_and_top_p.yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_temperature_and_top_p.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_tool_use[stream].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_tool_use[stream].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_tool_use[stream].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_tool_use[stream].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_tool_use[sync].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_tool_use[sync].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_tool_use[sync].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_tool_use[sync].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_tool_use_async[async].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_tool_use_async[async].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_tool_use_async[async].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_tool_use_async[async].yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_tool_use_async[async_stream].yaml b/py/src/braintrust/integrations/google_genai/cassettes/test_tool_use_async[async_stream].yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_tool_use_async[async_stream].yaml rename to py/src/braintrust/integrations/google_genai/cassettes/test_tool_use_async[async_stream].yaml diff --git a/py/src/braintrust/integrations/google_genai/integration.py b/py/src/braintrust/integrations/google_genai/integration.py new file mode 100644 index 00000000..eaa5a201 --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/integration.py @@ -0,0 +1,32 @@ +"""Google GenAI integration — orchestration class and setup entry-point.""" + +import logging + +from braintrust.integrations.base import BaseIntegration + +from .patchers import ( + AsyncModelsEmbedContentPatcher, + AsyncModelsGenerateContentPatcher, + AsyncModelsGenerateContentStreamPatcher, + ModelsEmbedContentPatcher, + ModelsGenerateContentPatcher, + ModelsGenerateContentStreamPatcher, +) + + +logger = logging.getLogger(__name__) + + +class GoogleGenAIIntegration(BaseIntegration): + """Braintrust instrumentation for the Google GenAI Python SDK.""" + + name = "google_genai" + import_names = ("google.genai",) + patchers = ( + ModelsGenerateContentPatcher, + ModelsGenerateContentStreamPatcher, + ModelsEmbedContentPatcher, + AsyncModelsGenerateContentPatcher, + AsyncModelsGenerateContentStreamPatcher, + AsyncModelsEmbedContentPatcher, + ) diff --git a/py/src/braintrust/integrations/google_genai/patchers.py b/py/src/braintrust/integrations/google_genai/patchers.py new file mode 100644 index 00000000..600a604b --- /dev/null +++ b/py/src/braintrust/integrations/google_genai/patchers.py @@ -0,0 +1,76 @@ +"""Google GenAI patchers — one patcher per coherent patch target.""" + +from braintrust.integrations.base import FunctionWrapperPatcher + +from .tracing import ( + _async_embed_content_wrapper, + _async_generate_content_stream_wrapper, + _async_generate_content_wrapper, + _embed_content_wrapper, + _generate_content_stream_wrapper, + _generate_content_wrapper, +) + + +# --------------------------------------------------------------------------- +# Sync Models patchers +# --------------------------------------------------------------------------- + + +class ModelsGenerateContentPatcher(FunctionWrapperPatcher): + """Patch ``Models._generate_content`` for tracing.""" + + name = "google_genai.models.generate_content" + target_module = "google.genai.models" + target_path = "Models._generate_content" + wrapper = _generate_content_wrapper + + +class ModelsGenerateContentStreamPatcher(FunctionWrapperPatcher): + """Patch ``Models.generate_content_stream`` for tracing.""" + + name = "google_genai.models.generate_content_stream" + target_module = "google.genai.models" + target_path = "Models.generate_content_stream" + wrapper = _generate_content_stream_wrapper + + +class ModelsEmbedContentPatcher(FunctionWrapperPatcher): + """Patch ``Models.embed_content`` for tracing.""" + + name = "google_genai.models.embed_content" + target_module = "google.genai.models" + target_path = "Models.embed_content" + wrapper = _embed_content_wrapper + + +# --------------------------------------------------------------------------- +# Async Models patchers +# --------------------------------------------------------------------------- + + +class AsyncModelsGenerateContentPatcher(FunctionWrapperPatcher): + """Patch ``AsyncModels.generate_content`` for tracing.""" + + name = "google_genai.async_models.generate_content" + target_module = "google.genai.models" + target_path = "AsyncModels.generate_content" + wrapper = _async_generate_content_wrapper + + +class AsyncModelsGenerateContentStreamPatcher(FunctionWrapperPatcher): + """Patch ``AsyncModels.generate_content_stream`` for tracing.""" + + name = "google_genai.async_models.generate_content_stream" + target_module = "google.genai.models" + target_path = "AsyncModels.generate_content_stream" + wrapper = _async_generate_content_stream_wrapper + + +class AsyncModelsEmbedContentPatcher(FunctionWrapperPatcher): + """Patch ``AsyncModels.embed_content`` for tracing.""" + + name = "google_genai.async_models.embed_content" + target_module = "google.genai.models" + target_path = "AsyncModels.embed_content" + wrapper = _async_embed_content_wrapper diff --git a/py/src/braintrust/wrappers/test_google_genai.py b/py/src/braintrust/integrations/google_genai/test_google_genai.py similarity index 99% rename from py/src/braintrust/wrappers/test_google_genai.py rename to py/src/braintrust/integrations/google_genai/test_google_genai.py index 73a31e71..9834fe30 100644 --- a/py/src/braintrust/wrappers/test_google_genai.py +++ b/py/src/braintrust/integrations/google_genai/test_google_genai.py @@ -4,8 +4,8 @@ import pytest from braintrust import logger +from braintrust.integrations.google_genai import setup_genai from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.google_genai import setup_genai from braintrust.wrappers.test_utils import verify_autoinstrument_script from google.genai import types from google.genai.client import Client diff --git a/py/src/braintrust/wrappers/google_genai/__init__.py b/py/src/braintrust/integrations/google_genai/tracing.py similarity index 73% rename from py/src/braintrust/wrappers/google_genai/__init__.py rename to py/src/braintrust/integrations/google_genai/tracing.py index 3bdae565..ed7f572e 100644 --- a/py/src/braintrust/wrappers/google_genai/__init__.py +++ b/py/src/braintrust/integrations/google_genai/tracing.py @@ -1,9 +1,13 @@ +"""Google GenAI-specific span creation, metadata extraction, stream handling, and output normalization.""" + import logging import time from collections.abc import Awaitable, Callable, Iterable from typing import TYPE_CHECKING, Any from braintrust.bt_json import bt_safe_deep_copy +from braintrust.logger import Attachment, start_span +from braintrust.span_types import SpanTypeAttribute if TYPE_CHECKING: @@ -12,140 +16,16 @@ GenerateContentResponse, GenerateContentResponseUsageMetadata, ) -from braintrust.logger import NOOP_SPAN, Attachment, current_span, init_logger, start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - logger = logging.getLogger(__name__) -def setup_genai( - api_key: str | None = None, - project_id: str | None = None, - project_name: str | None = None, -) -> bool: - """ - Setup Braintrust integration with Google GenAI. - - Returns: - True if setup was successful, False if google-genai is not installed. - """ - span = current_span() - if span == NOOP_SPAN: - init_logger(project=project_name, api_key=api_key, project_id=project_id) - - try: - import google.genai as genai # pyright: ignore - from google.genai import models - - genai.Client = wrap_client(genai.Client) - models.Models = wrap_models(models.Models) - models.AsyncModels = wrap_async_models(models.AsyncModels) - return True - except ImportError: - return False - - -def wrap_client(Client: Any): - if is_patched(Client): - return Client - - # noop for now, but may be useful in the future - - mark_patched(Client) - return Client - - -def wrap_models(Models: Any): - if is_patched(Models): - return Models - - def wrap_generate_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return _run_traced_call( - instance._api_client, - args, - kwargs, - name="generate_content", - invoke=lambda: wrapped(*args, **kwargs), - process_result=_gc_process_result, - ) - - wrap_function_wrapper(Models, "_generate_content", wrap_generate_content) - - def wrap_generate_content_stream(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return _run_stream_traced_call( - instance._api_client, - args, - kwargs, - name="generate_content_stream", - invoke=lambda: wrapped(*args, **kwargs), - aggregate=_aggregate_generate_content_chunks, - ) - - wrap_function_wrapper(Models, "generate_content_stream", wrap_generate_content_stream) - - def wrap_embed_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return _run_traced_call( - instance._api_client, - args, - kwargs, - name="embed_content", - invoke=lambda: wrapped(*args, **kwargs), - process_result=_embed_process_result, - ) - - wrap_function_wrapper(Models, "embed_content", wrap_embed_content) - - mark_patched(Models) - return Models - - -def wrap_async_models(AsyncModels: Any): - if is_patched(AsyncModels): - return AsyncModels - - async def wrap_generate_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return await _run_async_traced_call( - instance._api_client, - args, - kwargs, - name="generate_content", - invoke=lambda: wrapped(*args, **kwargs), - process_result=_gc_process_result, - ) - - wrap_function_wrapper(AsyncModels, "generate_content", wrap_generate_content) - - async def wrap_generate_content_stream(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return _run_async_stream_traced_call( - instance._api_client, - args, - kwargs, - name="generate_content_stream", - invoke=lambda: wrapped(*args, **kwargs), - aggregate=_aggregate_generate_content_chunks, - ) - - wrap_function_wrapper(AsyncModels, "generate_content_stream", wrap_generate_content_stream) - - async def wrap_embed_content(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return await _run_async_traced_call( - instance._api_client, - args, - kwargs, - name="embed_content", - invoke=lambda: wrapped(*args, **kwargs), - process_result=_embed_process_result, - ) +# --------------------------------------------------------------------------- +# Serialization helpers +# --------------------------------------------------------------------------- - wrap_function_wrapper(AsyncModels, "embed_content", wrap_embed_content) - mark_patched(AsyncModels) - return AsyncModels - - -def _serialize_input(api_client: Any, input: dict[str, Any]): +def _serialize_input(api_client: Any, input: dict[str, Any]) -> dict[str, Any]: config = bt_safe_deep_copy(input.get("config")) if config is not None: @@ -163,113 +43,6 @@ def _serialize_input(api_client: Any, input: dict[str, Any]): return input -def _gc_process_result(result: "GenerateContentResponse", start: float) -> tuple[Any, dict[str, Any]]: - return result, _extract_generate_content_metrics(result, start) - - -def _embed_process_result(result: "EmbedContentResponse", start: float) -> tuple[Any, dict[str, Any]]: - return _extract_embed_content_output(result), _extract_embed_content_metrics(result, start) - - -def _prepare_traced_call( - api_client: Any, args: list[Any], kwargs: dict[str, Any] -) -> tuple[dict[str, Any], dict[str, Any]]: - input, clean_kwargs = get_args_kwargs(args, kwargs, ["model", "contents", "config"], ["contents", "config"]) - return _serialize_input(api_client, input), clean_kwargs - - -def _run_traced_call( - api_client: Any, - args: list[Any], - kwargs: dict[str, Any], - *, - name: str, - invoke: Callable[[], Any], - process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], -): - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) - - start = time.time() - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: - result = invoke() - output, metrics = process_result(result, start) - span.log(output=output, metrics=metrics) - return result - - -async def _run_async_traced_call( - api_client: Any, - args: list[Any], - kwargs: dict[str, Any], - *, - name: str, - invoke: Callable[[], Awaitable[Any]], - process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], -): - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) - - start = time.time() - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: - result = await invoke() - output, metrics = process_result(result, start) - span.log(output=output, metrics=metrics) - return result - - -def _run_stream_traced_call( - api_client: Any, - args: list[Any], - kwargs: dict[str, Any], - *, - name: str, - invoke: Callable[[], Any], - aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], -): - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) - - start = time.time() - first_token_time = None - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: - chunks = [] - for chunk in invoke(): - if first_token_time is None: - first_token_time = time.time() - chunks.append(chunk) - yield chunk - - output, metrics = aggregate(chunks, start, first_token_time) - span.log(output=output, metrics=metrics) - return output - - -def _run_async_stream_traced_call( - api_client: Any, - args: list[Any], - kwargs: dict[str, Any], - *, - name: str, - invoke: Callable[[], Awaitable[Any]], - aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], -): - input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) - - async def stream_generator(): - start = time.time() - first_token_time = None - with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: - chunks = [] - async for chunk in await invoke(): - if first_token_time is None: - first_token_time = time.time() - chunks.append(chunk) - yield chunk - - output, metrics = aggregate(chunks, start, first_token_time) - span.log(output=output, metrics=metrics) - - return stream_generator() - - def _serialize_contents(contents: Any) -> Any: """Serialize contents, converting binary data to base64-encoded data URLs.""" if contents is None: @@ -327,7 +100,7 @@ def _serialize_content_item(item: Any) -> Any: return item -def _serialize_tools(api_client: Any, input: Any | None): +def _serialize_tools(api_client: Any, input: Any | None) -> Any | None: try: from google.genai.models import ( _GenerateContentParameters_to_mldev, # pyright: ignore [reportPrivateUsage] @@ -346,22 +119,35 @@ def _serialize_tools(api_client: Any, input: Any | None): return None -def omit(obj: dict[str, Any], keys: Iterable[str]): +# --------------------------------------------------------------------------- +# Argument extraction helpers +# --------------------------------------------------------------------------- + + +def _omit(obj: dict[str, Any], keys: Iterable[str]) -> dict[str, Any]: return {k: v for k, v in obj.items() if k not in keys} -def is_patched(obj: Any): - return getattr(obj, "_braintrust_patched", False) +def _get_args_kwargs( + args: list[str], kwargs: dict[str, Any], keys: Iterable[str], omit_keys: Iterable[str] | None = None +) -> tuple[dict[str, Any], dict[str, Any]]: + return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, _omit(kwargs, omit_keys or keys) -def mark_patched(obj: Any): - return setattr(obj, "_braintrust_patched", True) +def _clean(obj: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in obj.items() if v is not None} -def get_args_kwargs( - args: list[str], kwargs: dict[str, Any], keys: Iterable[str], omit_keys: Iterable[str] | None = None -): - return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, omit_keys or keys) +def _prepare_traced_call( + api_client: Any, args: list[Any], kwargs: dict[str, Any] +) -> tuple[dict[str, Any], dict[str, Any]]: + input, clean_kwargs = _get_args_kwargs(args, kwargs, ["model", "contents", "config"], ["contents", "config"]) + return _serialize_input(api_client, input), clean_kwargs + + +# --------------------------------------------------------------------------- +# Metric extraction helpers +# --------------------------------------------------------------------------- def _extract_usage_metadata_metrics( @@ -392,7 +178,7 @@ def _extract_generate_content_metrics(response: "GenerateContentResponse", start if hasattr(response, "usage_metadata") and response.usage_metadata: _extract_usage_metadata_metrics(response.usage_metadata, metrics) - return clean(dict(metrics)) + return _clean(dict(metrics)) def _extract_embed_content_output(response: "EmbedContentResponse") -> dict[str, Any]: @@ -400,7 +186,7 @@ def _extract_embed_content_output(response: "EmbedContentResponse") -> dict[str, first_embedding = embeddings[0] if embeddings else None first_values = getattr(first_embedding, "values", None) or [] - return clean( + return _clean( { "embedding_length": len(first_values) if first_values else None, "embeddings_count": len(embeddings) if embeddings else None, @@ -433,7 +219,25 @@ def _extract_embed_content_metrics(response: "EmbedContentResponse", start: floa if billable_character_count is not None: metrics["billable_characters"] = billable_character_count - return clean(metrics) + return _clean(metrics) + + +# --------------------------------------------------------------------------- +# Result processing helpers +# --------------------------------------------------------------------------- + + +def _gc_process_result(result: "GenerateContentResponse", start: float) -> tuple[Any, dict[str, Any]]: + return result, _extract_generate_content_metrics(result, start) + + +def _embed_process_result(result: "EmbedContentResponse", start: float) -> tuple[Any, dict[str, Any]]: + return _extract_embed_content_output(result), _extract_embed_content_metrics(result, start) + + +# --------------------------------------------------------------------------- +# Stream aggregation +# --------------------------------------------------------------------------- def _aggregate_generate_content_chunks( @@ -524,22 +328,174 @@ def _aggregate_generate_content_chunks( if text: aggregated["text"] = text - clean_metrics = clean(dict(metrics)) + clean_metrics = _clean(dict(metrics)) return aggregated, clean_metrics -def clean(obj: dict[str, Any]) -> dict[str, Any]: - return {k: v for k, v in obj.items() if v is not None} +# --------------------------------------------------------------------------- +# Traced call orchestration +# --------------------------------------------------------------------------- -def get_path(obj: dict[str, Any], path: str, default: Any = None) -> Any | None: - keys = path.split(".") - current = obj +def _run_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Any], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], +) -> Any: + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + start = time.time() + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + result = invoke() + output, metrics = process_result(result, start) + span.log(output=output, metrics=metrics) + return result - for key in keys: - if not (isinstance(current, dict) and key in current): - return default - current = current[key] - return current +async def _run_async_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Awaitable[Any]], + process_result: Callable[[Any, float], tuple[Any, dict[str, Any]]], +) -> Any: + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + start = time.time() + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + result = await invoke() + output, metrics = process_result(result, start) + span.log(output=output, metrics=metrics) + return result + + +def _run_stream_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Any], + aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], +) -> Any: + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + start = time.time() + first_token_time = None + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + chunks = [] + for chunk in invoke(): + if first_token_time is None: + first_token_time = time.time() + chunks.append(chunk) + yield chunk + + output, metrics = aggregate(chunks, start, first_token_time) + span.log(output=output, metrics=metrics) + return output + + +def _run_async_stream_traced_call( + api_client: Any, + args: list[Any], + kwargs: dict[str, Any], + *, + name: str, + invoke: Callable[[], Awaitable[Any]], + aggregate: Callable[[list[Any], float, float | None], tuple[Any, dict[str, Any]]], +) -> Any: + input, clean_kwargs = _prepare_traced_call(api_client, args, kwargs) + + async def stream_generator(): + start = time.time() + first_token_time = None + with start_span(name=name, type=SpanTypeAttribute.LLM, input=input, metadata=clean_kwargs) as span: + chunks = [] + async for chunk in await invoke(): + if first_token_time is None: + first_token_time = time.time() + chunks.append(chunk) + yield chunk + + output, metrics = aggregate(chunks, start, first_token_time) + span.log(output=output, metrics=metrics) + + return stream_generator() + + +# --------------------------------------------------------------------------- +# wrapt wrapper functions (used by patchers) +# --------------------------------------------------------------------------- + + +def _generate_content_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_gc_process_result, + ) + + +def _generate_content_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_stream_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content_stream", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_generate_content_chunks, + ) + + +def _embed_content_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_traced_call( + instance._api_client, + args, + kwargs, + name="embed_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_embed_process_result, + ) + + +async def _async_generate_content_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_gc_process_result, + ) + + +async def _async_generate_content_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return _run_async_stream_traced_call( + instance._api_client, + args, + kwargs, + name="generate_content_stream", + invoke=lambda: wrapped(*args, **kwargs), + aggregate=_aggregate_generate_content_chunks, + ) + + +async def _async_embed_content_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + return await _run_async_traced_call( + instance._api_client, + args, + kwargs, + name="embed_content", + invoke=lambda: wrapped(*args, **kwargs), + process_result=_embed_process_result, + ) diff --git a/py/src/braintrust/wrappers/google_genai.py b/py/src/braintrust/wrappers/google_genai.py new file mode 100644 index 00000000..853449aa --- /dev/null +++ b/py/src/braintrust/wrappers/google_genai.py @@ -0,0 +1,6 @@ +"""Compatibility shim — real implementation lives in braintrust.integrations.google_genai.""" + +from braintrust.integrations.google_genai import setup_genai # noqa: F401 + + +__all__ = ["setup_genai"]