Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions py/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_pydantic_ai_wrap_openai(session, version):
"""Test pydantic_ai with wrap_openai() approach - supports older versions."""
_install_test_deps(session)
_install(session, "pydantic_ai", version)
_run_tests(session, f"{WRAPPER_DIR}/test_pydantic_ai_wrap_openai.py")
_run_tests(session, f"{INTEGRATION_DIR}/pydantic_ai/test_pydantic_ai_wrap_openai.py")
_run_core_tests(session)


Expand All @@ -137,7 +137,7 @@ def test_pydantic_ai_integration(session, version):
session.skip("pydantic_ai integration tests require Python >= 3.10 (pydantic_ai 1.10.0+)")
_install_test_deps(session)
_install(session, "pydantic_ai", version)
_run_tests(session, f"{WRAPPER_DIR}/test_pydantic_ai_integration.py")
_run_tests(session, f"{INTEGRATION_DIR}/pydantic_ai/test_pydantic_ai_integration.py")
_run_core_tests(session)


Expand All @@ -149,7 +149,7 @@ def test_pydantic_ai_logfire(session):
_install_test_deps(session)
_install(session, "pydantic_ai")
_install(session, "logfire")
_run_tests(session, f"{WRAPPER_DIR}/test_pydantic_ai_logfire.py")
_run_tests(session, f"{INTEGRATION_DIR}/pydantic_ai/test_pydantic_ai_logfire.py")


@nox.session()
Expand Down
6 changes: 3 additions & 3 deletions py/src/braintrust/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def is_equal(expected, output):
from .integrations.openrouter import (
wrap_openrouter, # noqa: F401 # type: ignore[reportUnusedImport]
)
from .integrations.pydantic_ai import (
setup_pydantic_ai, # noqa: F401 # type: ignore[reportUnusedImport]
)
from .logger import *
from .logger import (
_internal_get_global_state, # noqa: F401 # type: ignore[reportUnusedImport]
Expand All @@ -98,6 +101,3 @@ def is_equal(expected, output):
from .wrappers.litellm import (
wrap_litellm, # noqa: F401 # type: ignore[reportUnusedImport]
)
from .wrappers.pydantic_ai import (
setup_pydantic_ai, # noqa: F401 # type: ignore[reportUnusedImport]
)
11 changes: 2 additions & 9 deletions py/src/braintrust/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
DSPyIntegration,
GoogleGenAIIntegration,
OpenRouterIntegration,
PydanticAIIntegration,
)


