Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions py/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,18 @@ def test_btx_openai(session, version):
_run_tests(session, "braintrust/btx", version=version, env={"BTX_PROVIDER": "openai", "BTX_CLIENT": "openai"})


@nox.session()
@nox.parametrize("version", ANTHROPIC_VERSIONS, ids=ANTHROPIC_VERSIONS)
def test_btx_anthropic(session, version):
"""Run the BTX cross-language LLM-span spec tests (Anthropic provider)."""
_install_test_deps(session)
_install_matrix_dep(session, "anthropic", version)
session.install("pyyaml")
_run_tests(
session, "braintrust/btx", version=version, env={"BTX_PROVIDER": "anthropic", "BTX_CLIENT": "anthropic"}
)


@nox.session()
def test_openai_ddtrace(session):
_install_test_deps(session)
Expand Down
14 changes: 14 additions & 0 deletions py/src/braintrust/btx/span_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,21 @@ def _is_reasoning_message(value: Any) -> bool:
return True


def _is_positive_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool) and value > 0


def _undefined_or_null(value: Any) -> bool:
"""True if the value is absent (None/null) — used for fields that must not be populated."""
return value is None


_NAMED_MATCHERS: dict[str, Any] = {
"is_non_negative_number": _is_non_negative_number,
"is_positive_number": _is_positive_number,
"is_non_empty_string": _is_non_empty_string,
"is_reasoning_message": _is_reasoning_message,
"undefined_or_null": _undefined_or_null,
}


Expand Down Expand Up @@ -146,6 +157,9 @@ def _validate_value(actual: Any, expected: Any, path: str, errors: list[str]) ->
return
for key, exp_val in expected.items():
if key not in actual:
# A missing key satisfies undefined_or_null — treat absence as None.
if isinstance(exp_val, FnMatcher) and exp_val.expr == "undefined_or_null":
continue
errors.append(f"{path}.{key}: key not found in actual span")
else:
_validate_value(actual[key], exp_val, f"{path}.{key}", errors)
Expand Down
2 changes: 1 addition & 1 deletion py/src/braintrust/btx/spec-ref.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v0.0.1
v0.0.5
24 changes: 16 additions & 8 deletions py/src/braintrust/btx/spec_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _dispatch(spec: LlmSpanSpec, client: Any) -> None:
_execute_responses(spec.requests, client)

elif provider == "anthropic" and endpoint == "/v1/messages":
_execute_anthropic_messages(spec.requests, client)
_execute_anthropic_messages(spec.requests, client, extra_headers=spec.headers or {})

else:
raise NotImplementedError(f"BTX executor: provider={provider!r} endpoint={endpoint!r} not implemented")
Expand Down Expand Up @@ -191,7 +191,9 @@ def _execute_responses(requests: list[dict[str, Any]], client: Any) -> None:
# ---------------------------------------------------------------------------


def _execute_anthropic_messages(requests: list[dict[str, Any]], client: Any) -> None:
def _execute_anthropic_messages(
requests: list[dict[str, Any]], client: Any, extra_headers: dict[str, str] | None = None
) -> None:
"""Execute Anthropic messages requests.

Handles streaming (stream=True) by consuming the stream context manager.
Expand All @@ -209,14 +211,20 @@ def _execute_anthropic_messages(requests: list[dict[str, Any]], client: Any) ->
is_streaming = full_req.get("stream", False)
conversation_history.extend(req.get("messages", []))

create_kwargs: dict[str, Any] = dict(full_req)
if extra_headers:
create_kwargs["extra_headers"] = extra_headers

if is_streaming:
with client.messages.create(**full_req) as stream:
final = stream.get_final_message()
if hasattr(final, "content") and final.content:
text_blocks = [b.text for b in final.content if hasattr(b, "text")]
conversation_history.append({"role": "assistant", "content": " ".join(text_blocks)})
# Iterate the stream to exhaustion — the Braintrust TracedMessageStream
# context manager captures metrics and logs the span on __exit__.
# We can't call get_final_message() on the traced wrapper, so we
# skip history accumulation for streaming (no multi-turn streaming specs).
with client.messages.create(**create_kwargs) as stream:
for _ in stream:
pass
else:
response = client.messages.create(**full_req)
response = client.messages.create(**create_kwargs)
if hasattr(response, "content") and response.content:
text_blocks = [b.text for b in response.content if hasattr(b, "text")]
conversation_history.append({"role": "assistant", "content": " ".join(text_blocks)})
100 changes: 79 additions & 21 deletions py/src/braintrust/btx/spec_loader.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,31 @@
"""Load BTX LLM-span spec YAML files.

