From fdd37a7d5815a13cce94d5e1ea537e754d015226 Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 24 Mar 2026 13:55:15 -0700 Subject: [PATCH] ref(agno): Make agno use new integrations API --- .agents/skills/sdk-integrations/SKILL.md | 171 ++- .../skills/sdk-wrapper-migrations/SKILL.md | 257 ++-- py/noxfile.py | 4 +- py/src/braintrust/auto.py | 12 +- py/src/braintrust/integrations/__init__.py | 3 +- .../braintrust/integrations/agno/__init__.py | 50 + .../agno/_test_agno_helpers.py | 2 +- .../test_agno_simple_agent_execution.yaml | 0 .../test_agno_workflow_with_agent.yaml | 0 .../agno/cassettes/test_auto_agno.yaml | 147 ++ .../integrations/agno/integration.py | 31 + .../braintrust/integrations/agno/patchers.py | 379 +++++ .../agno/test_agno.py | 25 +- .../agno/test_workflow.py | 6 +- .../braintrust/integrations/agno/tracing.py | 1316 +++++++++++++++++ py/src/braintrust/integrations/base.py | 35 +- py/src/braintrust/wrappers/agno/__init__.py | 83 +- py/src/braintrust/wrappers/agno/agent.py | 216 --- .../braintrust/wrappers/agno/function_call.py | 67 - py/src/braintrust/wrappers/agno/model.py | 318 ---- .../braintrust/wrappers/agno/run_helpers.py | 139 -- py/src/braintrust/wrappers/agno/team.py | 216 --- py/src/braintrust/wrappers/agno/utils.py | 520 ------- py/src/braintrust/wrappers/agno/workflow.py | 353 ----- 24 files changed, 2227 insertions(+), 2123 deletions(-) create mode 100644 py/src/braintrust/integrations/agno/__init__.py rename py/src/braintrust/{wrappers => integrations}/agno/_test_agno_helpers.py (99%) rename py/src/braintrust/{wrappers => integrations/agno}/cassettes/test_agno_simple_agent_execution.yaml (100%) rename py/src/braintrust/{wrappers => integrations/agno}/cassettes/test_agno_workflow_with_agent.yaml (100%) create mode 100644 py/src/braintrust/integrations/agno/cassettes/test_auto_agno.yaml create mode 100644 py/src/braintrust/integrations/agno/integration.py create mode 100644 py/src/braintrust/integrations/agno/patchers.py rename py/src/braintrust/{wrappers => integrations}/agno/test_agno.py (92%) rename py/src/braintrust/{wrappers => integrations}/agno/test_workflow.py (97%) create mode 100644 py/src/braintrust/integrations/agno/tracing.py delete mode 100644 py/src/braintrust/wrappers/agno/agent.py delete mode 100644 py/src/braintrust/wrappers/agno/function_call.py delete mode 100644 py/src/braintrust/wrappers/agno/model.py delete mode 100644 py/src/braintrust/wrappers/agno/run_helpers.py delete mode 100644 py/src/braintrust/wrappers/agno/team.py delete mode 100644 py/src/braintrust/wrappers/agno/utils.py delete mode 100644 py/src/braintrust/wrappers/agno/workflow.py diff --git a/.agents/skills/sdk-integrations/SKILL.md b/.agents/skills/sdk-integrations/SKILL.md index e487fc1b..754b61f7 100644 --- a/.agents/skills/sdk-integrations/SKILL.md +++ b/.agents/skills/sdk-integrations/SKILL.md @@ -1,24 +1,23 @@ --- name: sdk-integrations -description: Create or update Braintrust Python SDK integrations built on the integrations API. Use for work in `py/src/braintrust/integrations/`, including new providers, patchers, tracing, `auto_instrument()` updates, integration exports, and integration tests. +description: Create or update Braintrust Python SDK integrations built on the integrations API under `py/src/braintrust/integrations/`. Use when adding a new integration package, extending an existing provider integration, changing patchers, tracing, manual `wrap_*()` helpers, integration exports, `auto_instrument()` wiring, `py/noxfile.py` sessions, integration tests, or cassettes. Do not use when migrating an existing legacy wrapper from `py/src/braintrust/wrappers/` into the integrations API; use `sdk-wrapper-migrations` for that. --- # SDK Integrations Use this skill for integrations API work under `py/src/braintrust/integrations/`. -Start from the nearest existing provider instead of designing from scratch: +If the provider already has a real implementation under `py/src/braintrust/wrappers//` and the task is to move that implementation into the integrations API, switch to `sdk-wrapper-migrations` instead of treating it like a fresh integration. -- ADK (`py/src/braintrust/integrations/adk/`) is the best reference for direct method patching, `target_module`, `CompositeFunctionWrapperPatcher`, and public `wrap_*()` helpers. -- Anthropic (`py/src/braintrust/integrations/anthropic/`) is the best reference for constructor patching with `FunctionWrapperPatcher`. +## Pick The Nearest Example -## Workflow +Start from one structural reference and one patching reference instead of designing from scratch: -1. Read the shared primitives and the nearest provider example. -2. Decide whether the task is a new provider, an existing provider update, or an `auto_instrument()` change. -3. Change only the affected integration, patchers, tracing, exports, and tests. -4. Update tests and cassettes only where behavior changed intentionally. -5. Run the narrowest provider session first, then expand only if shared code changed. +- ADK (`py/src/braintrust/integrations/adk/`) for direct method patching, `target_module`, `CompositeFunctionWrapperPatcher`, manual `wrap_*()` helpers, and priority-based context propagation. +- Agno (`py/src/braintrust/integrations/agno/`) for multi-target patching, version-conditional fallbacks with `superseded_by`, and providers that need several related patchers. +- Anthropic (`py/src/braintrust/integrations/anthropic/`) for constructor patching and a compact provider package with a small public surface. + +Match an existing repo pattern unless the target provider forces a different shape. ## Read First @@ -26,43 +25,94 @@ Always read: - `py/src/braintrust/integrations/base.py` - `py/src/braintrust/integrations/versioning.py` +- `py/src/braintrust/integrations/__init__.py` +- `py/noxfile.py` + +Read when updating an existing integration: + +- `py/src/braintrust/integrations//__init__.py` +- `py/src/braintrust/integrations//integration.py` +- `py/src/braintrust/integrations//patchers.py` +- `py/src/braintrust/integrations//tracing.py` +- `py/src/braintrust/integrations//test_*.py` Read when relevant: - `py/src/braintrust/auto.py` for `auto_instrument()` work - `py/src/braintrust/conftest.py` for VCR behavior -- `py/src/braintrust/integrations/adk/test_adk.py` for integration test patterns - `py/src/braintrust/integrations/auto_test_scripts/` for subprocess auto-instrument tests +- `py/src/braintrust/integrations/adk/test_adk.py` and `py/src/braintrust/integrations/anthropic/test_anthropic.py` for test layout patterns + +## Route The Task + +### New provider integration + +1. Create `py/src/braintrust/integrations//`. +2. Add the normal split unless the provider is exceptionally small: + - `__init__.py` + - `integration.py` + - `patchers.py` + - `tracing.py` + - `test_.py` + - `cassettes/` when the provider uses HTTP +3. Export the integration from `py/src/braintrust/integrations/__init__.py`. +4. Add or update the provider session in `py/noxfile.py`. +5. Update `py/src/braintrust/auto.py` only if the integration should participate in `auto_instrument()`. +6. Add subprocess coverage in `py/src/braintrust/integrations/auto_test_scripts/` when `auto_instrument()` changes. + +### Existing integration update + +1. Read the current provider package before editing. +2. Change only the affected patchers, tracing helpers, exports, tests, and cassettes. +3. Preserve the provider's public setup and `wrap_*()` surface unless the task explicitly changes it. +4. Keep repo-level changes narrow; do not touch `auto.py`, `integrations/__init__.py`, or `py/noxfile.py` unless the task actually requires it. + +### `auto_instrument()` only + +1. Update `py/src/braintrust/auto.py`. +2. Use `_instrument_integration(...)` instead of adding a custom `_instrument_*` helper when the integration fits the standard pattern. +3. Add the integration import near the other integration imports. +4. Add or update the relevant subprocess auto-instrument test. ## Package Layout -Create new providers under `py/src/braintrust/integrations//`. Keep the existing layout for provider updates unless the current structure is the problem. +Keep provider-local code inside `py/src/braintrust/integrations//`. -Typical files: +Typical file ownership: - `__init__.py`: export the integration class, `setup_()`, and public `wrap_*()` helpers - `integration.py`: define the `BaseIntegration` subclass and register patchers -- `patchers.py`: define patchers and `wrap_*()` helpers -- `tracing.py`: keep provider-specific tracing, stream handling, and normalization -- `test_.py`: keep provider behavior tests next to the integration +- `patchers.py`: define patchers and manual `wrap_*()` helpers +- `tracing.py`: keep provider-specific tracing, stream handling, normalization, and metadata extraction +- `test_*.py`: keep provider behavior tests next to the integration - `cassettes/`: keep VCR recordings next to the integration tests when the provider uses HTTP +Keep `integration.py` thin. Put provider behavior in provider-local modules, not in shared integration primitives, unless the shared abstraction is genuinely missing. + ## Integration Rules -Keep `integration.py` thin. Set: +Set the integration class up declaratively: + +- set `name` +- set `import_names` +- set `patchers` +- set `min_version` or `max_version` only when feature detection is not enough -- `name` -- `import_names` -- `patchers` -- `min_version` and `max_version` only when needed +Keep span creation, metadata extraction, stream aggregation, error logging, and output normalization in `tracing.py`. -Keep provider behavior in the provider package, not in shared integration code. Put span creation, metadata extraction, stream aggregation, error logging, and output normalization in `tracing.py`. +Preserve provider behavior. Do not let tracing-only code change provider return values, control flow, or error behavior except where the task explicitly requires it. -Preserve provider behavior. Do not let tracing-only code break the provider call. +Prefer feature detection first and version checks second. Use: + +- `detect_module_version(...)` +- `version_satisfies(...)` +- `make_specifier(...)` + +Let `BaseIntegration.resolve_patchers()` reject duplicate patcher ids; do not silently paper over duplicates. ## Patcher Rules -Create one patcher per coherent patch target. If targets are unrelated, split them. +Create one patcher per coherent patch target. Split unrelated targets into separate patchers. Use `FunctionWrapperPatcher` for one import path or one constructor/method surface, for example: @@ -72,60 +122,29 @@ Use `FunctionWrapperPatcher` for one import path or one constructor/method surfa Use `CompositeFunctionWrapperPatcher` when several closely related targets should appear as one patcher, for example: - sync and async variants of the same method -- the same function patched across multiple modules +- the same logical surface patched across multiple modules -Set `target_module` when the patch target lives outside the module named by `import_names`, especially for optional or deep submodules. Failed `target_module` imports should cause the patcher to skip cleanly through `applies()`. +Set `target_module` when the patch target lives outside the module named by `import_names`, especially for optional or deep submodules. Failed `target_module` imports should make the patcher skip cleanly through `applies()`. + +Use `superseded_by` for version-conditional mutual exclusion. Express fallback relationships declaratively instead of reproducing `hasattr` logic in custom `applies()` methods whenever possible. Expose manual wrapping helpers through `wrap_target()`: ```python def wrap_agent(Agent: Any) -> Any: - return AgentRunAsyncPatcher.wrap_target(Agent) + return AgentPatcher.wrap_target(Agent) ``` -Use lower `priority` values only when ordering matters, such as context propagation before tracing. +Use lower `priority` values only when ordering matters, such as context propagation before tracing patchers. -Patchers must provide: +Require every patcher to provide: -- stable `name` values +- a stable `name` - version gating only when needed -- existence checks +- clean existence checks - idempotence through the base patcher marker -Let `BaseIntegration.resolve_patchers()` reject duplicate patcher ids instead of silently ignoring them. - -## Patching Patterns - -Use constructor patching when the goal is to instrument future clients created by the provider SDK. Patch the constructor, then attach traced surfaces after the real constructor runs. - -Use direct method patching with `target_module` when the provider exposes a flatter API and there is no useful constructor patch point. - -Keep public `wrap_*()` helpers in `patchers.py` and export them from the integration package. - -## Versioning - -Prefer feature detection first and version checks second. - -Use: - -- `detect_module_version(...)` -- `version_satisfies(...)` -- `make_specifier(...)` - -## `auto_instrument()` - -Update `py/src/braintrust/auto.py` only if the integration should be auto-patched. - -All `auto_instrument()` parameters are plain `bool` flags. Use `_instrument_integration(...)` instead of adding a custom `_instrument_*` function: - -```python -if provider: - results["provider"] = _instrument_integration(ProviderIntegration) -``` - -Add the integration import near the other integration imports in `auto.py`. - -## Tests +## Testing Keep integration tests in the provider package. @@ -137,8 +156,8 @@ Use `@pytest.mark.vcr` for real provider network behavior. Prefer recorded provi Cover the surfaces that changed: -- direct `wrap(...)` behavior -- `setup()` patching new clients +- direct `wrap_*()` behavior +- `setup()` patching for newly created clients or classes - sync behavior - async behavior - streaming behavior @@ -146,14 +165,16 @@ Cover the surfaces that changed: - failure and error logging - patcher resolution and duplicate detection -Keep VCR cassettes in `py/src/braintrust/integrations//cassettes/`. Re-record them only for intentional behavior changes. +Keep VCR cassettes in `py/src/braintrust/integrations//cassettes/`. Re-record only when the behavior change is intentional. + +When choosing commands, confirm the real session name in `py/noxfile.py` instead of assuming it matches the provider folder. Examples in this repo include `test_agno`, `test_anthropic`, and `test_google_adk`. ## Commands ```bash -cd py && nox -s "test_(latest)" -cd py && nox -s "test_(latest)" -- -k "test_name" -cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" +cd py && nox -s "test_(latest)" +cd py && nox -s "test_(latest)" -- -k "test_name" +cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" cd py && make test-core cd py && make lint ``` @@ -161,14 +182,16 @@ cd py && make lint ## Validation - Run the narrowest provider session first. +- Run the relevant auto-instrument subprocess test if `auto.py` changed. - Run `cd py && make test-core` if shared integration code changed. -- Run `cd py && make lint` before handing off broader integration changes. -- Run the relevant auto-instrument subprocess tests if `auto_instrument()` changed. +- Run `cd py && make lint` before handoff when shared files or repo-level wiring changed. ## Pitfalls -- Moving provider-specific behavior into shared integration code. +- Treating a wrapper migration as fresh integration work. +- Changing shared integration primitives when the provider-specific package should own the behavior. - Combining unrelated targets into one patcher. +- Forgetting repo-level touch points for new providers: `integrations/__init__.py`, `py/noxfile.py`, and sometimes `auto.py`. - Forgetting async or streaming coverage. - Re-recording cassettes when behavior did not intentionally change. - Adding a custom `_instrument_*` helper where `_instrument_integration()` already fits. diff --git a/.agents/skills/sdk-wrapper-migrations/SKILL.md b/.agents/skills/sdk-wrapper-migrations/SKILL.md index f7c1852a..0aff31da 100644 --- a/.agents/skills/sdk-wrapper-migrations/SKILL.md +++ b/.agents/skills/sdk-wrapper-migrations/SKILL.md @@ -1,178 +1,223 @@ --- name: sdk-wrapper-migrations -description: Migrate Braintrust Python SDK legacy wrapper implementations to the integrations API. Use when moving a provider from `py/src/braintrust/wrappers/` into `py/src/braintrust/integrations/`, preserving backward compatibility while relocating tracing, patchers, tests, cassettes, auto-instrument hooks, and test sessions. +description: Migrate Braintrust Python SDK legacy wrapper implementations to the integrations API. Use when moving an existing provider from `py/src/braintrust/wrappers/` into `py/src/braintrust/integrations/` while preserving old import paths, public helpers, tests, cassettes, tracing behavior, auto-instrument hooks, and nox coverage. --- # SDK Wrapper Migrations -Use this skill when a provider already exists under `py/src/braintrust/wrappers/` and needs to be migrated to the integrations API. +Migrate an existing wrapper-backed provider to the integrations API without breaking the old wrapper import path. -Use current repo examples, not old commit history: +Prefer this skill only when both of these are true: -- `py/src/braintrust/integrations/adk/` for full integration package structure, test placement, auto-instrument coverage, and wrapper delegation -- `py/src/braintrust/integrations/anthropic/` for constructor patching and a minimal compatibility wrapper +- Find an existing provider implementation under `py/src/braintrust/wrappers/`. +- Need the end state to be an integration package under `py/src/braintrust/integrations/` plus a thin compatibility wrapper. -The target end state is: +Use `sdk-integrations` instead when the task is integration work that does not start from a legacy wrapper. -- provider logic lives in `py/src/braintrust/integrations//` -- tests and cassettes live with the integration -- `auto_instrument()` uses the integration -- the legacy wrapper becomes a thin compatibility layer +Do not reconstruct migrations from old commit history. Start from the current tree and copy the nearest current pattern. + +## Target End State + +Finish with this structure: + +- provider logic in `py/src/braintrust/integrations//` +- provider tests in `py/src/braintrust/integrations//` +- provider cassettes in `py/src/braintrust/integrations//cassettes/` when applicable +- `auto_instrument()` pointing at the integration when the provider participates in auto patching +- the wrapper reduced to compatibility re-exports with the old public surface intact + +Do not leave tracing helpers, patchers, or setup orchestration behind in the wrapper. + +## Current References + +Use the nearest current provider instead of inventing a layout: + +- ADK: use `py/src/braintrust/integrations/adk/` as the main structural reference for package layout, patchers, tracing split, tests, cassettes, auto-test scripts, and thin wrapper delegation. +- Agno: use `py/src/braintrust/integrations/agno/` for multi-method patching, `CompositeFunctionWrapperPatcher`, raw wrapt wrappers in `tracing.py`, and version-conditional fallbacks using `superseded_by`. +- Anthropic: use `py/src/braintrust/integrations/anthropic/` for compact constructor patching and a minimal compatibility wrapper. + +Match one of those patterns unless the provider has a concrete reason to differ. ## Read First Always read: -- the existing legacy wrapper under `py/src/braintrust/wrappers//` -- `py/src/braintrust/integrations/anthropic/__init__.py` -- `py/src/braintrust/integrations/anthropic/integration.py` -- `py/src/braintrust/integrations/anthropic/patchers.py` -- `py/src/braintrust/integrations/anthropic/tracing.py` +- the existing legacy wrapper under `py/src/braintrust/wrappers//` or `py/src/braintrust/wrappers/.py` - `py/src/braintrust/integrations/base.py` +- `py/src/braintrust/integrations/versioning.py` - `py/src/braintrust/auto.py` - `py/noxfile.py` +Read these current migration examples before editing: + +- `py/src/braintrust/integrations/adk/__init__.py` +- `py/src/braintrust/integrations/adk/integration.py` +- `py/src/braintrust/integrations/adk/patchers.py` +- `py/src/braintrust/integrations/adk/tracing.py` +- `py/src/braintrust/integrations/agno/patchers.py` +- `py/src/braintrust/integrations/agno/tracing.py` +- `py/src/braintrust/integrations/anthropic/__init__.py` +- `py/src/braintrust/integrations/anthropic/integration.py` +- `py/src/braintrust/integrations/anthropic/patchers.py` + Read when relevant: -- `py/src/braintrust/integrations/auto_test_scripts/` +- `py/src/braintrust/conftest.py` for VCR behavior +- `py/src/braintrust/integrations/auto_test_scripts/` for auto-instrument subprocess coverage - the provider's existing wrapper tests and cassettes -## Workflow +## Migration Playbook -1. Inventory the wrapper's public API, patch targets, tests, and cassettes. -2. Create an integration package that preserves the wrapper's behavior and public helper surface. -3. Move provider-specific tracing and patching into the integration package. +1. Inventory the wrapper before moving code. +2. Create the integration package with the public API and layout you intend to keep. +3. Move tracing helpers and wrapper logic into provider-local integration modules. 4. Move tests, cassettes, and auto-instrument subprocess coverage next to the integration. -5. Wire the integration into exports, `auto.py`, and `py/noxfile.py`. -6. Replace the wrapper with a thin re-export layer. -7. Run the narrowest provider session first, then expand if shared code changed. - -## Migration Checklist +5. Wire exports, `auto.py`, and `py/noxfile.py` to the new integration location. +6. Collapse the wrapper to compatibility imports and re-exports. +7. Run the narrowest provider test session first, then expand only if shared code changed. -### 1. Preserve the public surface +## Inventory First -Before moving code, list the public names exposed by the wrapper: +Before editing, list the wrapper's user-visible and repo-visible surface: -- setup functions -- `wrap_*()` helpers -- deprecated aliases that still need to work +- setup entry points such as `setup_()` +- public `wrap_*()` helpers +- deprecated aliases that must still import correctly - `__all__` +- patch targets and target modules +- sync, async, and streaming code paths +- test files, cassette directories, and auto-test scripts +- any version-routing logic, especially `hasattr`-based fallback behavior -The integration package should own that public surface after the migration. The wrapper should only delegate to it. +Do not start moving files until this inventory is explicit. The migration succeeds only if the integration preserves the same behavior and import surface. -### 2. Create the integration package +## Package Layout -Create `py/src/braintrust/integrations//` with the same split used by ADK: +Create `py/src/braintrust/integrations//` and keep provider-specific behavior there. -- `__init__.py`: public API, setup entry point, deprecated aliases if needed -- `integration.py`: `BaseIntegration` subclass and patcher registration -- `patchers.py`: one patcher per coherent patch target, plus public `wrap_*()` helpers -- `tracing.py`: provider-specific tracing, stream handling, normalization, and helper code -- `test_.py`: provider behavior tests -- `cassettes/`: VCR recordings when the provider uses HTTP +Use this layout unless the provider already has a better current variant: -Keep provider-specific behavior out of shared modules unless the provider truly needs a shared change. +- `__init__.py`: export the public API, `setup_()`, and compatibility aliases +- `integration.py`: define the `BaseIntegration` subclass and register patchers +- `patchers.py`: define patchers and public `wrap_*()` helpers +- `tracing.py`: keep spans, metadata extraction, stream handling, normalization, and helper code +- `test_.py` or split test files: keep provider behavior tests next to the integration +- `cassettes/`: keep VCR recordings next to the provider tests when the provider uses HTTP -### 3. Move tracing and patching out of the wrapper +Keep `integration.py` thin. Do not move provider behavior into shared integration primitives unless the provider truly needs a shared change. -Extract wrapper internals into: +## Public API Rules -- `tracing.py` for spans, metadata extraction, stream aggregation, and output normalization -- `patchers.py` for patcher classes and `wrap_*()` helpers -- `integration.py` for the orchestration layer only +Preserve the wrapper's public surface exactly unless the task explicitly changes it. -Prefer one patcher per coherent patch target. Use composite patchers only when several related targets should be user-visible as one patcher. +Keep or migrate: -### 4. Preserve setup behavior +- setup function names +- `wrap_*()` helper names +- deprecated aliases +- `__all__` -The new integration package should preserve the wrapper's setup semantics: +Make the integration package the source of truth. Make the wrapper import from the integration package, not the other way around. -- keep the same setup function names where possible -- keep deprecated aliases that users may still import -- keep logger initialization or other setup-time side effects aligned with prior behavior +When the legacy wrapper is a single module such as `py/src/braintrust/wrappers/anthropic.py`, reduce that module to compatibility re-exports in place. When the wrapper is a package directory, reduce its `__init__.py` to compatibility re-exports and delete or stop importing the old implementation modules if they are no longer used. -The integration package is the new source of truth. Do not leave setup logic duplicated in the wrapper. +## Patching And Tracing Rules -### 5. Move tests and cassettes +Move raw tracing behavior into `tracing.py`. -Move provider tests from `py/src/braintrust/wrappers/` into the integration package. +Keep tracing wrappers as plain wrapt wrapper functions. Do not carry wrapper-era patch-state logic into tracing code: -Move or rename: +- no `is_patched` +- no `mark_patched` +- no `hasattr` branching to choose targets -- provider behavior tests to `py/src/braintrust/integrations//` -- cassettes to `py/src/braintrust/integrations//cassettes/` -- auto-instrument subprocess tests to `py/src/braintrust/integrations/auto_test_scripts/` +Move patch target selection into `patchers.py`. -Update imports and cassette paths during the move. Preserve coverage for: +Prefer: -- direct `wrap_*()` behavior -- setup-time patching -- sync paths -- async paths -- streaming paths -- idempotence -- failure and logging behavior +- one `FunctionWrapperPatcher` per coherent target +- `CompositeFunctionWrapperPatcher` only when several related targets should appear as one patcher +- `superseded_by` for version-conditional fallback relationships + +When the legacy wrapper does "wrap `_run` if present, otherwise wrap `run`", convert that to separate patchers instead of reproducing the branching: + +- point the preferred patcher at the higher-priority target directly +- point the fallback patcher at the fallback target +- set `superseded_by` on the fallback patcher + +Use `py/src/braintrust/integrations/agno/patchers.py` as the reference pattern for this conversion. -### 6. Wire repo-level integration points +Expose manual wrapping through thin public helpers in `patchers.py`, then re-export them from `__init__.py`. -Update the minimum shared surfaces required by the migration: +## Test And Cassette Moves -- `py/src/braintrust/integrations/__init__.py` -- `py/src/braintrust/auto.py` if the provider participates in `auto_instrument()` -- `py/noxfile.py` so provider sessions run against the integration tests +Move provider tests with the implementation. Do not strand coverage under `wrappers/`. -Only change shared integration primitives when the provider actually needs it. +Move or update: -### 7. Reduce the wrapper to compatibility imports +- provider behavior tests into `py/src/braintrust/integrations//` +- cassette directories into `py/src/braintrust/integrations//cassettes/` +- auto-instrument subprocess tests into `py/src/braintrust/integrations/auto_test_scripts/` when relevant -After the integration package is working, replace the legacy wrapper implementation with a thin `__init__.py` that re-exports the migrated surface from `braintrust.integrations.`. +Update imports, cassette paths, and fixtures during the move. -Keep `__all__` aligned with the pre-migration public API. Do not leave business logic, tracing helpers, or patchers behind in the wrapper package. +Preserve coverage for the changed surfaces: -## Current Examples +- direct `wrap_*()` behavior +- setup-time patching +- sync behavior +- async behavior +- streaming behavior +- idempotence +- failure and logging behavior +- version-routing behavior when applicable -Use ADK as the main structural reference: +Re-record cassettes only when behavior intentionally changes. -- tracing moved into `py/src/braintrust/integrations/adk/tracing.py` -- patchers moved into `py/src/braintrust/integrations/adk/patchers.py` -- orchestration moved into `py/src/braintrust/integrations/adk/integration.py` -- public exports live in `py/src/braintrust/integrations/adk/__init__.py` -- wrapper tests and cassettes moved under `py/src/braintrust/integrations/adk/` -- auto-instrument subprocess coverage moved to `py/src/braintrust/integrations/auto_test_scripts/test_auto_adk.py` -- `py/src/braintrust/wrappers/adk/__init__.py` became a thin compatibility layer +## Repo Wiring -Use Anthropic as the compact constructor-patching reference: +Update only the shared surfaces required by the migration: -- `py/src/braintrust/integrations/anthropic/integration.py` registers sync and async constructor patchers -- `py/src/braintrust/integrations/anthropic/patchers.py` keeps one patcher per constructor target -- `py/src/braintrust/wrappers/anthropic.py` is a minimal compatibility re-export +- `py/src/braintrust/integrations/__init__.py` when the provider should be exported there +- `py/src/braintrust/auto.py` when `auto_instrument()` should use the integration +- `py/noxfile.py` so provider sessions point at the migrated integration tests -Match those patterns unless the provider has a clear reason to differ. +Prefer narrow repo-level changes. Do not broaden shared integration code unless the migration cannot work without it. -## Commands +## Validation + +Run the smallest relevant session first: ```bash cd py && nox -s "test_(latest)" cd py && nox -s "test_(latest)" -- -k "test_name" cd py && nox -s "test_(latest)" -- --vcr-record=all -k "test_name" +``` + +Expand only when the migration touches shared code: + +```bash cd py && make test-core cd py && make lint ``` -## Validation +Also verify: + +- the old wrapper import path still works +- the old `wrap_*()` helpers still work +- deprecated aliases still resolve +- the relevant auto-instrument subprocess tests still pass if `auto.py` changed + +## Migration-Specific Pitfalls + +Avoid these failures: -- Run the narrowest provider session first. -- Run `cd py && make test-core` if shared integration code changed. -- Run `cd py && make lint` before handoff when the migration touches shared files. -- Run the relevant auto-instrument subprocess tests if `auto.py` changed. -- Verify the old wrapper import path still works through compatibility re-exports. - -## Pitfalls - -- Copying wrapper code into the integration package without restructuring it around `integration.py`, `patchers.py`, and `tracing.py`. -- Leaving real logic behind in the wrapper after the migration. -- Breaking deprecated aliases or `__all__` exports that users still import. -- Moving tests without moving their cassettes or auto-instrument scripts. -- Forgetting to update `py/noxfile.py` to point at the new integration test paths. -- Changing shared integration code more broadly than the provider requires. -- Re-recording cassettes when behavior did not intentionally change. +- copying wrapper code into `integrations/` without restructuring it around `__init__.py`, `integration.py`, `patchers.py`, and `tracing.py` +- leaving business logic or tracing helpers in the wrapper after the migration +- preserving wrapper-era `hasattr` or patch-state logic in tracing wrappers instead of using patcher primitives +- re-implementing target precedence with custom branching instead of `superseded_by` +- forgetting to move cassettes or auto-test scripts with the tests +- updating tests but forgetting `py/noxfile.py` +- breaking deprecated aliases, `__all__`, or old import paths +- changing shared integration code more broadly than the provider requires +- re-recording cassettes when behavior did not intentionally change diff --git a/py/noxfile.py b/py/noxfile.py index 0e32cf57..81ece0f5 100644 --- a/py/noxfile.py +++ b/py/noxfile.py @@ -167,8 +167,8 @@ def test_agno(session, version): _install(session, "agno", version) _install(session, "openai") # Required for agno.models.openai _install(session, "fastapi") # Required for agno.workflow - _run_tests(session, f"{WRAPPER_DIR}/agno/test_agno.py") - _run_tests(session, f"{WRAPPER_DIR}/agno/test_workflow.py") + _run_tests(session, f"{INTEGRATION_DIR}/agno/test_agno.py") + _run_tests(session, f"{INTEGRATION_DIR}/agno/test_workflow.py") _run_core_tests(session) diff --git a/py/src/braintrust/auto.py b/py/src/braintrust/auto.py index 7cd51870..f91feb6e 100644 --- a/py/src/braintrust/auto.py +++ b/py/src/braintrust/auto.py @@ -7,7 +7,7 @@ import logging from contextlib import contextmanager -from braintrust.integrations import ADKIntegration, AnthropicIntegration +from braintrust.integrations import ADKIntegration, AgnoIntegration, AnthropicIntegration __all__ = ["auto_instrument"] @@ -115,7 +115,7 @@ def auto_instrument( if google_genai: results["google_genai"] = _instrument_google_genai() if agno: - results["agno"] = _instrument_agno() + results["agno"] = _instrument_integration(AgnoIntegration) if claude_agent_sdk: results["claude_agent_sdk"] = _instrument_claude_agent_sdk() if dspy: @@ -164,14 +164,6 @@ def _instrument_google_genai() -> bool: return False -def _instrument_agno() -> bool: - with _try_patch(): - from braintrust.wrappers.agno import setup_agno - - return setup_agno() - return False - - def _instrument_claude_agent_sdk() -> bool: with _try_patch(): from braintrust.wrappers.claude_agent_sdk import setup_claude_agent_sdk diff --git a/py/src/braintrust/integrations/__init__.py b/py/src/braintrust/integrations/__init__.py index d8c51617..e87b2dd6 100644 --- a/py/src/braintrust/integrations/__init__.py +++ b/py/src/braintrust/integrations/__init__.py @@ -1,5 +1,6 @@ from .adk import ADKIntegration +from .agno import AgnoIntegration from .anthropic import AnthropicIntegration -__all__ = ["ADKIntegration", "AnthropicIntegration"] +__all__ = ["ADKIntegration", "AgnoIntegration", "AnthropicIntegration"] diff --git a/py/src/braintrust/integrations/agno/__init__.py b/py/src/braintrust/integrations/agno/__init__.py new file mode 100644 index 00000000..8860a67a --- /dev/null +++ b/py/src/braintrust/integrations/agno/__init__.py @@ -0,0 +1,50 @@ +"""Braintrust integration for Agno.""" + +import logging + +from braintrust.logger import NOOP_SPAN, current_span, init_logger + +from .integration import AgnoIntegration +from .patchers import ( + wrap_agent, + wrap_function_call, + wrap_model, + wrap_team, + wrap_workflow, +) + + +logger = logging.getLogger(__name__) + +__all__ = [ + "AgnoIntegration", + "setup_agno", + "wrap_agent", + "wrap_function_call", + "wrap_model", + "wrap_team", + "wrap_workflow", +] + + +def setup_agno( + api_key: str | None = None, + project_id: str | None = None, + project_name: str | None = None, +) -> bool: + """ + Setup Braintrust integration with Agno. Will automatically patch Agno agents, models, and function calls for tracing. + + Args: + api_key: Braintrust API key (optional, can use env var BRAINTRUST_API_KEY) + project_id: Braintrust project ID (optional) + project_name: Braintrust project name (optional, can use env var BRAINTRUST_PROJECT) + + Returns: + True if setup was successful, False otherwise + """ + span = current_span() + if span == NOOP_SPAN: + init_logger(project=project_name, api_key=api_key, project_id=project_id) + + return AgnoIntegration.setup() diff --git a/py/src/braintrust/wrappers/agno/_test_agno_helpers.py b/py/src/braintrust/integrations/agno/_test_agno_helpers.py similarity index 99% rename from py/src/braintrust/wrappers/agno/_test_agno_helpers.py rename to py/src/braintrust/integrations/agno/_test_agno_helpers.py index fcb926e1..2c7d4b65 100644 --- a/py/src/braintrust/wrappers/agno/_test_agno_helpers.py +++ b/py/src/braintrust/integrations/agno/_test_agno_helpers.py @@ -5,7 +5,7 @@ from inspect import isawaitable -from braintrust.wrappers.agno.agent import wrap_agent +from braintrust.integrations.agno.patchers import wrap_agent PROJECT_NAME = "test-agno-app" diff --git a/py/src/braintrust/wrappers/cassettes/test_agno_simple_agent_execution.yaml b/py/src/braintrust/integrations/agno/cassettes/test_agno_simple_agent_execution.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_agno_simple_agent_execution.yaml rename to py/src/braintrust/integrations/agno/cassettes/test_agno_simple_agent_execution.yaml diff --git a/py/src/braintrust/wrappers/cassettes/test_agno_workflow_with_agent.yaml b/py/src/braintrust/integrations/agno/cassettes/test_agno_workflow_with_agent.yaml similarity index 100% rename from py/src/braintrust/wrappers/cassettes/test_agno_workflow_with_agent.yaml rename to py/src/braintrust/integrations/agno/cassettes/test_agno_workflow_with_agent.yaml diff --git a/py/src/braintrust/integrations/agno/cassettes/test_auto_agno.yaml b/py/src/braintrust/integrations/agno/cassettes/test_auto_agno.yaml new file mode 100644 index 00000000..338d68c2 --- /dev/null +++ b/py/src/braintrust/integrations/agno/cassettes/test_auto_agno.yaml @@ -0,0 +1,147 @@ +interactions: +- request: + body: '{"messages":[{"role":"developer","content":"You are a helpful assistant. + Be brief."},{"role":"user","content":"Say hi"}],"model":"gpt-4o-mini"}' + headers: + Accept: + - application/json + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '143' + Content-Type: + - application/json + Host: + - api.openai.com + User-Agent: + - PatchedOpenAI/Python 2.15.0 + X-Stainless-Arch: + - arm64 + X-Stainless-Async: + - 'false' + X-Stainless-Lang: + - python + X-Stainless-OS: + - MacOS + X-Stainless-Package-Version: + - 2.15.0 + X-Stainless-Raw-Response: + - 'true' + X-Stainless-Runtime: + - CPython + X-Stainless-Runtime-Version: + - 3.13.3 + x-stainless-read-timeout: + - '600' + x-stainless-retry-count: + - '0' + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + body: + string: !!binary | + H4sIAAAAAAAAAwAAAP//jFJBbtswELzrFVuerUKWjdjyJZcWSBG0QC9F0SIQaHIls6G4LLlqawT+ + e0HJsZQ0BXLRYWdnNDPchwxAGC12INRBsuq8zd8Vunj/8RPef9k3n5tN87W3+ue3Zh1u+TaKRWLQ + /gcqfmS9VdR5i2zIjbAKKBmT6nJzVRXrsqy2A9CRRptored8TXlnnMnLolznxSZfbs/sAxmFUezg + ewYA8DB8k0+n8Y/YQbF4nHQYo2xR7C5LACKQTRMhYzSRpWOxmEBFjtEN1m/MG7ih36Ckgw8wbsOR + emDS8ng9ZwVs+iiTc9dbOwOkc8QyJR/83p2R08WhpdYH2sdnVNEYZ+KhDigjueQmMnkxoKcM4G5o + on8STvhAneea6R6H35XlKCem/iewOmNMLO00Xi0XL4jVGlkaG2dFCiXVAfXEnFqXvTY0A7JZ5H+9 + vKQ9xjaufY38BCiFnlHXPqA26mneaS1gOs7/rV0qHgyLiOGXUVizwZCeQWMjezuejIjHyNjVjXEt + Bh/MeDeNr8tqtSpkdbXdiuyU/QUAAP//AwBBs78WRQMAAA== + headers: + CF-RAY: + - 9c1afcdd3f36b231-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Thu, 22 Jan 2026 00:38:19 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=Jw95ZRGfTr6qO8YVvMCpB1aMAiti.HWb9WM0o.EAG4M-1769042299-1.0.1.1-F0ol4YtLGC1.t2DHb1Hj435gvyQ_nGNudwYUErS.pg4aWKbU4O68f4wJthw2GUCv2BYU7cC4ZcIA0B6TvaUN7VYsBM5OS7Ccc46cnb7zQ9Y; + path=/; expires=Thu, 22-Jan-26 01:08:19 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=gxrFvllhyUbQeecWVXMHkFhdg_IAJ7CO467JJDSyVA8-1769042299331-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Strict-Transport-Security: + - max-age=31536000; includeSubDomains; preload + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + openai-organization: + - braintrust-data + openai-processing-ms: + - '438' + openai-project: + - proj_vsCSXafhhByzWOThMrJcZiw9 + openai-version: + - '2020-10-01' + x-envoy-upstream-service-time: + - '490' + x-openai-proxy-wasm: + - v0.1 + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999985' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_c68836a69b1549819fb6a5eecfd10be7 + status: + code: 200 + message: OK +- request: + body: '{"session_id":"3ed01154-18cc-4648-b766-73f60e3e08c2","run_id":"4ebf7a0f-31fa-4a69-9500-f3f3f21d350d","data":{"agent_id":"test-agent","db_type":null,"model_provider":"OpenAI","model_name":"OpenAIChat","model_id":"gpt-4o-mini","parser_model":null,"output_model":null,"has_tools":true,"has_memory":false,"has_learnings":false,"has_culture":false,"has_reasoning":false,"has_knowledge":false,"has_input_schema":false,"has_output_schema":false,"has_team":false},"sdk_version":"2.4.1","type":"agent"}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, zstd + Connection: + - keep-alive + Content-Length: + - '493' + Content-Type: + - application/json + Host: + - os-api.agno.com + user-agent: + - agno/2.4.1 + method: POST + uri: https://os-api.agno.com/telemetry/runs + response: + body: + string: '{"message":"Run creation acknowledged: 4ebf7a0f-31fa-4a69-9500-f3f3f21d350d","status":"success"}' + headers: + content-length: + - '96' + content-type: + - application/json + date: + - Thu, 22 Jan 2026 00:38:19 GMT + server: + - uvicorn + status: + code: 201 + message: null +version: 1 diff --git a/py/src/braintrust/integrations/agno/integration.py b/py/src/braintrust/integrations/agno/integration.py new file mode 100644 index 00000000..94d0d9c4 --- /dev/null +++ b/py/src/braintrust/integrations/agno/integration.py @@ -0,0 +1,31 @@ +"""Agno integration — orchestration class and setup entry-point.""" + +import logging + +from braintrust.integrations.base import BaseIntegration + +from .patchers import ( + AgentPatcher, + FunctionCallPatcher, + ModelPatcher, + TeamPatcher, + WorkflowPatcher, +) + + +logger = logging.getLogger(__name__) + + +class AgnoIntegration(BaseIntegration): + """Braintrust instrumentation for Agno.""" + + name = "agno" + import_names = ("agno",) + min_version = "2.1.0" + patchers = ( + AgentPatcher, + TeamPatcher, + ModelPatcher, + FunctionCallPatcher, + WorkflowPatcher, + ) diff --git a/py/src/braintrust/integrations/agno/patchers.py b/py/src/braintrust/integrations/agno/patchers.py new file mode 100644 index 00000000..112c947f --- /dev/null +++ b/py/src/braintrust/integrations/agno/patchers.py @@ -0,0 +1,379 @@ +from typing import Any, ClassVar + +from braintrust.integrations.base import CompositeFunctionWrapperPatcher, FunctionWrapperPatcher + +from .tracing import ( + _agent_arun_private_wrapper, + _agent_arun_public_wrapper, + _agent_arun_stream_wrapper, + _agent_run_private_wrapper, + _agent_run_public_wrapper, + _agent_run_stream_wrapper, + _function_call_aexecute_wrapper, + _function_call_execute_wrapper, + _model_ainvoke_stream_wrapper, + _model_ainvoke_wrapper, + _model_aresponse_stream_wrapper, + _model_aresponse_wrapper, + _model_invoke_stream_wrapper, + _model_invoke_wrapper, + _model_response_stream_wrapper, + _model_response_wrapper, + _team_arun_private_wrapper, + _team_arun_public_wrapper, + _team_arun_stream_wrapper, + _team_run_private_wrapper, + _team_run_public_wrapper, + _team_run_stream_wrapper, + _workflow_aexecute_stream_wrapper, + _workflow_aexecute_workflow_agent_wrapper, + _workflow_aexecute_wrapper, + _workflow_execute_stream_wrapper, + _workflow_execute_workflow_agent_wrapper, + _workflow_execute_wrapper, +) + + +# --------------------------------------------------------------------------- +# Agent patchers +# --------------------------------------------------------------------------- + +# Private methods have higher priority (lower number) so they are tried first. +# The public fallback patchers override applies() to yield when the private +# variant exists. + + +class _AgentRunPrivatePatcher(FunctionWrapperPatcher): + name = "agno.agent.run.private" + target_module = "agno.agent" + target_path = "Agent._run" + wrapper = _agent_run_private_wrapper + priority: ClassVar[int] = 50 + + +class _AgentRunPublicPatcher(FunctionWrapperPatcher): + """Fallback: wrap ``Agent.run`` only when ``Agent._run`` does not exist.""" + + name = "agno.agent.run.public" + target_module = "agno.agent" + target_path = "Agent.run" + wrapper = _agent_run_public_wrapper + priority: ClassVar[int] = 100 + superseded_by = (_AgentRunPrivatePatcher,) + + +class _AgentArunPrivatePatcher(FunctionWrapperPatcher): + name = "agno.agent.arun.private" + target_module = "agno.agent" + target_path = "Agent._arun" + wrapper = _agent_arun_private_wrapper + priority: ClassVar[int] = 50 + + +class _AgentRunStreamPatcher(FunctionWrapperPatcher): + name = "agno.agent.run_stream" + target_module = "agno.agent" + target_path = "Agent._run_stream" + wrapper = _agent_run_stream_wrapper + + +class _AgentArunStreamPatcher(FunctionWrapperPatcher): + name = "agno.agent.arun_stream" + target_module = "agno.agent" + target_path = "Agent._arun_stream" + wrapper = _agent_arun_stream_wrapper + priority: ClassVar[int] = 50 + + +class _AgentArunPublicPatcher(FunctionWrapperPatcher): + """Fallback: wrap ``Agent.arun`` only when neither ``_arun`` nor ``_arun_stream`` exist.""" + + name = "agno.agent.arun.public" + target_module = "agno.agent" + target_path = "Agent.arun" + wrapper = _agent_arun_public_wrapper + priority: ClassVar[int] = 100 + superseded_by = (_AgentArunPrivatePatcher, _AgentArunStreamPatcher) + + +class AgentPatcher(CompositeFunctionWrapperPatcher): + """Patch ``agno.agent.Agent`` for tracing.""" + + name = "agno.agent" + sub_patchers = ( + _AgentRunPrivatePatcher, + _AgentRunPublicPatcher, + _AgentArunPrivatePatcher, + _AgentRunStreamPatcher, + _AgentArunStreamPatcher, + _AgentArunPublicPatcher, + ) + + +# --------------------------------------------------------------------------- +# Team patchers +# --------------------------------------------------------------------------- + + +class _TeamRunPrivatePatcher(FunctionWrapperPatcher): + name = "agno.team.run.private" + target_module = "agno.team" + target_path = "Team._run" + wrapper = _team_run_private_wrapper + priority: ClassVar[int] = 50 + + +class _TeamRunPublicPatcher(FunctionWrapperPatcher): + """Fallback: wrap ``Team.run`` only when ``Team._run`` does not exist.""" + + name = "agno.team.run.public" + target_module = "agno.team" + target_path = "Team.run" + wrapper = _team_run_public_wrapper + priority: ClassVar[int] = 100 + superseded_by = (_TeamRunPrivatePatcher,) + + +class _TeamArunPrivatePatcher(FunctionWrapperPatcher): + name = "agno.team.arun.private" + target_module = "agno.team" + target_path = "Team._arun" + wrapper = _team_arun_private_wrapper + priority: ClassVar[int] = 50 + + +class _TeamRunStreamPatcher(FunctionWrapperPatcher): + name = "agno.team.run_stream" + target_module = "agno.team" + target_path = "Team._run_stream" + wrapper = _team_run_stream_wrapper + + +class _TeamArunStreamPatcher(FunctionWrapperPatcher): + name = "agno.team.arun_stream" + target_module = "agno.team" + target_path = "Team._arun_stream" + wrapper = _team_arun_stream_wrapper + priority: ClassVar[int] = 50 + + +class _TeamArunPublicPatcher(FunctionWrapperPatcher): + """Fallback: wrap ``Team.arun`` only when neither ``_arun`` nor ``_arun_stream`` exist.""" + + name = "agno.team.arun.public" + target_module = "agno.team" + target_path = "Team.arun" + wrapper = _team_arun_public_wrapper + priority: ClassVar[int] = 100 + superseded_by = (_TeamArunPrivatePatcher, _TeamArunStreamPatcher) + + +class TeamPatcher(CompositeFunctionWrapperPatcher): + """Patch ``agno.team.Team`` for tracing.""" + + name = "agno.team" + sub_patchers = ( + _TeamRunPrivatePatcher, + _TeamRunPublicPatcher, + _TeamArunPrivatePatcher, + _TeamRunStreamPatcher, + _TeamArunStreamPatcher, + _TeamArunPublicPatcher, + ) + + +# --------------------------------------------------------------------------- +# Model patchers +# --------------------------------------------------------------------------- + + +class _ModelInvokePatcher(FunctionWrapperPatcher): + name = "agno.model.invoke" + target_module = "agno.models.base" + target_path = "Model.invoke" + wrapper = _model_invoke_wrapper + + +class _ModelAinvokePatcher(FunctionWrapperPatcher): + name = "agno.model.ainvoke" + target_module = "agno.models.base" + target_path = "Model.ainvoke" + wrapper = _model_ainvoke_wrapper + + +class _ModelInvokeStreamPatcher(FunctionWrapperPatcher): + name = "agno.model.invoke_stream" + target_module = "agno.models.base" + target_path = "Model.invoke_stream" + wrapper = _model_invoke_stream_wrapper + + +class _ModelAinvokeStreamPatcher(FunctionWrapperPatcher): + name = "agno.model.ainvoke_stream" + target_module = "agno.models.base" + target_path = "Model.ainvoke_stream" + wrapper = _model_ainvoke_stream_wrapper + + +class _ModelResponsePatcher(FunctionWrapperPatcher): + name = "agno.model.response" + target_module = "agno.models.base" + target_path = "Model.response" + wrapper = _model_response_wrapper + + +class _ModelAresponsePatcher(FunctionWrapperPatcher): + name = "agno.model.aresponse" + target_module = "agno.models.base" + target_path = "Model.aresponse" + wrapper = _model_aresponse_wrapper + + +class _ModelResponseStreamPatcher(FunctionWrapperPatcher): + name = "agno.model.response_stream" + target_module = "agno.models.base" + target_path = "Model.response_stream" + wrapper = _model_response_stream_wrapper + + +class _ModelAresponseStreamPatcher(FunctionWrapperPatcher): + name = "agno.model.aresponse_stream" + target_module = "agno.models.base" + target_path = "Model.aresponse_stream" + wrapper = _model_aresponse_stream_wrapper + + +class ModelPatcher(CompositeFunctionWrapperPatcher): + """Patch ``agno.models.base.Model`` for tracing.""" + + name = "agno.model" + sub_patchers = ( + _ModelInvokePatcher, + _ModelAinvokePatcher, + _ModelInvokeStreamPatcher, + _ModelAinvokeStreamPatcher, + _ModelResponsePatcher, + _ModelAresponsePatcher, + _ModelResponseStreamPatcher, + _ModelAresponseStreamPatcher, + ) + + +# --------------------------------------------------------------------------- +# FunctionCall patchers +# --------------------------------------------------------------------------- + + +class _FunctionCallExecutePatcher(FunctionWrapperPatcher): + name = "agno.function_call.execute" + target_module = "agno.tools.function" + target_path = "FunctionCall.execute" + wrapper = _function_call_execute_wrapper + + +class _FunctionCallAexecutePatcher(FunctionWrapperPatcher): + name = "agno.function_call.aexecute" + target_module = "agno.tools.function" + target_path = "FunctionCall.aexecute" + wrapper = _function_call_aexecute_wrapper + + +class FunctionCallPatcher(CompositeFunctionWrapperPatcher): + """Patch ``agno.tools.function.FunctionCall`` for tracing.""" + + name = "agno.function_call" + sub_patchers = ( + _FunctionCallExecutePatcher, + _FunctionCallAexecutePatcher, + ) + + +# --------------------------------------------------------------------------- +# Workflow patchers (optional — requires fastapi) +# --------------------------------------------------------------------------- + + +class _WorkflowExecutePatcher(FunctionWrapperPatcher): + name = "agno.workflow.execute" + target_module = "agno.workflow" + target_path = "Workflow._execute" + wrapper = _workflow_execute_wrapper + + +class _WorkflowExecuteStreamPatcher(FunctionWrapperPatcher): + name = "agno.workflow.execute_stream" + target_module = "agno.workflow" + target_path = "Workflow._execute_stream" + wrapper = _workflow_execute_stream_wrapper + + +class _WorkflowAexecutePatcher(FunctionWrapperPatcher): + name = "agno.workflow.aexecute" + target_module = "agno.workflow" + target_path = "Workflow._aexecute" + wrapper = _workflow_aexecute_wrapper + + +class _WorkflowAexecuteStreamPatcher(FunctionWrapperPatcher): + name = "agno.workflow.aexecute_stream" + target_module = "agno.workflow" + target_path = "Workflow._aexecute_stream" + wrapper = _workflow_aexecute_stream_wrapper + + +class _WorkflowExecuteWorkflowAgentPatcher(FunctionWrapperPatcher): + name = "agno.workflow.execute_workflow_agent" + target_module = "agno.workflow" + target_path = "Workflow._execute_workflow_agent" + wrapper = _workflow_execute_workflow_agent_wrapper + + +class _WorkflowAexecuteWorkflowAgentPatcher(FunctionWrapperPatcher): + name = "agno.workflow.aexecute_workflow_agent" + target_module = "agno.workflow" + target_path = "Workflow._aexecute_workflow_agent" + wrapper = _workflow_aexecute_workflow_agent_wrapper + + +class WorkflowPatcher(CompositeFunctionWrapperPatcher): + """Patch ``agno.workflow.Workflow`` for tracing (optional — requires fastapi).""" + + name = "agno.workflow" + sub_patchers = ( + _WorkflowExecutePatcher, + _WorkflowExecuteStreamPatcher, + _WorkflowAexecutePatcher, + _WorkflowAexecuteStreamPatcher, + _WorkflowExecuteWorkflowAgentPatcher, + _WorkflowAexecuteWorkflowAgentPatcher, + ) + + +# --------------------------------------------------------------------------- +# Public wrap_*() helpers — thin wrappers around patcher.wrap_target() +# --------------------------------------------------------------------------- + + +def wrap_agent(Agent: Any) -> Any: + """Manually patch an Agent class for tracing.""" + return AgentPatcher.wrap_target(Agent) + + +def wrap_team(Team: Any) -> Any: + """Manually patch a Team class for tracing.""" + return TeamPatcher.wrap_target(Team) + + +def wrap_model(Model: Any) -> Any: + """Manually patch a Model class for tracing.""" + return ModelPatcher.wrap_target(Model) + + +def wrap_function_call(FunctionCall: Any) -> Any: + """Manually patch a FunctionCall class for tracing.""" + return FunctionCallPatcher.wrap_target(FunctionCall) + + +def wrap_workflow(Workflow: Any) -> Any: + """Manually patch a Workflow class for tracing.""" + return WorkflowPatcher.wrap_target(Workflow) diff --git a/py/src/braintrust/wrappers/agno/test_agno.py b/py/src/braintrust/integrations/agno/test_agno.py similarity index 92% rename from py/src/braintrust/wrappers/agno/test_agno.py rename to py/src/braintrust/integrations/agno/test_agno.py index 25c1406e..412e9421 100644 --- a/py/src/braintrust/wrappers/agno/test_agno.py +++ b/py/src/braintrust/integrations/agno/test_agno.py @@ -8,15 +8,11 @@ import pytest from braintrust import logger +from braintrust.integrations.agno import setup_agno +from braintrust.integrations.agno import tracing as agno_tracing_module +from braintrust.integrations.agno.patchers import wrap_agent, wrap_team from braintrust.logger import start_span from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.agno import agent as agno_agent_module -from braintrust.wrappers.agno import model as agno_model_module -from braintrust.wrappers.agno import run_helpers as agno_run_helpers_module -from braintrust.wrappers.agno import setup_agno -from braintrust.wrappers.agno import team as agno_team_module -from braintrust.wrappers.agno.agent import wrap_agent -from braintrust.wrappers.agno.team import wrap_team from braintrust.wrappers.test_utils import verify_autoinstrument_script from ._test_agno_helpers import ( @@ -40,7 +36,7 @@ def memory_logger(): @pytest.fixture(scope="module") def vcr_config(): return { - "cassette_library_dir": str(Path(__file__).parent.parent / "cassettes"), + "cassette_library_dir": str(Path(__file__).parent / "cassettes"), } @@ -122,7 +118,7 @@ class FakeModel: def get_provider(self): return "OpenAI Chat" - assert agno_model_module._get_model_name(FakeModel()) == "OpenAI" + assert agno_tracing_module._get_model_name(FakeModel()) == "OpenAI" class TestAutoInstrumentAgno: @@ -209,16 +205,15 @@ async def test_agno_public_arun_awaited_async_iterator_compat(memory_logger, wra @pytest.mark.asyncio @pytest.mark.parametrize( - "module,wrapper,name", + "wrapper,name", [ - (agno_agent_module, wrap_agent, "StrictAgentAwaitedAsync"), - (agno_team_module, wrap_team, "StrictTeamAwaitedAsync"), + (wrap_agent, "StrictAgentAwaitedAsync"), + (wrap_team, "StrictTeamAwaitedAsync"), ], ) -async def test_agno_public_arun_awaited_async_iterator_span_lifecycle(monkeypatch, module, wrapper, name): +async def test_agno_public_arun_awaited_async_iterator_span_lifecycle(monkeypatch, wrapper, name): strict_span = StrictSpan() - monkeypatch.setattr(module, "start_span", lambda **kwargs: strict_span) - monkeypatch.setattr(agno_run_helpers_module, "start_span", lambda **kwargs: strict_span) + monkeypatch.setattr(agno_tracing_module, "start_span", lambda **kwargs: strict_span) Component = wrapper(make_fake_async_dispatch_component(name)) instance = Component() diff --git a/py/src/braintrust/wrappers/agno/test_workflow.py b/py/src/braintrust/integrations/agno/test_workflow.py similarity index 97% rename from py/src/braintrust/wrappers/agno/test_workflow.py rename to py/src/braintrust/integrations/agno/test_workflow.py index 199a52f9..2b2b622c 100644 --- a/py/src/braintrust/wrappers/agno/test_workflow.py +++ b/py/src/braintrust/integrations/agno/test_workflow.py @@ -8,9 +8,9 @@ import pytest from braintrust import logger +from braintrust.integrations.agno import setup_agno +from braintrust.integrations.agno.patchers import wrap_workflow from braintrust.test_helpers import init_test_logger -from braintrust.wrappers.agno import setup_agno -from braintrust.wrappers.agno.workflow import wrap_workflow from ._test_agno_helpers import ( PROJECT_NAME, @@ -34,7 +34,7 @@ def memory_logger(): @pytest.fixture(scope="module") def vcr_config(): return { - "cassette_library_dir": str(Path(__file__).parent.parent / "cassettes"), + "cassette_library_dir": str(Path(__file__).parent / "cassettes"), } diff --git a/py/src/braintrust/integrations/agno/tracing.py b/py/src/braintrust/integrations/agno/tracing.py new file mode 100644 index 00000000..65c45dba --- /dev/null +++ b/py/src/braintrust/integrations/agno/tracing.py @@ -0,0 +1,1316 @@ +import time +from inspect import isawaitable +from typing import Any + +from braintrust.logger import start_span +from braintrust.span_types import SpanTypeAttribute +from braintrust.util import is_numeric + + +# --------------------------------------------------------------------------- +# Small helpers +# --------------------------------------------------------------------------- + + +def omit(obj: dict[str, Any], keys: list[str]): + return {k: v for k, v in obj.items() if k not in keys} + + +def clean(obj: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in obj.items() if v is not None} + + +def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: list[str]): + return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys) + + +def _try_to_dict(obj: Any) -> Any: + """Convert object to dict, handling different object types like OpenAI wrapper.""" + if isinstance(obj, dict): + return obj + if hasattr(obj, "model_dump") and callable(obj.model_dump): + try: + return obj.model_dump() + except Exception: + pass + if hasattr(obj, "dict") and callable(obj.dict): + try: + return obj.dict() + except Exception: + pass + if hasattr(obj, "__dict__"): + try: + return obj.__dict__.copy() + except Exception: + pass + return obj + + +def is_sync_iterator(result: Any) -> bool: + return hasattr(result, "__iter__") and hasattr(result, "__next__") + + +def is_async_iterator(result: Any) -> bool: + return hasattr(result, "__aiter__") and hasattr(result, "__anext__") + + +# --------------------------------------------------------------------------- +# Metrics mapping & extraction +# --------------------------------------------------------------------------- + +AGNO_METRICS_MAP = { + "input_tokens": "prompt_tokens", + "output_tokens": "completion_tokens", + "total_tokens": "tokens", + "reasoning_tokens": "completion_reasoning_tokens", + "audio_input_tokens": "prompt_audio_tokens", + "audio_output_tokens": "completion_audio_tokens", + "cache_read_tokens": "prompt_cached_tokens", + "cache_write_tokens": "prompt_cache_creation_tokens", + "duration": "duration", + "time_to_first_token": "time_to_first_token", +} + + +def extract_metadata(instance: Any, component: str) -> dict[str, Any]: + """Extract metadata from any component (model, agent, team).""" + metadata = {"component": component} + + if component == "model": + if hasattr(instance, "id") and instance.id: + metadata["model"] = instance.id + metadata["model_id"] = instance.id + if hasattr(instance, "provider") and instance.provider: + metadata["provider"] = instance.provider + if hasattr(instance, "name") and instance.name: + metadata["model_name"] = instance.name + if hasattr(instance, "__class__"): + metadata["model_class"] = instance.__class__.__name__ + elif component == "agent": + metadata["agent_name"] = getattr(instance, "name", None) + model = getattr(instance, "model", None) + if model: + metadata["model"] = getattr(model, "id", None) or model.__class__.__name__ + elif component == "team": + metadata["team_name"] = getattr(instance, "name", None) + model = getattr(instance, "model", None) + if model: + metadata["model"] = getattr(model, "id", None) or model.__class__.__name__ + elif component == "workflow": + metadata["workflow_id"] = getattr(instance, "id", None) + metadata["workflow_name"] = getattr(instance, "name", None) + steps = getattr(instance, "steps", None) + if steps: + metadata["steps_count"] = len(steps) + + return metadata + + +def parse_metrics_from_agno(usage: Any) -> dict[str, Any]: + """Parse metrics from Agno usage object, following OpenAI wrapper pattern.""" + metrics = {} + if not usage: + return metrics + usage_dict = _try_to_dict(usage) + if not isinstance(usage_dict, dict): + return metrics + for agno_name, value in usage_dict.items(): + if agno_name in AGNO_METRICS_MAP and is_numeric(value) and value != 0: + braintrust_name = AGNO_METRICS_MAP[agno_name] + metrics[braintrust_name] = value + return metrics + + +def extract_metrics(result: Any, messages: list | None = None) -> dict[str, Any]: + """Unified metrics extraction for all components.""" + if hasattr(result, "response_usage") and result.response_usage: + return parse_metrics_from_agno(result.response_usage) + if hasattr(result, "metrics") and result.metrics: + metrics = parse_metrics_from_agno(result.metrics) + return metrics if metrics else None + if messages: + for msg in messages: + if hasattr(msg, "role") and msg.role == "assistant" and hasattr(msg, "metrics") and msg.metrics: + return parse_metrics_from_agno(msg.metrics) + return {} + + +def extract_streaming_metrics(aggregated: dict[str, Any], start_time: float) -> dict[str, Any] | None: + """Extract metrics from aggregated streaming response.""" + metrics = {} + if aggregated.get("metrics") and isinstance(aggregated["metrics"], dict): + metrics.update(aggregated["metrics"]) + elif aggregated.get("metrics"): + parsed_metrics = parse_metrics_from_agno(aggregated["metrics"]) + if parsed_metrics: + metrics.update(parsed_metrics) + elif aggregated.get("response_usage"): + response_metrics = parse_metrics_from_agno(aggregated["response_usage"]) + if response_metrics: + metrics.update(response_metrics) + metrics["duration"] = time.time() - start_time + return metrics if metrics else None + + +# --------------------------------------------------------------------------- +# Chunk aggregation +# --------------------------------------------------------------------------- + + +def _aggregate_metrics(target: dict[str, Any], source: dict[str, Any]) -> None: + """Aggregate metrics from source into target dict.""" + for key, value in source.items(): + if is_numeric(value): + if key in target: + if "time" in key.lower() or "duration" in key.lower(): + target[key] = value + elif "token" in key.lower() or key == "tokens": + target[key] = (target.get(key, 0) or 0) + value + else: + target[key] = value + else: + target[key] = value + + +def _aggregate_model_chunks(chunks: list[Any]) -> dict[str, Any]: + """Aggregate ModelResponse chunks from invoke_stream into a complete response.""" + aggregated = { + "content": "", + "reasoning_content": "", + "tool_calls": [], + "role": None, + "audio": None, + "images": [], + "videos": [], + "files": [], + "citations": None, + "metrics": {}, + } + + for chunk in chunks: + if hasattr(chunk, "content") and chunk.content: + aggregated["content"] += str(chunk.content) + if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: + aggregated["reasoning_content"] += chunk.reasoning_content + if hasattr(chunk, "role") and chunk.role and not aggregated["role"]: + aggregated["role"] = chunk.role + if hasattr(chunk, "tool_calls") and chunk.tool_calls: + aggregated["tool_calls"].extend(chunk.tool_calls) + if hasattr(chunk, "audio") and chunk.audio: + aggregated["audio"] = chunk.audio + if hasattr(chunk, "images") and chunk.images: + aggregated["images"].extend(chunk.images) + if hasattr(chunk, "videos") and chunk.videos: + aggregated["videos"].extend(chunk.videos) + if hasattr(chunk, "files") and chunk.files: + aggregated["files"].extend(chunk.files) + if hasattr(chunk, "citations") and chunk.citations: + aggregated["citations"] = chunk.citations + if hasattr(chunk, "response_usage") and chunk.response_usage: + chunk_metrics = parse_metrics_from_agno(chunk.response_usage) + if chunk_metrics: + _aggregate_metrics(aggregated["metrics"], chunk_metrics) + + if aggregated["metrics"]: + aggregated["response_usage"] = aggregated["metrics"] + else: + aggregated["metrics"] = None + + return aggregated + + +def _aggregate_response_stream_chunks(chunks: list[Any]) -> dict[str, Any]: + """Aggregate chunks from response_stream (ModelResponse, RunOutputEvent, etc.).""" + aggregated = { + "content": "", + "reasoning_content": "", + "tool_calls": [], + "role": None, + "audio": None, + "images": [], + "videos": [], + "files": [], + "citations": None, + "metrics": {}, + } + + for chunk in chunks: + if hasattr(chunk, "__class__") and "ModelResponse" in chunk.__class__.__name__: + if hasattr(chunk, "content") and chunk.content: + aggregated["content"] += str(chunk.content) + if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: + aggregated["reasoning_content"] += chunk.reasoning_content + if hasattr(chunk, "role") and chunk.role and not aggregated["role"]: + aggregated["role"] = chunk.role + if hasattr(chunk, "tool_calls") and chunk.tool_calls: + aggregated["tool_calls"].extend(chunk.tool_calls) + if hasattr(chunk, "audio") and chunk.audio: + aggregated["audio"] = chunk.audio + if hasattr(chunk, "images") and chunk.images: + aggregated["images"].extend(chunk.images) + if hasattr(chunk, "videos") and chunk.videos: + aggregated["videos"].extend(chunk.videos) + if hasattr(chunk, "files") and chunk.files: + aggregated["files"].extend(chunk.files) + if hasattr(chunk, "citations") and chunk.citations: + aggregated["citations"] = chunk.citations + if hasattr(chunk, "response_usage") and chunk.response_usage: + chunk_metrics = parse_metrics_from_agno(chunk.response_usage) + if chunk_metrics: + _aggregate_metrics(aggregated["metrics"], chunk_metrics) + elif hasattr(chunk, "metrics") and chunk.metrics: + chunk_metrics = parse_metrics_from_agno(chunk.metrics) + if chunk_metrics: + _aggregate_metrics(aggregated["metrics"], chunk_metrics) + elif hasattr(chunk, "content"): + if chunk.content: + aggregated["content"] += str(chunk.content) + + if hasattr(chunk, "metrics") and chunk.metrics and "metrics" not in str(type(chunk)): + chunk_metrics = parse_metrics_from_agno(chunk.metrics) + if chunk_metrics: + _aggregate_metrics(aggregated["metrics"], chunk_metrics) + + if aggregated["metrics"]: + aggregated["response_usage"] = aggregated["metrics"] + else: + aggregated["metrics"] = None + + return aggregated + + +def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: + """Aggregate BaseAgentRunEvent/BaseTeamRunEvent chunks into a complete response.""" + aggregated = { + "content": "", + "reasoning_content": "", + "model": "", + "model_provider": "", + "tool_calls": [], + "citations": None, + "references": None, + "metrics": None, + "finish_reason": None, + } + + for chunk in chunks: + event = getattr(chunk, "event", None) + + if event == "RunStarted": + if hasattr(chunk, "model"): + aggregated["model"] = chunk.model + if hasattr(chunk, "model_provider"): + aggregated["model_provider"] = chunk.model_provider + elif event == "RunContent": + if hasattr(chunk, "content") and chunk.content: + aggregated["content"] += str(chunk.content) # type: ignore + if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: + aggregated["reasoning_content"] += chunk.reasoning_content + if hasattr(chunk, "citations"): + aggregated["citations"] = chunk.citations + if hasattr(chunk, "references"): + aggregated["references"] = chunk.references + elif event == "RunCompleted": + if hasattr(chunk, "metrics"): + parsed_metrics = parse_metrics_from_agno(chunk.metrics) + aggregated["metrics"] = parsed_metrics if parsed_metrics else chunk.metrics + aggregated["finish_reason"] = "stop" + elif event == "RunError": + aggregated["finish_reason"] = "error" + elif event == "ToolCallStarted": + if hasattr(chunk, "tool_call"): + aggregated["tool_calls"].append( # type:ignore + { + "id": getattr(chunk.tool_call, "id", None), + "type": "function", + "function": { + "name": getattr(chunk.tool_call, "name", None), + "arguments": getattr(chunk.tool_call, "arguments", ""), + }, + } + ) + + return {k: v for k, v in aggregated.items() if v not in (None, "")} + + +def _aggregate_workflow_chunks(chunks: list[Any], workflow_run_response: Any | None = None) -> dict[str, Any]: + """Aggregate workflow/step events into a final workflow-style response.""" + aggregated = { + "content": "", + "status": None, + "metrics": None, + } + final_workflow_content = None + + for chunk in chunks: + event = getattr(chunk, "event", None) + + if hasattr(chunk, "content") and chunk.content: + if event == "WorkflowCompleted": + final_workflow_content = str(chunk.content) + elif final_workflow_content is None: + aggregated["content"] += str(chunk.content) + + if hasattr(chunk, "status") and chunk.status: + aggregated["status"] = chunk.status + + if hasattr(chunk, "metrics") and chunk.metrics: + parsed_metrics = parse_metrics_from_agno(chunk.metrics) + aggregated["metrics"] = parsed_metrics if parsed_metrics else chunk.metrics + + if final_workflow_content is not None: + accumulated_content = aggregated["content"] + if not accumulated_content: + aggregated["content"] = final_workflow_content + elif accumulated_content.endswith(final_workflow_content): + aggregated["content"] = accumulated_content + else: + aggregated["content"] = f"{accumulated_content}{final_workflow_content}" + + if workflow_run_response is not None: + if not aggregated["content"] and hasattr(workflow_run_response, "content") and workflow_run_response.content: + aggregated["content"] = str(workflow_run_response.content) + if not aggregated["status"] and hasattr(workflow_run_response, "status") and workflow_run_response.status: + aggregated["status"] = workflow_run_response.status + if not aggregated["metrics"] and hasattr(workflow_run_response, "metrics") and workflow_run_response.metrics: + parsed_metrics = parse_metrics_from_agno(workflow_run_response.metrics) + aggregated["metrics"] = parsed_metrics if parsed_metrics else workflow_run_response.metrics + + return {k: v for k, v in aggregated.items() if v not in (None, "")} + + +# --------------------------------------------------------------------------- +# Stream tracing helpers +# --------------------------------------------------------------------------- + + +def _trace_sync_stream(result: Any, span: Any, start: float): + def _inner(): + should_unset = True + try: + first = True + all_chunks = [] + for chunk in result: + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_agent_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _inner() + + +def _trace_async_stream(result: Any, span: Any, start: float): + async def _inner(): + should_unset = True + try: + first = True + all_chunks = [] + async for chunk in result: + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_agent_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _inner() + + +# =========================================================================== +# Raw wrapt wrapper functions — used by FunctionWrapperPatcher in patchers.py +# =========================================================================== + + +# --------------------------------------------------------------------------- +# Agent / Team private wrappers +# --------------------------------------------------------------------------- + + +def _agent_run_private_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Agent._run(run_response, run_messages).""" + run_response = args[0] if len(args) > 0 else kwargs.get("run_response") + run_messages = args[1] if len(args) > 1 else kwargs.get("run_messages") + input_data = {"run_response": run_response, "run_messages": run_messages} + agent_name = getattr(instance, "name", None) or "Agent" + with start_span( + name=f"{agent_name}.run", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "agent")}, + ) as span: + result = wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result)) + return result + + +async def _agent_arun_private_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Agent._arun(run_response, input).""" + run_response = args[0] if len(args) > 0 else kwargs.get("run_response") + input_arg = args[1] if len(args) > 1 else kwargs.get("input") + input_data = {"run_response": run_response, "input": input_arg} + agent_name = getattr(instance, "name", None) or "Agent" + with start_span( + name=f"{agent_name}.arun", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "agent")}, + ) as span: + result = await wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result)) + return result + + +def _agent_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Agent._run_stream.""" + agent_name = getattr(instance, "name", None) or "Agent" + run_response = args[0] if args else kwargs.get("run_response") + run_messages = args[1] if args else kwargs.get("run_messages") + + def _trace_stream(): + start = time.time() + span = start_span( + name=f"{agent_name}.run_stream", + type=SpanTypeAttribute.TASK, + input={"run_response": run_response, "run_messages": run_messages}, + metadata={**omit(kwargs, ["run_response", "run_messages"]), **extract_metadata(instance, "agent")}, + ) + span.set_current() + should_unset = True + try: + first = True + all_chunks = [] + for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_agent_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +def _agent_arun_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Agent._arun_stream.""" + agent_name = getattr(instance, "name", None) or "Agent" + run_response = args[0] if args else kwargs.get("run_response") + input = args[2] if args else kwargs.get("input") + + async def _trace_stream(): + start = time.time() + span = start_span( + name=f"{agent_name}.arun_stream", + type=SpanTypeAttribute.TASK, + input={"run_response": run_response, "input": input}, + metadata={**omit(kwargs, ["run_response", "input"]), **extract_metadata(instance, "agent")}, + ) + span.set_current() + should_unset = True + try: + first = True + all_chunks = [] + async for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_agent_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +def _team_run_private_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Team._run(run_response, run_messages).""" + run_response = args[0] if len(args) > 0 else kwargs.get("run_response") + run_messages = args[1] if len(args) > 1 else kwargs.get("run_messages") + input_data = {"run_response": run_response, "run_messages": run_messages} + team_name = getattr(instance, "name", None) or "Team" + with start_span( + name=f"{team_name}.run", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "team")}, + ) as span: + result = wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result)) + return result + + +async def _team_arun_private_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Team._arun(run_response, input).""" + run_response = args[0] if len(args) > 0 else kwargs.get("run_response") + input_arg = args[1] if len(args) > 1 else kwargs.get("input") + input_data = {"run_response": run_response, "input": input_arg} + team_name = getattr(instance, "name", None) or "Team" + with start_span( + name=f"{team_name}.arun", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "team")}, + ) as span: + result = await wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result)) + return result + + +def _team_run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Team._run_stream.""" + team_name = getattr(instance, "name", None) or "Team" + run_response = args[0] if args else kwargs.get("run_response") + run_messages = args[1] if args else kwargs.get("run_messages") + + def _trace_stream(): + start = time.time() + span = start_span( + name=f"{team_name}.run_stream", + type=SpanTypeAttribute.TASK, + input={"run_response": run_response, "run_messages": run_messages}, + metadata={**omit(kwargs, ["run_response", "run_messages"]), **extract_metadata(instance, "team")}, + ) + span.set_current() + should_unset = True + try: + first = True + all_chunks = [] + for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_agent_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +def _team_arun_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + """Wrapper for Team._arun_stream.""" + team_name = getattr(instance, "name", None) or "Team" + run_response = args[0] if args else kwargs.get("run_response") + input = args[2] if args else kwargs.get("input") + + async def _trace_stream(): + start = time.time() + span = start_span( + name=f"{team_name}.arun_stream", + type=SpanTypeAttribute.TASK, + input={"run_response": run_response, "input": input}, + metadata={**omit(kwargs, ["run_response", "input"]), **extract_metadata(instance, "team")}, + ) + span.set_current() + should_unset = True + try: + first = True + all_chunks = [] + async for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_agent_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +# --------------------------------------------------------------------------- +# Agent / Team public dispatch wrappers (Agno >= 2.5) +# --------------------------------------------------------------------------- + + +def _run_public_dispatch_wrapper( + wrapped: Any, + instance: Any, + args: Any, + kwargs: Any, + *, + default_name: str, + metadata_component: str, +) -> Any: + """Trace a public synchronous `run(...)` dispatch method.""" + component_name = getattr(instance, "name", None) or default_name + input_arg = args[0] if len(args) > 0 else kwargs.get("input") + input_data = {"input": input_arg} + metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)} + + span = start_span( + name=f"{component_name}.run", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=metadata, + ) + span.set_current() + start = time.time() + try: + result = wrapped(*args, **kwargs) + if is_sync_iterator(result): + return _trace_sync_stream(result, span, start) + span.log(output=result, metrics=extract_metrics(result)) + span.unset_current() + span.end() + return result + except Exception as e: + span.log(error=str(e)) + span.unset_current() + span.end() + raise + + +def _arun_public_dispatch_wrapper( + wrapped: Any, + instance: Any, + args: Any, + kwargs: Any, + *, + default_name: str, + metadata_component: str, +) -> Any: + """Trace a public `arun(...)` dispatch method across async return contracts.""" + component_name = getattr(instance, "name", None) or default_name + input_arg = args[0] if len(args) > 0 else kwargs.get("input") + input_data = {"input": input_arg} + metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)} + + span = start_span( + name=f"{component_name}.arun", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=metadata, + ) + span.set_current() + start = time.time() + try: + result = wrapped(*args, **kwargs) + + if isawaitable(result): + + async def _trace_awaitable(): + should_end_span = True + try: + awaited = await result + if is_async_iterator(awaited): + should_end_span = False + return _trace_async_stream(awaited, span, start) + span.log(output=awaited, metrics=extract_metrics(awaited)) + return awaited + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_end_span: + span.unset_current() + span.end() + + return _trace_awaitable() + + if is_async_iterator(result): + return _trace_async_stream(result, span, start) + + span.log(output=result, metrics=extract_metrics(result)) + span.unset_current() + span.end() + return result + except Exception as e: + span.log(error=str(e)) + span.unset_current() + span.end() + raise + + +def _agent_run_public_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return _run_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent" + ) + + +def _agent_arun_public_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return _arun_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent" + ) + + +def _team_run_public_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return _run_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Team", metadata_component="team" + ) + + +def _team_arun_public_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + return _arun_public_dispatch_wrapper( + wrapped, instance, args, kwargs, default_name="Team", metadata_component="team" + ) + + +# --------------------------------------------------------------------------- +# Model wrappers +# --------------------------------------------------------------------------- + + +def _get_model_name(instance: Any) -> str: + provider = getattr(instance, "provider", None) + if provider: + return str(provider) + if hasattr(instance, "get_provider") and callable(instance.get_provider): + return str(instance.get_provider()) + return getattr(instance.__class__, "__name__", "Model") + + +def _model_invoke_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["assistant_message", "messages", "response_format", "tools", "tool_choice"] + ) + with start_span( + name=f"{model_name}.invoke", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + result = wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result, kwargs.get("messages", []))) + return result + + +async def _model_ainvoke_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "assistant_message", "response_format", "tools", "tool_choice"] + ) + with start_span( + name=f"{model_name}.ainvoke", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + result = await wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result, kwargs.get("messages", []))) + return result + + +def _model_invoke_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "assistant_messages", "response_format", "tools", "tool_choice"] + ) + + def _trace_stream(): + start = time.time() + with start_span( + name=f"{model_name}.invoke_stream", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + first = True + collected_chunks = [] + for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + collected_chunks.append(chunk) + yield chunk + aggregated = _aggregate_model_chunks(collected_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + + return _trace_stream() + + +def _model_ainvoke_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "assistant_messages", "response_format", "tools", "tool_choice"] + ) + + async def _trace_astream(): + start = time.time() + with start_span( + name=f"{model_name}.ainvoke_stream", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + first = True + collected_chunks = [] + async for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + collected_chunks.append(chunk) + yield chunk + aggregated = _aggregate_model_chunks(collected_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + + return _trace_astream() + + +def _model_response_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] + ) + with start_span( + name=f"{model_name}.response", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + result = wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result, kwargs.get("messages", []))) + return result + + +async def _model_aresponse_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] + ) + with start_span( + name=f"{model_name}.aresponse", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + result = await wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result, kwargs.get("messages", []))) + return result + + +def _model_response_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] + ) + + def _trace_stream(): + start = time.time() + with start_span( + name=f"{model_name}.response_stream", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + first = True + collected_chunks = [] + for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + collected_chunks.append(chunk) + yield chunk + aggregated = _aggregate_response_stream_chunks(collected_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + + return _trace_stream() + + +def _model_aresponse_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + model_name = _get_model_name(instance) + input, clean_kwargs = get_args_kwargs( + args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] + ) + + async def _trace_astream(): + start = time.time() + with start_span( + name=f"{model_name}.aresponse_stream", + type=SpanTypeAttribute.LLM, + input=input, + metadata={**clean_kwargs, **extract_metadata(instance, "model")}, + ) as span: + first = True + collected_chunks = [] + async for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + collected_chunks.append(chunk) + yield chunk + aggregated = _aggregate_response_stream_chunks(collected_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + + return _trace_astream() + + +# --------------------------------------------------------------------------- +# FunctionCall wrappers +# --------------------------------------------------------------------------- + + +def _get_function_name(instance) -> str: + if hasattr(instance, "function") and hasattr(instance.function, "name"): + return instance.function.name + return "Unknown" + + +def _function_call_execute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + function_name = _get_function_name(instance) + entrypoint_args = instance._build_entrypoint_args() + with start_span( + name=f"{function_name}.execute", + type=SpanTypeAttribute.TOOL, + input=(instance.arguments or {}), + metadata={ + "name": instance.function.name, + "entrypoint": instance.function.entrypoint.__name__, + **(entrypoint_args or {}), + }, + ) as span: + result = wrapped(*args, **kwargs) + span.log(output=result) + return result + + +async def _function_call_aexecute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + function_name = _get_function_name(instance) + entrypoint_args = instance._build_entrypoint_args() + with start_span( + name=f"{function_name}.aexecute", + type=SpanTypeAttribute.TOOL, + input=(instance.arguments or {}), + metadata={ + "name": instance.function.name, + "entrypoint": instance.function.entrypoint.__name__, + **(entrypoint_args or {}), + }, + ) as span: + result = await wrapped(*args, **kwargs) + span.log(output=result) + return result + + +# --------------------------------------------------------------------------- +# Workflow wrappers +# --------------------------------------------------------------------------- + + +def _extract_workflow_input( + args: Any, + kwargs: Any, + *, + execution_input_index: int, + workflow_run_response_index: int, +) -> dict[str, Any]: + execution_input = ( + args[execution_input_index] if len(args) > execution_input_index else kwargs.get("execution_input") + ) + workflow_run_response = ( + args[workflow_run_response_index] + if len(args) > workflow_run_response_index + else kwargs.get("workflow_run_response") + ) + result: dict[str, Any] = {} + if execution_input: + if hasattr(execution_input, "input"): + result["input"] = execution_input.input + result["execution_input"] = _try_to_dict(execution_input) + if workflow_run_response: + result["run_response"] = _try_to_dict(workflow_run_response) + return result + + +def _extract_workflow_agent_input(args: Any, kwargs: Any) -> dict[str, Any]: + user_input = args[0] if len(args) > 0 else kwargs.get("user_input") + execution_input = args[2] if len(args) > 2 else kwargs.get("execution_input") + result: dict[str, Any] = {"input": user_input} + if execution_input: + result["execution_input"] = _try_to_dict(execution_input) + return result + + +def _workflow_execute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + workflow_name = getattr(instance, "name", None) or "Workflow" + input_data = _extract_workflow_input(args, kwargs, execution_input_index=1, workflow_run_response_index=2) + workflow_metadata = extract_metadata(instance, "workflow") + with start_span( + name=f"{workflow_name}.run", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=workflow_metadata, + propagated_event={"metadata": workflow_metadata}, + ) as span: + result = wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result)) + return result + + +def _workflow_execute_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + workflow_name = getattr(instance, "name", None) or "Workflow" + input_data = _extract_workflow_input(args, kwargs, execution_input_index=1, workflow_run_response_index=2) + workflow_metadata = extract_metadata(instance, "workflow") + workflow_run_response = args[2] if len(args) > 2 else kwargs.get("workflow_run_response") + + def _trace_stream(): + start = time.time() + span = start_span( + name=f"{workflow_name}.run_stream", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=workflow_metadata, + propagated_event={"metadata": workflow_metadata}, + ) + span.set_current() + should_unset = True + try: + first = True + all_chunks = [] + for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_workflow_chunks(all_chunks, workflow_run_response) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +async def _workflow_aexecute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + workflow_name = getattr(instance, "name", None) or "Workflow" + input_data = _extract_workflow_input(args, kwargs, execution_input_index=2, workflow_run_response_index=3) + workflow_metadata = extract_metadata(instance, "workflow") + with start_span( + name=f"{workflow_name}.arun", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=workflow_metadata, + propagated_event={"metadata": workflow_metadata}, + ) as span: + result = await wrapped(*args, **kwargs) + span.log(output=result, metrics=extract_metrics(result)) + return result + + +def _workflow_aexecute_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + workflow_name = getattr(instance, "name", None) or "Workflow" + input_data = _extract_workflow_input(args, kwargs, execution_input_index=2, workflow_run_response_index=3) + workflow_metadata = extract_metadata(instance, "workflow") + workflow_run_response = args[3] if len(args) > 3 else kwargs.get("workflow_run_response") + + async def _trace_stream(): + start = time.time() + span = start_span( + name=f"{workflow_name}.arun_stream", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=workflow_metadata, + propagated_event={"metadata": workflow_metadata}, + ) + span.set_current() + should_unset = True + try: + first = True + all_chunks = [] + async for chunk in wrapped(*args, **kwargs): + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_workflow_chunks(all_chunks, workflow_run_response) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + +def _workflow_execute_workflow_agent_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + workflow_name = getattr(instance, "name", None) or "Workflow" + stream = kwargs.get("stream", False) + span_suffix = "run_stream" if stream else "run" + workflow_metadata = extract_metadata(instance, "workflow") + input_data = _extract_workflow_agent_input(args, kwargs) + + span = start_span( + name=f"{workflow_name}.{span_suffix}", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=workflow_metadata, + propagated_event={"metadata": workflow_metadata}, + ) + span.set_current() + start = time.time() + try: + result = wrapped(*args, **kwargs) + if stream and is_sync_iterator(result): + + def _trace_stream(): + should_unset = True + try: + first = True + all_chunks = [] + for chunk in result: + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_workflow_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + span.log(output=result, metrics=extract_metrics(result)) + span.unset_current() + span.end() + return result + except Exception as e: + span.log(error=str(e)) + span.unset_current() + span.end() + raise + + +async def _workflow_aexecute_workflow_agent_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): + workflow_name = getattr(instance, "name", None) or "Workflow" + stream = kwargs.get("stream", False) + span_suffix = "arun_stream" if stream else "arun" + workflow_metadata = extract_metadata(instance, "workflow") + input_data = _extract_workflow_agent_input(args, kwargs) + + span = start_span( + name=f"{workflow_name}.{span_suffix}", + type=SpanTypeAttribute.TASK, + input=input_data, + metadata=workflow_metadata, + propagated_event={"metadata": workflow_metadata}, + ) + span.set_current() + start = time.time() + try: + result = await wrapped(*args, **kwargs) + if stream and is_async_iterator(result): + + async def _trace_stream(): + should_unset = True + try: + first = True + all_chunks = [] + async for chunk in result: + if first: + span.log(metrics={"time_to_first_token": time.time() - start}) + first = False + all_chunks.append(chunk) + yield chunk + aggregated = _aggregate_workflow_chunks(all_chunks) + span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) + except GeneratorExit: + should_unset = False + raise + except Exception as e: + span.log(error=str(e)) + raise + finally: + if should_unset: + span.unset_current() + span.end() + + return _trace_stream() + + span.log(output=result, metrics=extract_metrics(result)) + span.unset_current() + span.end() + return result + except Exception as e: + span.log(error=str(e)) + span.unset_current() + span.end() + raise diff --git a/py/src/braintrust/integrations/base.py b/py/src/braintrust/integrations/base.py index e3deaabb..3d021f91 100644 --- a/py/src/braintrust/integrations/base.py +++ b/py/src/braintrust/integrations/base.py @@ -50,11 +50,19 @@ class FunctionWrapperPatcher(BasePatcher): different module than the one provided by the integration (e.g. a deep submodule that may or may not be installed). The module is imported lazily when the patcher is evaluated. + + Set ``superseded_by`` to a tuple of other ``FunctionWrapperPatcher`` + subclasses that take priority over this patcher. If any of them apply + (i.e. their target exists), this patcher yields — both in the + ``setup()`` path (via ``applies()``) and in the manual ``wrap_target()`` + path. This is useful for version-conditional mutual exclusion, e.g. + wrapping a public ``run()`` only when the private ``_run()`` is absent. """ target_path: ClassVar[str] wrapper: ClassVar[Any] target_module: ClassVar[str | None] = None + superseded_by: ClassVar[tuple[type["FunctionWrapperPatcher"], ...]] = () @classmethod def resolve_root(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> Any | None: @@ -83,11 +91,18 @@ def resolve_target(cls, module: Any | None, version: str | None, *, target: Any @classmethod def applies(cls, module: Any | None, version: str | None, *, target: Any | None = None) -> bool: - """Return whether the target exists and this patcher's version gate passes.""" - return ( - super().applies(module, version, target=target) - and cls.resolve_target(module, version, target=target) is not None - ) + """Return whether the target exists and this patcher's version gate passes. + + Returns ``False`` if any patcher listed in ``superseded_by`` applies. + """ + if not super().applies(module, version, target=target): + return False + if cls.resolve_target(module, version, target=target) is None: + return False + for superior in cls.superseded_by: + if superior.applies(module, version, target=target): + return False + return True @classmethod def patch_marker_attr(cls) -> str: @@ -135,8 +150,9 @@ def wrap_target(cls, target: Any) -> Any: whether the patch has already been applied. Returns *target* unchanged if the leaf attribute does not exist on - *target* or the patch has already been applied. Returns *target* - for convenient chaining. + *target*, the patch has already been applied, or a patcher in + ``superseded_by`` has a target that exists on *target*. Returns + *target* for convenient chaining. """ marker = cls.patch_marker_attr() if getattr(target, marker, False): @@ -144,6 +160,11 @@ def wrap_target(cls, target: Any) -> Any: attr = cls.target_path.rsplit(".", 1)[-1] if _resolve_attr_path(target, attr) is None: return target + # Check superseded_by against the target object directly. + for superior in cls.superseded_by: + superior_attr = superior.target_path.rsplit(".", 1)[-1] + if _resolve_attr_path(target, superior_attr) is not None: + return target wrap_function_wrapper(target, attr, cls.wrapper) cls.mark_patched(target) return target diff --git a/py/src/braintrust/wrappers/agno/__init__.py b/py/src/braintrust/wrappers/agno/__init__.py index 3345f28f..59ad441d 100644 --- a/py/src/braintrust/wrappers/agno/__init__.py +++ b/py/src/braintrust/wrappers/agno/__init__.py @@ -1,78 +1,11 @@ -""" -Braintrust wrapper for Agno - provides observability for agent workflows. +from braintrust.integrations.agno import ( # noqa: F401 + setup_agno, + wrap_agent, + wrap_function_call, + wrap_model, + wrap_team, + wrap_workflow, +) -This integration provides: -- Agent execution tracing with proper root spans -- LLM call tracing with proper nesting -- Tool call tracing with correct parent-child relationships - -Usage: - from braintrust.wrappers.agno import setup_agno - - # Initialize the integration - setup_agno(project_name="my-project") - - # Your Agno agent code will now be automatically traced - import agno - agent = agno.Agent(...) - response = agent.run(...) -""" __all__ = ["setup_agno", "wrap_agent", "wrap_function_call", "wrap_model", "wrap_team", "wrap_workflow"] - -import logging - -from braintrust.logger import NOOP_SPAN, current_span, init_logger - -from .agent import wrap_agent -from .function_call import wrap_function_call -from .model import wrap_model -from .team import wrap_team -from .workflow import wrap_workflow - - -logger = logging.getLogger(__name__) - - -def setup_agno( - api_key: str | None = None, - project_id: str | None = None, - project_name: str | None = None, -) -> bool: - """ - Setup Braintrust integration with Agno. Will automatically patch Agno agents, models, and function calls for tracing. - - This function is called by init_agno() and can also be used directly for more control. - - Args: - api_key: Braintrust API key (optional, can use env var BRAINTRUST_API_KEY) - project_id: Braintrust project ID (optional) - project_name: Braintrust project name (optional, can use env var BRAINTRUST_PROJECT) - - Returns: - True if setup was successful, False otherwise - """ - span = current_span() - if span == NOOP_SPAN: - init_logger(project=project_name, api_key=api_key, project_id=project_id) - - try: - from agno import agent, models, team, tools # pyright: ignore - - agent.Agent = wrap_agent(agent.Agent) # pyright: ignore[reportUnknownMemberType] - team.Team = wrap_team(team.Team) # pyright: ignore[reportUnknownMemberType] - models.base.Model = wrap_model(models.base.Model) # pyright: ignore[reportUnknownMemberType] - tools.function.FunctionCall = wrap_function_call(tools.function.FunctionCall) # pyright: ignore[reportUnknownMemberType] - except ImportError: - # Not installed - this is expected when using auto_instrument() - return False - - try: - from agno import workflow # pyright: ignore - - workflow.Workflow = wrap_workflow(workflow.Workflow) # pyright: ignore[reportUnknownMemberType] - except ImportError: - # agno.workflow requires fastapi which may not be installed - pass - - return True diff --git a/py/src/braintrust/wrappers/agno/agent.py b/py/src/braintrust/wrappers/agno/agent.py deleted file mode 100644 index cb63cc49..00000000 --- a/py/src/braintrust/wrappers/agno/agent.py +++ /dev/null @@ -1,216 +0,0 @@ -import time -from typing import Any - -from braintrust.logger import start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - -from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper -from .utils import ( - _aggregate_agent_chunks, - extract_metadata, - extract_metrics, - extract_streaming_metrics, - is_patched, - mark_patched, - omit, -) - - -def wrap_agent(Agent: Any) -> Any: - if is_patched(Agent): - return Agent - - def _create_run_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): - """Shared logic to create span and execute run method.""" - agent_name = getattr(instance, "name", None) or "Agent" - span_name = f"{agent_name}.run" - - with start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "agent")}, - ) as span: - result = wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result), - ) - return result - - def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for private _run(run_response, run_messages).""" - run_response = args[0] if len(args) > 0 else kwargs.get("run_response") - run_messages = args[1] if len(args) > 1 else kwargs.get("run_messages") - input_data = {"run_response": run_response, "run_messages": run_messages} - return _create_run_span(wrapped, instance, args, kwargs, input_data) - - def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return run_public_dispatch_wrapper( - wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent" - ) - - # Wrap private method if it exists, otherwise wrap public method - if hasattr(Agent, "_run"): - wrap_function_wrapper(Agent, "_run", _run_wrapper_private) - elif hasattr(Agent, "run"): - wrap_function_wrapper(Agent, "run", _run_wrapper_public) - - async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): - """Shared logic to create span and execute async private _arun method.""" - agent_name = getattr(instance, "name", None) or "Agent" - span_name = f"{agent_name}.arun" - - with start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "agent")}, - ) as span: - result = await wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result), - ) - return result - - async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for private _arun(run_response, input).""" - run_response = args[0] if len(args) > 0 else kwargs.get("run_response") - input_arg = args[1] if len(args) > 1 else kwargs.get("input") - input_data = {"run_response": run_response, "input": input_arg} - return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data) - - def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return arun_public_dispatch_wrapper( - wrapped, instance, args, kwargs, default_name="Agent", metadata_component="agent" - ) - - # Wrap private method if it exists, otherwise wrap public method - if hasattr(Agent, "_arun"): - wrap_function_wrapper(Agent, "_arun", _arun_wrapper_private) - - def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - agent_name = getattr(instance, "name", None) or "Agent" - span_name = f"{agent_name}.run_stream" - - run_response = args[0] if args else kwargs.get("run_response") - run_messages = args[1] if args else kwargs.get("run_messages") - - def _trace_stream(): - start = time.time() - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input={"run_response": run_response, "run_messages": run_messages}, - metadata={**omit(kwargs, ["run_response", "run_messages"]), **extract_metadata(instance, "agent")}, - ) - span.set_current() - - should_unset = True - try: - first = True - all_chunks = [] - - for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_agent_chunks(all_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - # Generator was closed early (e.g., break from for loop) - # Don't call unset_current() as context may have changed - should_unset = False - raise - except Exception as e: - span.log( - error=str(e), - ) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - if hasattr(Agent, "_run_stream"): - wrap_function_wrapper(Agent, "_run_stream", run_stream_wrapper) - - def arun_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - agent_name = getattr(instance, "name", None) or "Agent" - span_name = f"{agent_name}.arun_stream" - - run_response = args[0] if args else kwargs.get("run_response") - input = args[2] if args else kwargs.get("input") - - async def _trace_stream(): - start = time.time() - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input={"run_response": run_response, "input": input}, - metadata={**omit(kwargs, ["run_response", "input"]), **extract_metadata(instance, "agent")}, - ) - span.set_current() - - should_unset = True - try: - first = True - all_chunks = [] - - async for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_agent_chunks(all_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - # Generator was closed early (e.g., break from async for loop) - # Don't call unset_current() as context may have changed - should_unset = False - raise - except Exception as e: - span.log( - error=str(e), - ) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - if hasattr(Agent, "_arun_stream"): - wrap_function_wrapper(Agent, "_arun_stream", arun_stream_wrapper) - elif not hasattr(Agent, "_arun") and hasattr(Agent, "arun"): - # Agno >= 2.5 routes through public arun(..., stream=...) - wrap_function_wrapper(Agent, "arun", _arun_wrapper_public) - - mark_patched(Agent) - return Agent diff --git a/py/src/braintrust/wrappers/agno/function_call.py b/py/src/braintrust/wrappers/agno/function_call.py deleted file mode 100644 index e4104866..00000000 --- a/py/src/braintrust/wrappers/agno/function_call.py +++ /dev/null @@ -1,67 +0,0 @@ -from typing import Any - -from braintrust.logger import start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - -from .utils import is_patched - - -def wrap_function_call(FunctionCall: Any) -> Any: - if is_patched(FunctionCall): - return FunctionCall - - def execute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - function_name = _get_function_name(instance) - span_name = f"{function_name}.execute" - - entrypoint_args = instance._build_entrypoint_args() - - with start_span( - name=span_name, - type=SpanTypeAttribute.TOOL, - input=(instance.arguments or {}), - metadata={ - "name": instance.function.name, - "entrypoint": instance.function.entrypoint.__name__, - **(entrypoint_args or {}), - }, - ) as span: - result = wrapped(*args, **kwargs) - span.log(output=result) - return result - - if hasattr(FunctionCall, "execute"): - wrap_function_wrapper(FunctionCall, "execute", execute_wrapper) - - async def aexecute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - function_name = _get_function_name(instance) - span_name = f"{function_name}.aexecute" - - entrypoint_args = instance._build_entrypoint_args() - - with start_span( - name=span_name, - type=SpanTypeAttribute.TOOL, - input=(instance.arguments or {}), - metadata={ - "name": instance.function.name, - "entrypoint": instance.function.entrypoint.__name__, - **(entrypoint_args or {}), - }, - ) as span: - result = await wrapped(*args, **kwargs) - span.log(output=result) - return result - - if hasattr(FunctionCall, "aexecute"): - wrap_function_wrapper(FunctionCall, "aexecute", aexecute_wrapper) - - FunctionCall._braintrust_patched = True - return FunctionCall - - -def _get_function_name(instance) -> str: - if hasattr(instance, "function") and hasattr(instance.function, "name"): - return instance.function.name - return "Unknown" diff --git a/py/src/braintrust/wrappers/agno/model.py b/py/src/braintrust/wrappers/agno/model.py deleted file mode 100644 index 3e474b44..00000000 --- a/py/src/braintrust/wrappers/agno/model.py +++ /dev/null @@ -1,318 +0,0 @@ -""" -ModelWrapper class for Braintrust-Agno model observability. -""" - -import time -from typing import Any - -from braintrust.logger import start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - -from .utils import ( - _aggregate_model_chunks, - _aggregate_response_stream_chunks, - extract_metadata, - extract_metrics, - extract_streaming_metrics, - get_args_kwargs, - is_patched, - mark_patched, -) - - -def wrap_model(Model: Any) -> Any: - if is_patched(Model): - return Model - - def invoke_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.invoke" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["assistant_message", "messages", "response_format", "tools", "tool_choice"] - ) - - with start_span( - name=span_name, - type=SpanTypeAttribute.LLM, - input=input, - metadata={ - **clean_kwargs, - **extract_metadata(instance, "model"), - }, - ) as span: - result = wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result, kwargs.get("messages", [])), - ) - return result - - if hasattr(Model, "invoke"): - wrap_function_wrapper(Model, "invoke", invoke_wrapper) - - async def ainvoke_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.ainvoke" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "assistant_message", "response_format", "tools", "tool_choice"] - ) - - with start_span( - name=span_name, - type=SpanTypeAttribute.LLM, - input=input, - metadata={ - **clean_kwargs, - **extract_metadata(instance, "model"), - }, - ) as span: - result = await wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result, kwargs.get("messages", [])), - ) - return result - - if hasattr(Model, "ainvoke"): - wrap_function_wrapper(Model, "ainvoke", ainvoke_wrapper) - - def invoke_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.invoke_stream" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "assistant_messages", "response_format", "tools", "tool_choice"] - ) - - def _trace_stream(): - start = time.time() - with start_span( - name=span_name, - type=SpanTypeAttribute.LLM, - input=input, - metadata={ - **clean_kwargs, - **extract_metadata(instance, "model"), - }, - ) as span: - first = True - collected_chunks = [] - for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - - collected_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_model_chunks(collected_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - - return _trace_stream() - - if hasattr(Model, "invoke_stream"): - wrap_function_wrapper(Model, "invoke_stream", invoke_stream_wrapper) - - def ainvoke_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.ainvoke_stream" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "assistant_messages", "response_format", "tools", "tool_choice"] - ) - - async def _trace_astream(): - start = time.time() - with start_span( - name=span_name, - type=SpanTypeAttribute.LLM, - input=input, - metadata={ - **clean_kwargs, - **extract_metadata(instance, "model"), - }, - ) as span: - first = True - collected_chunks = [] - async for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - - collected_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_model_chunks(collected_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - - return _trace_astream() - - if hasattr(Model, "ainvoke_stream"): - wrap_function_wrapper(Model, "ainvoke_stream", ainvoke_stream_wrapper) - - def response_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.response" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] - ) - - with start_span( - name=span_name, - # TODO: should be LLM? - type=SpanTypeAttribute.LLM, - input=input, - metadata={**clean_kwargs, **extract_metadata(instance, "model")}, - ) as span: - result = wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result, kwargs.get("messages", [])), - ) - return result - - if hasattr(Model, "response"): - wrap_function_wrapper(Model, "response", response_wrapper) - - async def aresponse_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.aresponse" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] - ) - - with start_span( - name=span_name, - # TODO: should be LLM? - type=SpanTypeAttribute.LLM, - input=input, - metadata={**clean_kwargs, **extract_metadata(instance, "model")}, - ) as span: - result = await wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result, kwargs.get("messages", [])), - ) - return result - - if hasattr(Model, "aresponse"): - wrap_function_wrapper(Model, "aresponse", aresponse_wrapper) - - def response_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.response_stream" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] - ) - - def _trace_stream(): - start = time.time() - with start_span( - name=span_name, - type=SpanTypeAttribute.LLM, - input=input, - metadata={**clean_kwargs, **extract_metadata(instance, "model")}, - ) as span: - first = True - collected_chunks = [] - - for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - - collected_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_response_stream_chunks(collected_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - - return _trace_stream() - - if hasattr(Model, "response_stream"): - wrap_function_wrapper(Model, "response_stream", response_stream_wrapper) - - def aresponse_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - model_name = _get_model_name(instance) - span_name = f"{model_name}.aresponse_stream" - - input, clean_kwargs = get_args_kwargs( - args, kwargs, ["messages", "response_format", "tools", "functions", "tool_chocie", "tool_call_limit"] - ) - - async def _trace_astream(): - start = time.time() - with start_span( - name=span_name, - type=SpanTypeAttribute.LLM, - input=input, - metadata={**clean_kwargs, **extract_metadata(instance, "model")}, - ) as span: - first = True - collected_chunks = [] - - async for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - - collected_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_response_stream_chunks(collected_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - - return _trace_astream() - - if hasattr(Model, "aresponse_stream"): - wrap_function_wrapper(Model, "aresponse_stream", aresponse_stream_wrapper) - - mark_patched(Model) - return Model - - -def _get_model_name(instance: Any) -> str: - provider = getattr(instance, "provider", None) - if provider: - return str(provider) - if hasattr(instance, "get_provider") and callable(instance.get_provider): - return str(instance.get_provider()) - return getattr(instance.__class__, "__name__", "Model") diff --git a/py/src/braintrust/wrappers/agno/run_helpers.py b/py/src/braintrust/wrappers/agno/run_helpers.py deleted file mode 100644 index 3be7f587..00000000 --- a/py/src/braintrust/wrappers/agno/run_helpers.py +++ /dev/null @@ -1,139 +0,0 @@ -import time -from inspect import isawaitable -from typing import Any - -from braintrust.logger import start_span -from braintrust.span_types import SpanTypeAttribute - -from .utils import ( - extract_metadata, - extract_metrics, - is_async_iterator, - is_sync_iterator, - omit, - trace_async_stream_result, - trace_sync_stream_result, -) - - -def run_public_dispatch_wrapper( - wrapped: Any, - instance: Any, - args: Any, - kwargs: Any, - *, - default_name: str, - metadata_component: str, -) -> Any: - """Trace a public synchronous `run(...)` dispatch method. - - Handles both non-streaming return values and synchronous streaming iterators. - For iterator results, span lifecycle is delegated to `trace_sync_stream_result`. - """ - component_name = getattr(instance, "name", None) or default_name - input_arg = args[0] if len(args) > 0 else kwargs.get("input") - input_data = {"input": input_arg} - metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)} - - span = start_span( - name=f"{component_name}.run", - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=metadata, - ) - span.set_current() - start = time.time() - try: - result = wrapped(*args, **kwargs) - if is_sync_iterator(result): - return trace_sync_stream_result(result, span, start) - span.log( - output=result, - metrics=extract_metrics(result), - ) - span.unset_current() - span.end() - return result - except Exception as e: - span.log(error=str(e)) - span.unset_current() - span.end() - raise - - -def arun_public_dispatch_wrapper( - wrapped: Any, - instance: Any, - args: Any, - kwargs: Any, - *, - default_name: str, - metadata_component: str, -) -> Any: - """Trace a public `arun(...)` dispatch method across async return contracts. - - Supports all observed `arun` dispatcher behaviors: - - immediate return value - - awaitable returning a value - - direct async iterator - - awaitable returning an async iterator - - If an async iterator is returned (directly or after await), span lifecycle is - delegated to `trace_async_stream_result` so the span remains open until stream - consumption completes. - """ - component_name = getattr(instance, "name", None) or default_name - input_arg = args[0] if len(args) > 0 else kwargs.get("input") - input_data = {"input": input_arg} - metadata = {**omit(kwargs, ["input"]), **extract_metadata(instance, metadata_component)} - - span = start_span( - name=f"{component_name}.arun", - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=metadata, - ) - span.set_current() - start = time.time() - try: - result = wrapped(*args, **kwargs) - - if isawaitable(result): - - async def _trace_awaitable(): - should_end_span = True - try: - awaited = await result - if is_async_iterator(awaited): - should_end_span = False - return trace_async_stream_result(awaited, span, start) - span.log( - output=awaited, - metrics=extract_metrics(awaited), - ) - return awaited - except Exception as e: - span.log(error=str(e)) - raise - finally: - if should_end_span: - span.unset_current() - span.end() - - return _trace_awaitable() - - if is_async_iterator(result): - return trace_async_stream_result(result, span, start) - - span.log( - output=result, - metrics=extract_metrics(result), - ) - span.unset_current() - span.end() - return result - except Exception as e: - span.log(error=str(e)) - span.unset_current() - span.end() - raise diff --git a/py/src/braintrust/wrappers/agno/team.py b/py/src/braintrust/wrappers/agno/team.py deleted file mode 100644 index f82fc9b5..00000000 --- a/py/src/braintrust/wrappers/agno/team.py +++ /dev/null @@ -1,216 +0,0 @@ -import time -from typing import Any - -from braintrust.logger import start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - -from .run_helpers import arun_public_dispatch_wrapper, run_public_dispatch_wrapper -from .utils import ( - _aggregate_agent_chunks, - extract_metadata, - extract_metrics, - extract_streaming_metrics, - is_patched, - mark_patched, - omit, -) - - -def wrap_team(Team: Any) -> Any: - if is_patched(Team): - return Team - - def _create_run_span(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): - """Shared logic to create span and execute run method.""" - agent_name = getattr(instance, "name", None) or "Team" - span_name = f"{agent_name}.run" - - with start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "team")}, - ) as span: - result = wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result), - ) - return result - - def _run_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for private _run(run_response, run_messages).""" - run_response = args[0] if len(args) > 0 else kwargs.get("run_response") - run_messages = args[1] if len(args) > 1 else kwargs.get("run_messages") - input_data = {"run_response": run_response, "run_messages": run_messages} - return _create_run_span(wrapped, instance, args, kwargs, input_data) - - def _run_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return run_public_dispatch_wrapper( - wrapped, instance, args, kwargs, default_name="Team", metadata_component="team" - ) - - # Wrap private method if it exists, otherwise wrap public method - if hasattr(Team, "_run"): - wrap_function_wrapper(Team, "_run", _run_wrapper_private) - elif hasattr(Team, "run"): - wrap_function_wrapper(Team, "run", _run_wrapper_public) - - async def _create_arun_span_private(wrapped: Any, instance: Any, args: Any, kwargs: Any, input_data: dict): - """Shared logic to create span and execute async private _arun method.""" - agent_name = getattr(instance, "name", None) or "Team" - span_name = f"{agent_name}.arun" - - with start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata={**omit(kwargs, list(input_data.keys())), **extract_metadata(instance, "team")}, - ) as span: - result = await wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result), - ) - return result - - async def _arun_wrapper_private(wrapped: Any, instance: Any, args: Any, kwargs: Any): - """Entry point for private _arun(run_response, input).""" - run_response = args[0] if len(args) > 0 else kwargs.get("run_response") - input_arg = args[1] if len(args) > 1 else kwargs.get("input") - input_data = {"run_response": run_response, "input": input_arg} - return await _create_arun_span_private(wrapped, instance, args, kwargs, input_data) - - def _arun_wrapper_public(wrapped: Any, instance: Any, args: Any, kwargs: Any): - return arun_public_dispatch_wrapper( - wrapped, instance, args, kwargs, default_name="Team", metadata_component="team" - ) - - # Wrap private method if it exists, otherwise wrap public method - if hasattr(Team, "_arun"): - wrap_function_wrapper(Team, "_arun", _arun_wrapper_private) - - def run_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - agent_name = getattr(instance, "name", None) or "Team" - span_name = f"{agent_name}.run_stream" - - run_response = args[0] if args else kwargs.get("run_response") - run_messages = args[1] if args else kwargs.get("run_messages") - - def _trace_stream(): - start = time.time() - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input={"run_response": run_response, "run_messages": run_messages}, - metadata={**omit(kwargs, ["run_response", "run_messages"]), **extract_metadata(instance, "team")}, - ) - span.set_current() - - should_unset = True - try: - first = True - all_chunks = [] - - for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_agent_chunks(all_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - # Generator was closed early (e.g., break from for loop) - # Don't call unset_current() as context may have changed - should_unset = False - raise - except Exception as e: - span.log( - error=str(e), - ) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - if hasattr(Team, "_run_stream"): - wrap_function_wrapper(Team, "_run_stream", run_stream_wrapper) - - def arun_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - agent_name = getattr(instance, "name", None) or "Team" - span_name = f"{agent_name}.arun_stream" - - run_response = args[0] if args else kwargs.get("run_response") - input = args[2] if args else kwargs.get("input") - - async def _trace_stream(): - start = time.time() - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input={"run_response": run_response, "input": input}, - metadata={**omit(kwargs, ["run_response", "input"]), **extract_metadata(instance, "team")}, - ) - span.set_current() - - should_unset = True - try: - first = True - all_chunks = [] - - async for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_agent_chunks(all_chunks) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - # Generator was closed early (e.g., break from async for loop) - # Don't call unset_current() as context may have changed - should_unset = False - raise - except Exception as e: - span.log( - error=str(e), - ) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - if hasattr(Team, "_arun_stream"): - wrap_function_wrapper(Team, "_arun_stream", arun_stream_wrapper) - elif not hasattr(Team, "_arun") and hasattr(Team, "arun"): - # Agno >= 2.5 routes through public arun(..., stream=...) - wrap_function_wrapper(Team, "arun", _arun_wrapper_public) - - mark_patched(Team) - return Team diff --git a/py/src/braintrust/wrappers/agno/utils.py b/py/src/braintrust/wrappers/agno/utils.py deleted file mode 100644 index 7951ac7c..00000000 --- a/py/src/braintrust/wrappers/agno/utils.py +++ /dev/null @@ -1,520 +0,0 @@ -import time -from typing import Any - -from braintrust.util import is_numeric - - -def omit(obj: dict[str, Any], keys: list[str]): - return {k: v for k, v in obj.items() if k not in keys} - - -def is_patched(obj: Any) -> bool: - return getattr(obj, "_braintrust_patched", False) - - -def mark_patched(obj: Any): - setattr(obj, "_braintrust_patched", True) - - -def clean(obj: dict[str, Any]) -> dict[str, Any]: - return {k: v for k, v in obj.items() if v is not None} - - -def get_args_kwargs(args: list[str], kwargs: dict[str, Any], keys: list[str]): - return {k: args[i] if args else kwargs.get(k) for i, k in enumerate(keys)}, omit(kwargs, keys) - - -def _try_to_dict(obj: Any) -> Any: - """Convert object to dict, handling different object types like OpenAI wrapper.""" - if isinstance(obj, dict): - return obj - # convert a pydantic object to a dict - if hasattr(obj, "model_dump") and callable(obj.model_dump): - try: - return obj.model_dump() - except Exception: - pass - # deprecated pydantic method, try model_dump first. - if hasattr(obj, "dict") and callable(obj.dict): - try: - return obj.dict() - except Exception: - pass - # handle objects with __dict__ (like Agno Metrics objects) - if hasattr(obj, "__dict__"): - try: - return obj.__dict__.copy() - except Exception: - pass - return obj - - -# Agno field names to canonical Braintrust field names (following OpenAI wrapper pattern) -AGNO_METRICS_MAP = { - # Core token metrics - using OpenAI wrapper naming - "input_tokens": "prompt_tokens", - "output_tokens": "completion_tokens", - "total_tokens": "tokens", - # Reasoning and audio tokens - "reasoning_tokens": "completion_reasoning_tokens", - "audio_input_tokens": "prompt_audio_tokens", - "audio_output_tokens": "completion_audio_tokens", - # Cache tokens - "cache_read_tokens": "prompt_cached_tokens", - "cache_write_tokens": "prompt_cache_creation_tokens", - # Timing metrics - "duration": "duration", - "time_to_first_token": "time_to_first_token", -} - - -def extract_metadata(instance: Any, component: str) -> dict[str, Any]: - """Extract metadata from any component (model, agent, team).""" - metadata = {"component": component} - - # Component-specific name fields - if component == "model": - if hasattr(instance, "id") and instance.id: - metadata["model"] = instance.id - metadata["model_id"] = instance.id - if hasattr(instance, "provider") and instance.provider: - metadata["provider"] = instance.provider - if hasattr(instance, "name") and instance.name: - metadata["model_name"] = instance.name - if hasattr(instance, "__class__"): - metadata["model_class"] = instance.__class__.__name__ - elif component == "agent": - metadata["agent_name"] = getattr(instance, "name", None) - model = getattr(instance, "model", None) - if model: - metadata["model"] = getattr(model, "id", None) or model.__class__.__name__ - elif component == "team": - metadata["team_name"] = getattr(instance, "name", None) - model = getattr(instance, "model", None) - if model: - metadata["model"] = getattr(model, "id", None) or model.__class__.__name__ - elif component == "workflow": - metadata["workflow_id"] = getattr(instance, "id", None) - metadata["workflow_name"] = getattr(instance, "name", None) - steps = getattr(instance, "steps", None) - if steps: - metadata["steps_count"] = len(steps) - - return metadata - - -def parse_metrics_from_agno(usage: Any) -> dict[str, Any]: - """Parse metrics from Agno usage object, following OpenAI wrapper pattern.""" - metrics = {} - - if not usage: - return metrics - - # Convert to dict like OpenAI wrapper - usage_dict = _try_to_dict(usage) - if not isinstance(usage_dict, dict): - return metrics - - # Simple loop through Agno fields and map to Braintrust names - for agno_name, value in usage_dict.items(): - if agno_name in AGNO_METRICS_MAP and is_numeric(value) and value != 0: - braintrust_name = AGNO_METRICS_MAP[agno_name] - metrics[braintrust_name] = value - - return metrics - - -def extract_metrics(result: Any, messages: list | None = None) -> dict[str, Any]: - """ - Unified metrics extraction for all components. - - Handles: - - Model responses with response_usage - - Agent/Team responses with metrics - - Messages with metrics (for model responses) - """ - # For model responses with response_usage - if hasattr(result, "response_usage") and result.response_usage: - return parse_metrics_from_agno(result.response_usage) - - # For agent/team responses with metrics - if hasattr(result, "metrics") and result.metrics: - metrics = parse_metrics_from_agno(result.metrics) - return metrics if metrics else None - - # If no metrics found and we have messages, look for metrics in assistant messages (model-specific) - if messages: - for msg in messages: - # Look for assistant messages with metrics - if hasattr(msg, "role") and msg.role == "assistant" and hasattr(msg, "metrics") and msg.metrics: - return parse_metrics_from_agno(msg.metrics) - - return {} - - -def extract_streaming_metrics(aggregated: dict[str, Any], start_time: float) -> dict[str, Any] | None: - """Extract metrics from aggregated streaming response.""" - metrics = {} - - # Extract metrics from aggregated data - # The metrics are already in Braintrust format from _aggregate_model_chunks - if aggregated.get("metrics") and isinstance(aggregated["metrics"], dict): - # Merge the aggregated metrics - metrics.update(aggregated["metrics"]) - # Handle object-like metrics payloads (e.g. RunCompletedEvent.metrics) - elif aggregated.get("metrics"): - parsed_metrics = parse_metrics_from_agno(aggregated["metrics"]) - if parsed_metrics: - metrics.update(parsed_metrics) - # Also check response_usage for backward compatibility - elif aggregated.get("response_usage"): - response_metrics = parse_metrics_from_agno(aggregated["response_usage"]) - if response_metrics: - metrics.update(response_metrics) - - # Ensure we have the duration calculated from start_time - metrics["duration"] = time.time() - start_time - - return metrics if metrics else None - - -def _aggregate_metrics(target: dict[str, Any], source: dict[str, Any]) -> None: - """Aggregate metrics from source into target dict.""" - for key, value in source.items(): - if is_numeric(value): - if key in target: - # For timing metrics, we keep the latest - if "time" in key.lower() or "duration" in key.lower(): - target[key] = value - # For token counts, we sum them - elif "token" in key.lower() or key == "tokens": - target[key] = (target.get(key, 0) or 0) + value - # For other metrics, keep the latest - else: - target[key] = value - else: - target[key] = value - - -def _aggregate_model_chunks(chunks: list[Any]) -> dict[str, Any]: - """Aggregate ModelResponse chunks from invoke_stream into a complete response.""" - aggregated = { - "content": "", - "reasoning_content": "", - "tool_calls": [], - "role": None, - "audio": None, - "images": [], - "videos": [], - "files": [], - "citations": None, - "metrics": {}, - } - - for chunk in chunks: - if hasattr(chunk, "content") and chunk.content: - aggregated["content"] += str(chunk.content) - - if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: - aggregated["reasoning_content"] += chunk.reasoning_content - - if hasattr(chunk, "role") and chunk.role and not aggregated["role"]: - aggregated["role"] = chunk.role - - if hasattr(chunk, "tool_calls") and chunk.tool_calls: - aggregated["tool_calls"].extend(chunk.tool_calls) - - if hasattr(chunk, "audio") and chunk.audio: - aggregated["audio"] = chunk.audio - - if hasattr(chunk, "images") and chunk.images: - aggregated["images"].extend(chunk.images) - - if hasattr(chunk, "videos") and chunk.videos: - aggregated["videos"].extend(chunk.videos) - - if hasattr(chunk, "files") and chunk.files: - aggregated["files"].extend(chunk.files) - - if hasattr(chunk, "citations") and chunk.citations: - aggregated["citations"] = chunk.citations - - if hasattr(chunk, "response_usage") and chunk.response_usage: - # Parse and aggregate metrics from each chunk - chunk_metrics = parse_metrics_from_agno(chunk.response_usage) - if chunk_metrics: - _aggregate_metrics(aggregated["metrics"], chunk_metrics) - - # Convert aggregated metrics dict to the response_usage format for backward compatibility - if aggregated["metrics"]: - aggregated["response_usage"] = aggregated["metrics"] - else: - aggregated["metrics"] = None - - return aggregated - - -def _aggregate_response_stream_chunks(chunks: list[Any]) -> dict[str, Any]: - """ - Aggregate chunks from response_stream which can be ModelResponse, RunOutputEvent, or TeamRunOutputEvent. - - This is more robust than _aggregate_model_chunks as it handles different event types. - """ - aggregated = { - "content": "", - "reasoning_content": "", - "tool_calls": [], - "role": None, - "audio": None, - "images": [], - "videos": [], - "files": [], - "citations": None, - "metrics": {}, - } - - for chunk in chunks: - # Handle ModelResponse chunks - if hasattr(chunk, "__class__") and "ModelResponse" in chunk.__class__.__name__: - if hasattr(chunk, "content") and chunk.content: - aggregated["content"] += str(chunk.content) - - if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: - aggregated["reasoning_content"] += chunk.reasoning_content - - if hasattr(chunk, "role") and chunk.role and not aggregated["role"]: - aggregated["role"] = chunk.role - - if hasattr(chunk, "tool_calls") and chunk.tool_calls: - aggregated["tool_calls"].extend(chunk.tool_calls) - - if hasattr(chunk, "audio") and chunk.audio: - aggregated["audio"] = chunk.audio - - if hasattr(chunk, "images") and chunk.images: - aggregated["images"].extend(chunk.images) - - if hasattr(chunk, "videos") and chunk.videos: - aggregated["videos"].extend(chunk.videos) - - if hasattr(chunk, "files") and chunk.files: - aggregated["files"].extend(chunk.files) - - if hasattr(chunk, "citations") and chunk.citations: - aggregated["citations"] = chunk.citations - - if hasattr(chunk, "response_usage") and chunk.response_usage: - # Parse and aggregate metrics from each chunk - chunk_metrics = parse_metrics_from_agno(chunk.response_usage) - if chunk_metrics: - _aggregate_metrics(aggregated["metrics"], chunk_metrics) - - # Also check for metrics attribute directly (for some response types) - elif hasattr(chunk, "metrics") and chunk.metrics: - chunk_metrics = parse_metrics_from_agno(chunk.metrics) - if chunk_metrics: - _aggregate_metrics(aggregated["metrics"], chunk_metrics) - - # Handle RunOutputEvent/TeamRunOutputEvent chunks - these typically contain content - elif hasattr(chunk, "content"): - if chunk.content: - aggregated["content"] += str(chunk.content) - - # Handle other event types that might have metrics - if hasattr(chunk, "metrics") and chunk.metrics and "metrics" not in str(type(chunk)): - chunk_metrics = parse_metrics_from_agno(chunk.metrics) - if chunk_metrics: - _aggregate_metrics(aggregated["metrics"], chunk_metrics) - - # Convert aggregated metrics dict to the response_usage format for backward compatibility - if aggregated["metrics"]: - aggregated["response_usage"] = aggregated["metrics"] - else: - aggregated["metrics"] = None - - return aggregated - - -def _aggregate_agent_chunks(chunks: list[Any]) -> dict[str, Any]: - """Aggregate BaseAgentRunEvent/BaseTeamRunEvent chunks into a complete response.""" - aggregated = { - "content": "", - "reasoning_content": "", - "model": "", - "model_provider": "", - "tool_calls": [], - "citations": None, - "references": None, - "metrics": None, - "finish_reason": None, - } - - for chunk in chunks: - event = getattr(chunk, "event", None) - - if event == "RunStarted": - if hasattr(chunk, "model"): - aggregated["model"] = chunk.model - if hasattr(chunk, "model_provider"): - aggregated["model_provider"] = chunk.model_provider - - elif event == "RunContent": - if hasattr(chunk, "content") and chunk.content: - aggregated["content"] += str(chunk.content) # type: ignore - if hasattr(chunk, "reasoning_content") and chunk.reasoning_content: - aggregated["reasoning_content"] += chunk.reasoning_content - if hasattr(chunk, "citations"): - aggregated["citations"] = chunk.citations - if hasattr(chunk, "references"): - aggregated["references"] = chunk.references - - elif event == "RunCompleted": - if hasattr(chunk, "metrics"): - parsed_metrics = parse_metrics_from_agno(chunk.metrics) - aggregated["metrics"] = parsed_metrics if parsed_metrics else chunk.metrics - aggregated["finish_reason"] = "stop" - - elif event == "RunError": - aggregated["finish_reason"] = "error" - - elif event == "ToolCallStarted": - if hasattr(chunk, "tool_call"): - aggregated["tool_calls"].append( # type:ignore - { - "id": getattr(chunk.tool_call, "id", None), - "type": "function", - "function": { - "name": getattr(chunk.tool_call, "name", None), - "arguments": getattr(chunk.tool_call, "arguments", ""), - }, - } - ) - - return {k: v for k, v in aggregated.items() if v not in (None, "")} - - -def _aggregate_workflow_chunks(chunks: list[Any], workflow_run_response: Any | None = None) -> dict[str, Any]: - """Aggregate workflow/step events into a final workflow-style response.""" - aggregated = { - "content": "", - "status": None, - "metrics": None, - } - final_workflow_content = None - - for chunk in chunks: - event = getattr(chunk, "event", None) - - if hasattr(chunk, "content") and chunk.content: - if event == "WorkflowCompleted": - final_workflow_content = str(chunk.content) - elif final_workflow_content is None: - aggregated["content"] += str(chunk.content) - - if hasattr(chunk, "status") and chunk.status: - aggregated["status"] = chunk.status - - if hasattr(chunk, "metrics") and chunk.metrics: - parsed_metrics = parse_metrics_from_agno(chunk.metrics) - aggregated["metrics"] = parsed_metrics if parsed_metrics else chunk.metrics - - if final_workflow_content is not None: - accumulated_content = aggregated["content"] - if not accumulated_content: - aggregated["content"] = final_workflow_content - elif accumulated_content.endswith(final_workflow_content): - aggregated["content"] = accumulated_content - else: - aggregated["content"] = f"{accumulated_content}{final_workflow_content}" - - if workflow_run_response is not None: - if not aggregated["content"] and hasattr(workflow_run_response, "content") and workflow_run_response.content: - aggregated["content"] = str(workflow_run_response.content) - - if not aggregated["status"] and hasattr(workflow_run_response, "status") and workflow_run_response.status: - aggregated["status"] = workflow_run_response.status - - if not aggregated["metrics"] and hasattr(workflow_run_response, "metrics") and workflow_run_response.metrics: - parsed_metrics = parse_metrics_from_agno(workflow_run_response.metrics) - aggregated["metrics"] = parsed_metrics if parsed_metrics else workflow_run_response.metrics - - return {k: v for k, v in aggregated.items() if v not in (None, "")} - - -def is_sync_iterator(result: Any) -> bool: - return hasattr(result, "__iter__") and hasattr(result, "__next__") - - -def is_async_iterator(result: Any) -> bool: - return hasattr(result, "__aiter__") and hasattr(result, "__anext__") - - -def trace_sync_stream_result(result: Any, span: Any, start: float): - def _trace_stream(): - should_unset = True - try: - first = True - all_chunks = [] - for chunk in result: - if first: - span.log(metrics={"time_to_first_token": time.time() - start}) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_agent_chunks(all_chunks) - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - should_unset = False - raise - except Exception as e: - span.log(error=str(e)) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - -def trace_async_stream_result(result: Any, span: Any, start: float): - async def _trace_astream(): - should_unset = True - try: - first = True - all_chunks = [] - async for chunk in result: - if first: - span.log(metrics={"time_to_first_token": time.time() - start}) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_agent_chunks(all_chunks) - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - should_unset = False - raise - except Exception as e: - span.log(error=str(e)) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_astream() - - -# Legacy aliases for backward compatibility -_extract_run_metrics = extract_metrics -_extract_streaming_metrics = extract_streaming_metrics -_extract_model_metrics = extract_metrics -_parse_metrics_from_agno = parse_metrics_from_agno diff --git a/py/src/braintrust/wrappers/agno/workflow.py b/py/src/braintrust/wrappers/agno/workflow.py deleted file mode 100644 index 57866a4b..00000000 --- a/py/src/braintrust/wrappers/agno/workflow.py +++ /dev/null @@ -1,353 +0,0 @@ -import time -from typing import Any - -from braintrust.logger import start_span -from braintrust.span_types import SpanTypeAttribute -from wrapt import wrap_function_wrapper - -from .utils import ( - _aggregate_workflow_chunks, - _try_to_dict, - extract_metadata, - extract_metrics, - extract_streaming_metrics, - is_async_iterator, - is_patched, - is_sync_iterator, - mark_patched, -) - - -def _extract_workflow_input( - args: Any, - kwargs: Any, - *, - execution_input_index: int, - workflow_run_response_index: int, -) -> dict[str, Any]: - """Extract workflow input from execution method parameters.""" - execution_input = ( - args[execution_input_index] if len(args) > execution_input_index else kwargs.get("execution_input") - ) - workflow_run_response = ( - args[workflow_run_response_index] - if len(args) > workflow_run_response_index - else kwargs.get("workflow_run_response") - ) - - result: dict[str, Any] = {} - - if execution_input: - if hasattr(execution_input, "input"): - result["input"] = execution_input.input - result["execution_input"] = _try_to_dict(execution_input) - - if workflow_run_response: - result["run_response"] = _try_to_dict(workflow_run_response) - - return result - - -def wrap_workflow(Workflow: Any) -> Any: - if is_patched(Workflow): - return Workflow - - def _workflow_span_config(instance: Any, suffix: str) -> tuple[str, dict[str, Any]]: - workflow_name = getattr(instance, "name", None) or "Workflow" - return f"{workflow_name}.{suffix}", extract_metadata(instance, "workflow") - - def _extract_workflow_agent_input(args: Any, kwargs: Any) -> dict[str, Any]: - user_input = args[0] if len(args) > 0 else kwargs.get("user_input") - execution_input = args[2] if len(args) > 2 else kwargs.get("execution_input") - - result: dict[str, Any] = {"input": user_input} - if execution_input: - result["execution_input"] = _try_to_dict(execution_input) - return result - - def execute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - workflow_name = getattr(instance, "name", None) or "Workflow" - span_name = f"{workflow_name}.run" - - input_data = _extract_workflow_input(args, kwargs, execution_input_index=1, workflow_run_response_index=2) - workflow_metadata = extract_metadata(instance, "workflow") - - with start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=workflow_metadata, - propagated_event={"metadata": workflow_metadata}, - ) as span: - result = wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result), - ) - return result - - if hasattr(Workflow, "_execute"): - wrap_function_wrapper(Workflow, "_execute", execute_wrapper) - - def execute_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - workflow_name = getattr(instance, "name", None) or "Workflow" - span_name = f"{workflow_name}.run_stream" - - input_data = _extract_workflow_input(args, kwargs, execution_input_index=1, workflow_run_response_index=2) - workflow_metadata = extract_metadata(instance, "workflow") - workflow_run_response = args[2] if len(args) > 2 else kwargs.get("workflow_run_response") - - def _trace_stream(): - start = time.time() - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=workflow_metadata, - propagated_event={"metadata": workflow_metadata}, - ) - span.set_current() - - should_unset = True - try: - first = True - all_chunks = [] - - for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_workflow_chunks(all_chunks, workflow_run_response) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - should_unset = False - raise - except Exception as e: - span.log( - error=str(e), - ) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - if hasattr(Workflow, "_execute_stream"): - wrap_function_wrapper(Workflow, "_execute_stream", execute_stream_wrapper) - - async def aexecute_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - workflow_name = getattr(instance, "name", None) or "Workflow" - span_name = f"{workflow_name}.arun" - - input_data = _extract_workflow_input(args, kwargs, execution_input_index=2, workflow_run_response_index=3) - workflow_metadata = extract_metadata(instance, "workflow") - - with start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=workflow_metadata, - propagated_event={"metadata": workflow_metadata}, - ) as span: - result = await wrapped(*args, **kwargs) - span.log( - output=result, - metrics=extract_metrics(result), - ) - return result - - if hasattr(Workflow, "_aexecute"): - wrap_function_wrapper(Workflow, "_aexecute", aexecute_wrapper) - - def aexecute_stream_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - workflow_name = getattr(instance, "name", None) or "Workflow" - span_name = f"{workflow_name}.arun_stream" - - input_data = _extract_workflow_input(args, kwargs, execution_input_index=2, workflow_run_response_index=3) - workflow_metadata = extract_metadata(instance, "workflow") - workflow_run_response = args[3] if len(args) > 3 else kwargs.get("workflow_run_response") - - async def _trace_stream(): - start = time.time() - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=workflow_metadata, - propagated_event={"metadata": workflow_metadata}, - ) - span.set_current() - - should_unset = True - try: - first = True - all_chunks = [] - - async for chunk in wrapped(*args, **kwargs): - if first: - span.log( - metrics={ - "time_to_first_token": time.time() - start, - } - ) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_workflow_chunks(all_chunks, workflow_run_response) - - span.log( - output=aggregated, - metrics=extract_streaming_metrics(aggregated, start), - ) - except GeneratorExit: - should_unset = False - raise - except Exception as e: - span.log( - error=str(e), - ) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - if hasattr(Workflow, "_aexecute_stream"): - wrap_function_wrapper(Workflow, "_aexecute_stream", aexecute_stream_wrapper) - - def execute_workflow_agent_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - stream = kwargs.get("stream", False) - span_suffix = "run_stream" if stream else "run" - span_name, workflow_metadata = _workflow_span_config(instance, span_suffix) - input_data = _extract_workflow_agent_input(args, kwargs) - - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=workflow_metadata, - propagated_event={"metadata": workflow_metadata}, - ) - span.set_current() - start = time.time() - try: - result = wrapped(*args, **kwargs) - if stream and is_sync_iterator(result): - - def _trace_stream(): - should_unset = True - try: - first = True - all_chunks = [] - for chunk in result: - if first: - span.log(metrics={"time_to_first_token": time.time() - start}) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_workflow_chunks(all_chunks) - span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) - except GeneratorExit: - should_unset = False - raise - except Exception as e: - span.log(error=str(e)) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - span.log(output=result, metrics=extract_metrics(result)) - span.unset_current() - span.end() - return result - except Exception as e: - span.log(error=str(e)) - span.unset_current() - span.end() - raise - - if hasattr(Workflow, "_execute_workflow_agent"): - wrap_function_wrapper(Workflow, "_execute_workflow_agent", execute_workflow_agent_wrapper) - - async def aexecute_workflow_agent_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any): - stream = kwargs.get("stream", False) - span_suffix = "arun_stream" if stream else "arun" - span_name, workflow_metadata = _workflow_span_config(instance, span_suffix) - input_data = _extract_workflow_agent_input(args, kwargs) - - span = start_span( - name=span_name, - type=SpanTypeAttribute.TASK, - input=input_data, - metadata=workflow_metadata, - propagated_event={"metadata": workflow_metadata}, - ) - span.set_current() - start = time.time() - try: - result = await wrapped(*args, **kwargs) - if stream and is_async_iterator(result): - - async def _trace_stream(): - should_unset = True - try: - first = True - all_chunks = [] - async for chunk in result: - if first: - span.log(metrics={"time_to_first_token": time.time() - start}) - first = False - all_chunks.append(chunk) - yield chunk - - aggregated = _aggregate_workflow_chunks(all_chunks) - span.log(output=aggregated, metrics=extract_streaming_metrics(aggregated, start)) - except GeneratorExit: - should_unset = False - raise - except Exception as e: - span.log(error=str(e)) - raise - finally: - if should_unset: - span.unset_current() - span.end() - - return _trace_stream() - - span.log(output=result, metrics=extract_metrics(result)) - span.unset_current() - span.end() - return result - except Exception as e: - span.log(error=str(e)) - span.unset_current() - span.end() - raise - - if hasattr(Workflow, "_aexecute_workflow_agent"): - wrap_function_wrapper(Workflow, "_aexecute_workflow_agent", aexecute_workflow_agent_wrapper) - - mark_patched(Workflow) - return Workflow