Expand Down Expand Up @@ -124,7 +125,7 @@ def auto_instrument(
if litellm:
results["litellm"] = _instrument_litellm()
if pydantic_ai:
results["pydantic_ai"] = _instrument_pydantic_ai()
results["pydantic_ai"] = _instrument_integration(PydanticAIIntegration)
if google_genai:
results["google_genai"] = _instrument_integration(GoogleGenAIIntegration)
if openrouter:
Expand Down Expand Up @@ -163,11 +164,3 @@ def _instrument_litellm() -> bool:

return patch_litellm()
return False


def _instrument_pydantic_ai() -> bool:
with _try_patch():
from braintrust.wrappers.pydantic_ai import setup_pydantic_ai

return setup_pydantic_ai()
return False
2 changes: 2 additions & 0 deletions py/src/braintrust/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .dspy import DSPyIntegration
from .google_genai import GoogleGenAIIntegration
from .openrouter import OpenRouterIntegration
from .pydantic_ai import PydanticAIIntegration


__all__ = [
Expand All @@ -17,4 +18,5 @@
"DSPyIntegration",
"GoogleGenAIIntegration",
"OpenRouterIntegration",
"PydanticAIIntegration",
]
148 changes: 140 additions & 8 deletions py/src/braintrust/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@ class BasePatcher(ABC):
patch_id: ClassVar[str | None] = None
version_spec: ClassVar[str | None] = None
priority: ClassVar[int] = 100
rescan_on_setup: ClassVar[bool] = False

@classmethod
def patch_marker_attr(cls) -> str:
"""Return the sentinel attribute used to mark this patcher as applied."""
suffix = re.sub(r"\W+", "_", cls.identifier()).strip("_")
return f"__braintrust_patched_{suffix}__"

@classmethod
def has_patch_marker(cls, obj: Any) -> bool:
"""Return whether *obj* is marked as patched by this patcher.

For classes, read ``__dict__`` directly so markers inherited via the
MRO do not make subclasses appear locally patched.
"""
if obj is None:
return False
if isinstance(obj, type):
return bool(obj.__dict__.get(cls.patch_marker_attr(), False))
return bool(getattr(obj, cls.patch_marker_attr(), False))

@classmethod
def mark_patched(cls, obj: Any) -> None:
"""Mark an object as patched by this patcher."""
setattr(obj, cls.patch_marker_attr(), True)

@classmethod
def identifier(cls) -> str:
Expand All @@ -44,6 +69,115 @@ def patch(cls, module: Any | None, version: str | None, *, target: Any | None =
raise NotImplementedError


class ClassScanPatcher(BasePatcher):
"""Base patcher for rescanning and patching discovered class hierarchies."""

rescan_on_setup: ClassVar[bool] = True
include_abstract_classes: ClassVar[bool] = False
target_module: ClassVar[str | None] = None
root_class_path: ClassVar[str | None] = None

@classmethod
def resolve_scan_root(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> Any | None:
"""Return the object from which this patcher resolves its root class."""
if target is not None:
return target
if cls.target_module is not None:
try:
return importlib.import_module(cls.target_module)
except ImportError:
return None
return module

@classmethod
def iter_root_classes(
cls,
module: Any | None,
version: str | None,
*,
target: Any | None = None,
) -> Iterable[type[Any]]:
"""Yield root classes whose subclass trees should be scanned."""
if cls.root_class_path is None:
return ()
root = cls.resolve_scan_root(module, version, target=target)
if root is None:
return ()
root_class = _resolve_attr_path(root, cls.root_class_path)
if root_class is None:
return ()
return (root_class,)

@classmethod
def resolve_root_classes(
cls,
module: Any | None,
version: str | None,
*,
target: Any | None = None,
) -> tuple[type[Any], ...]:
"""Return the currently discoverable root classes for this patcher."""
try:
return tuple(cls.iter_root_classes(module, version, target=target))
except ImportError:
return ()

@classmethod
def applies(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool:
"""Return whether any root classes are currently discoverable."""
return super().applies(module, version, target=target) and bool(
cls.resolve_root_classes(module, version, target=target)
)

@classmethod
@abstractmethod
def patch_class(cls, target_class: type[Any]) -> bool | None:
"""Patch one discovered class.

Return ``False`` to skip marking the class as patched. Any other return
value is treated as a successful patch.
"""
raise NotImplementedError

@classmethod
def iter_classes(
cls,
module: Any | None,
version: str | None,
*,
target: Any | None = None,
) -> Iterable[type[Any]]:
"""Yield discovered subclasses under the configured root classes."""

def walk(base_class: type[Any]) -> Iterable[type[Any]]:
for subclass in base_class.__subclasses__():
if cls.include_abstract_classes or not getattr(subclass, "__abstractmethods__", None):
yield subclass
yield from walk(subclass)

for root_class in cls.resolve_root_classes(module, version, target=target):
yield from walk(root_class)

@classmethod
def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool:
"""Return ``True`` when every currently discovered class is patched."""
classes = tuple(cls.iter_classes(module, version, target=target))
return bool(classes) and all(cls.has_patch_marker(class_) for class_ in classes)

@classmethod
def patch(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool:
"""Patch all newly discovered classes under the configured roots."""
success = False
for class_ in cls.iter_classes(module, version, target=target):
if cls.has_patch_marker(class_):
continue
if cls.patch_class(class_) is False:
continue
cls.mark_patched(class_)
success = True
return success


class FunctionWrapperPatcher(BasePatcher):
"""Base patcher for single-target `wrap_function_wrapper` instrumentation.

Expand Down Expand Up @@ -125,14 +259,13 @@ def mark_patched(cls, obj: Any) -> None:
@classmethod
def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool:
"""Return whether this patcher's target has already been instrumented."""
marker = cls.patch_marker_attr()
resolved_target = cls.resolve_target(module, version, target=target)
if resolved_target is not None and getattr(resolved_target, marker, False):
if cls.has_patch_marker(resolved_target):
return True
# Fall back to checking the root — the marker may live there when the
# resolved target does not support setattr (e.g. bound methods).
root = cls.resolve_root(module, version, target=target)
if root is not None and root is not resolved_target and getattr(root, marker, False):
if root is not None and root is not resolved_target and cls.has_patch_marker(root):
return True
return False

Expand All @@ -152,7 +285,7 @@ def patch(cls, module: Any | None, version: str | None, *, target: Any | None =
cls.mark_patched(resolved_target)
# If mark_patched could not store the marker on the target (e.g. bound
# methods), store it on the root so is_patched() can still find it.
if not getattr(resolved_target, marker, False):
if not cls.has_patch_marker(resolved_target):
setattr(root, marker, True)
return True

Expand All @@ -174,8 +307,7 @@ def wrap_target(cls, target: Any) -> Any:
``superseded_by`` has a target that exists on *target*. Returns
*target* for convenient chaining.
"""
marker = cls.patch_marker_attr()
if getattr(target, marker, False):
if cls.has_patch_marker(target):
return target
attr = cls.target_path.rsplit(".", 1)[-1]
if _resolve_attr_path(target, attr) is None:
Expand Down Expand Up @@ -241,7 +373,7 @@ def mark_patched(cls, obj: Any) -> None:
def is_patched(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool:
"""Return whether this patcher's replacement class is already installed."""
resolved_target = cls.resolve_target(module, version, target=target)
return bool(resolved_target is not None and getattr(resolved_target, cls.patch_marker_attr(), False))
return bool(resolved_target is not None and cls.has_patch_marker(resolved_target))

@classmethod
def patch(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool:
Expand Down Expand Up @@ -370,7 +502,7 @@ def setup(
for patcher in sorted(selected_patchers, key=lambda patcher: patcher.priority):
if not patcher.applies(module, version, target=target):
continue
if patcher.is_patched(module, version, target=target):
if not patcher.rescan_on_setup and patcher.is_patched(module, version, target=target):
success = True
continue
success = patcher.patch(module, version, target=target) or success
Expand Down
52 changes: 52 additions & 0 deletions py/src/braintrust/integrations/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Braintrust integration for Pydantic AI."""

import logging

from braintrust.logger import NOOP_SPAN, current_span, init_logger

from .integration import PydanticAIIntegration
from .patchers import wrap_agent, wrap_model_classes
from .tracing import (
wrap_model_request,
wrap_model_request_stream,
wrap_model_request_stream_sync,
wrap_model_request_sync,
)


logger = logging.getLogger(__name__)

__all__ = [
"PydanticAIIntegration",
"setup_pydantic_ai",
"wrap_agent",
"wrap_model_classes",
"wrap_model_request",
"wrap_model_request_sync",
"wrap_model_request_stream",
"wrap_model_request_stream_sync",
]


def setup_pydantic_ai(
api_key: str | None = None,
project_id: str | None = None,
project_name: str | None = None,
) -> bool:
"""
Setup Braintrust integration with Pydantic AI. Will automatically patch Pydantic AI
agents and direct API functions for automatic tracing.

Args:
api_key: Braintrust API key.
project_id: Braintrust project ID.
project_name: Braintrust project name.

Returns:
True if setup was successful, False otherwise.
"""
span = current_span()
if span == NOOP_SPAN:
init_logger(project=project_name, api_key=api_key, project_id=project_id)

return PydanticAIIntegration.setup()
Loading
Loading