Handles the three custom YAML tags used in the spec:
Handles the custom YAML tags used in the spec:
!fn <name-or-lambda> — named predicate or arbitrary lambda (eval'd in Python)
!starts_with <prefix> — string prefix check
!or [...] — at-least-one-of validator
!gen <generator-name> — value generated by the test runner (e.g. test_runner_client)
"""

from __future__ import annotations

import dataclasses
import os
import uuid
from pathlib import Path
from typing import Any

import yaml


# ---------------------------------------------------------------------------
# Matcher types (parallel to SpecMatcher.java)
# Matcher / generator types
# ---------------------------------------------------------------------------


@dataclasses.dataclass
class FnMatcher:
"""A named or lambda-expression validator.

For well-known names (is_non_negative_number, etc.) the span_validator
module dispatches them to dedicated functions. For arbitrary Python
expressions the expression string is stored and eval()'d at validation
time.
"""
"""A named or lambda-expression validator."""

expr: str # e.g. "is_non_negative_number" or "lambda value: value > 0"

Expand All @@ -44,6 +40,18 @@ class OrMatcher:
alternatives: list[Any]


@dataclasses.dataclass
class GenValue:
"""A value generated by the test runner at execution time.

The generator name determines what value is produced:
test_runner_client — a string identifying this SDK/client (e.g. "python-openai")
vcr_nonce — a random string that changes every run (busts caches)
"""

generator: str # e.g. "test_runner_client", "vcr_nonce"


# ---------------------------------------------------------------------------
# YAML custom constructors
# ---------------------------------------------------------------------------
Expand All @@ -64,6 +72,11 @@ def _or_constructor(loader: yaml.SafeLoader, node: yaml.Node) -> OrMatcher:
return OrMatcher(alternatives=alternatives)


def _gen_constructor(loader: yaml.SafeLoader, node: yaml.Node) -> GenValue:
generator = loader.construct_scalar(node) # type: ignore[arg-type]
return GenValue(generator=generator)


def _make_loader() -> type:
"""Return a SafeLoader subclass with BTX custom tags registered."""

Expand All @@ -73,9 +86,57 @@ class BtxLoader(yaml.SafeLoader):
BtxLoader.add_constructor("!fn", _fn_constructor)
BtxLoader.add_constructor("!starts_with", _starts_with_constructor)
BtxLoader.add_constructor("!or", _or_constructor)
BtxLoader.add_constructor("!gen", _gen_constructor)
return BtxLoader


# ---------------------------------------------------------------------------
# Generator resolution
# ---------------------------------------------------------------------------

# Stable client identifier for this SDK implementation.
_CLIENT_ID = "python-btx"

# Per-process nonce — constant within a run so cassette body matching is stable,
# but differs across runs so cache-busting specs actually bust caches.
_VCR_NONCE = str(uuid.uuid4())[:8]

_GENERATORS: dict[str, str] = {
"test_runner_client": _CLIENT_ID,
"vcr_nonce": _VCR_NONCE,
}


def _resolve_gen(value: GenValue) -> str:
if value.generator in _GENERATORS:
return _GENERATORS[value.generator]
raise ValueError(f"Unknown !gen generator: {value.generator!r}")


def _resolve_variables(variables: dict[str, Any]) -> dict[str, str]:
"""Resolve all !gen values in the variables map to concrete strings."""
resolved: dict[str, str] = {}
for key, val in variables.items():
if isinstance(val, GenValue):
resolved[key] = _resolve_gen(val)
else:
resolved[key] = str(val)
return resolved


def _substitute_templates(obj: Any, variables: dict[str, str]) -> Any:
"""Recursively substitute {{var}} placeholders in strings."""
if isinstance(obj, str):
for key, value in variables.items():
obj = obj.replace(f"{{{{{key}}}}}", value)
return obj
if isinstance(obj, dict):
return {k: _substitute_templates(v, variables) for k, v in obj.items()}
if isinstance(obj, list):
return [_substitute_templates(item, variables) for item in obj]
return obj


# ---------------------------------------------------------------------------
# Spec dataclass
# ---------------------------------------------------------------------------
Expand All @@ -89,6 +150,7 @@ class LlmSpanSpec:
endpoint: str
requests: list[dict[str, Any]]
expected_brainstore_spans: list[dict[str, Any]]
headers: dict[str, str]
source_path: Path

@property
Expand All @@ -98,13 +160,19 @@ def display_name(self) -> str:

@classmethod
def from_dict(cls, data: dict[str, Any], source_path: Path) -> "LlmSpanSpec":
# Resolve variables and substitute templates in requests
raw_variables = data.get("variables", {})
variables = _resolve_variables(raw_variables)
requests = _substitute_templates(data.get("requests", []), variables)

return cls(
name=data["name"],
type=data["type"],
provider=data["provider"],
endpoint=data["endpoint"],
requests=data.get("requests", []),
requests=requests,
expected_brainstore_spans=data.get("expected_brainstore_spans", []),
headers=data.get("headers", {}),
source_path=source_path,
)

Expand All @@ -117,13 +185,6 @@ def from_dict(cls, data: dict[str, Any], source_path: Path) -> "LlmSpanSpec":


def _spec_root(override: str | None = None) -> Path:
"""Return the llm_span spec root directory.

Priority:
1. ``override`` argument (used by the pytest fixture after fetching specs)
2. ``BTX_SPEC_ROOT`` environment variable
3. ``<btx-dir>/spec/test/llm_span`` (local dev snapshot)
"""
if override:
return Path(override)
env = os.environ.get("BTX_SPEC_ROOT")
Expand All @@ -139,10 +200,8 @@ def load_specs(
"""Load all YAML spec files under *spec_root*.

Args:
spec_root: Path to the ``test/llm_span`` directory. Falls back to
:func:`_spec_root` resolution if ``None``.
spec_root: Path to the ``test/llm_span`` directory.
providers: Optional allow-list of provider names (e.g. ``["openai"]``).
If ``None``, all providers are loaded.

Returns:
Sorted list of :class:`LlmSpanSpec` instances.
Expand All @@ -159,7 +218,6 @@ def load_specs(
specs: list[LlmSpanSpec] = []

for yaml_path in sorted(root.rglob("*.yaml")):
# Filter by provider directory if requested
provider_dir = yaml_path.parent.name
if providers is not None and provider_dir not in providers:
continue
Expand Down
14 changes: 9 additions & 5 deletions py/src/braintrust/integrations/anthropic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,18 @@ def extract_anthropic_usage(usage: Any) -> tuple[dict[str, float], dict[str, Any
for source_name, value in server_tool_use.items():
_set_numeric_metric(metrics, f"server_tool_use_{source_name}", value)

if "prompt_cache_creation_tokens" not in metrics and cache_creation_breakdown:
metrics["prompt_cache_creation_tokens"] = sum(cache_creation_breakdown)
if any(v > 0 for v in cache_creation_breakdown):
# Per-TTL breakdown has non-zero values — omit the aggregate so consumers
# can rely on the breakdown fields exclusively (spec: undefined_or_null).
metrics.pop("prompt_cache_creation_tokens", None)
cache_creation_total = sum(cache_creation_breakdown)
else:
# No breakdown or all-zero breakdown — keep the aggregate.
cache_creation_total = metrics.get("prompt_cache_creation_tokens", 0)

if metrics:
total_prompt_tokens = (
metrics.get("prompt_tokens", 0)
+ metrics.get("prompt_cached_tokens", 0)
+ metrics.get("prompt_cache_creation_tokens", 0)
metrics.get("prompt_tokens", 0) + metrics.get("prompt_cached_tokens", 0) + cache_creation_total
)
metrics["prompt_tokens"] = total_prompt_tokens
metrics["tokens"] = total_prompt_tokens + metrics.get("completion_tokens", 0)
Expand Down
12 changes: 7 additions & 5 deletions py/src/braintrust/integrations/anthropic/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ def to_dict(self):
"prompt_tokens": 21.0,
"completion_tokens": 7.0,
"prompt_cached_tokens": 3.0,
"prompt_cache_creation_tokens": 7.0,
"prompt_cache_creation_5m_tokens": 2.0,
"prompt_cache_creation_1h_tokens": 5.0,
"server_tool_use_web_search_requests": 2.0,
Expand Down Expand Up @@ -371,7 +370,7 @@ def test_anthropic_messages_create_prompt_cache_5m_metrics(memory_logger):

span = find_span_by_name(memory_logger.pop(), "anthropic.messages.create")
assert span["output"]["role"] == response.role
assert span["metrics"]["prompt_cache_creation_tokens"] == response.usage.cache_creation_input_tokens
assert "prompt_cache_creation_tokens" not in span["metrics"]
assert (
span["metrics"]["prompt_cache_creation_5m_tokens"] == response.usage.cache_creation.ephemeral_5m_input_tokens
)
Expand Down Expand Up @@ -403,7 +402,7 @@ def test_anthropic_messages_create_prompt_cache_1h_metrics(memory_logger):

span = find_span_by_name(memory_logger.pop(), "anthropic.messages.create")
assert span["output"]["role"] == response.role
assert span["metrics"]["prompt_cache_creation_tokens"] == response.usage.cache_creation_input_tokens
assert "prompt_cache_creation_tokens" not in span["metrics"]
assert (
span["metrics"]["prompt_cache_creation_5m_tokens"] == response.usage.cache_creation.ephemeral_5m_input_tokens
)
Expand Down Expand Up @@ -1342,7 +1341,10 @@ def test_setup_creates_spans(memory_logger):
usage.input_tokens + usage.cache_read_input_tokens + usage.cache_creation_input_tokens
)
assert metrics["completion_tokens"] == usage.output_tokens
assert metrics["prompt_cache_creation_tokens"] == usage.cache_creation_input_tokens
# When breakdown is all zeros, the aggregate is still reported.
# When breakdown has non-zero values, the aggregate is omitted.
if ephemeral_5m > 0 or ephemeral_1h > 0:
assert "prompt_cache_creation_tokens" not in metrics
assert metrics["prompt_cache_creation_5m_tokens"] == ephemeral_5m
assert metrics["prompt_cache_creation_1h_tokens"] == ephemeral_1h
assert "service_tier" not in metrics
Expand Down Expand Up @@ -1373,7 +1375,7 @@ def test_extract_anthropic_usage_preserves_nested_numeric_fields():
assert metrics["prompt_tokens"] == 15
assert metrics["completion_tokens"] == 12
assert metrics["tokens"] == 27
assert metrics["prompt_cache_creation_tokens"] == 7
assert "prompt_cache_creation_tokens" not in metrics
assert metrics["prompt_cache_creation_5m_tokens"] == 3
assert metrics["prompt_cache_creation_1h_tokens"] == 4
assert metrics["server_tool_use_web_search_requests"] == 2
Expand Down
Loading