From d7d4d18ad48e00117e21f69af2f009016d8d6aad Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Thu, 30 Apr 2026 07:44:27 -0700 Subject: [PATCH 1/2] initial commit --- py/src/braintrust/__init__.py | 1 + py/src/braintrust/dataset_pipeline.py | 98 +++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 py/src/braintrust/dataset_pipeline.py diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index 72252a1d..4b50b9b9 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -62,6 +62,7 @@ def is_equal(expected, output): from .audit import * from .auto import auto_instrument as auto_instrument +from .dataset_pipeline import * from .framework import * from .framework2 import * from .functions.invoke import * diff --git a/py/src/braintrust/dataset_pipeline.py b/py/src/braintrust/dataset_pipeline.py new file mode 100644 index 00000000..a458c45d --- /dev/null +++ b/py/src/braintrust/dataset_pipeline.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from typing import Any, Generic, Literal, TypeAlias, TypeVar + +from typing_extensions import NotRequired, TypedDict + +from .generated_types import ObjectReference +from .logger import Metadata +from .trace import Trace + + +DatasetPipelineScope: TypeAlias = Literal["span", "trace"] + + +class DatasetPipelineSource(TypedDict, total=False): + project_id: str + project_name: str + org_name: str + filter: str + scope: DatasetPipelineScope + limit: int + + +class DatasetPipelineTarget(TypedDict, total=False): + project_id: str + project_name: str + org_name: str + dataset_name: str + description: str + metadata: Metadata + + +class DatasetPipelineRow(TypedDict, total=False): + id: str + input: Any | None + expected: Any | None + tags: Sequence[str] | None + metadata: Metadata | None + origin: ObjectReference + + +class DatasetPipelineCandidate(TypedDict): + trace: Trace + id: NotRequired[str] + origin: NotRequired[ObjectReference] + + +Candidate = TypeVar("Candidate", bound=DatasetPipelineCandidate) +Row = TypeVar("Row", bound=DatasetPipelineRow) + + +class DatasetPipelineTransformContext(TypedDict): + pipeline: "DatasetPipelineDefinition[Any, Any]" + + +DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None +DatasetPipelineTransform: TypeAlias = Callable[ + [Candidate, DatasetPipelineTransformContext], + DatasetPipelineTransformResult[Row] | Awaitable[DatasetPipelineTransformResult[Row]], +] + + +@dataclass(frozen=True) +class DatasetPipelineDefinition(Generic[Candidate, Row]): + source: DatasetPipelineSource + transform: DatasetPipelineTransform[Candidate, Row] + target: DatasetPipelineTarget + name: str | None = None + + +_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any, Any]] = [] + + +def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any, Any]]: + return list(_DATASET_PIPELINES) + + +def is_dataset_pipeline_definition(value: object) -> bool: + return isinstance(value, DatasetPipelineDefinition) + + +def DatasetPipeline( + name: str | None = None, + *, + source: DatasetPipelineSource, + transform: DatasetPipelineTransform[DatasetPipelineCandidate, DatasetPipelineRow], + target: DatasetPipelineTarget, +) -> DatasetPipelineDefinition[DatasetPipelineCandidate, DatasetPipelineRow]: + definition = DatasetPipelineDefinition( + name=name, + source=source, + transform=transform, + target=target, + ) + _DATASET_PIPELINES.append(definition) + return definition From da2f03175281c668e565d19cf3f4af6118855b53 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Sun, 3 May 2026 15:34:48 -0400 Subject: [PATCH 2/2] more updates --- py/src/braintrust/dataset_pipeline.py | 125 +++++++++++++++++++------- py/src/braintrust/trace.py | 116 ++++++++++++++++-------- 2 files changed, 175 insertions(+), 66 deletions(-) diff --git a/py/src/braintrust/dataset_pipeline.py b/py/src/braintrust/dataset_pipeline.py index a458c45d..4cbbbf4f 100644 --- a/py/src/braintrust/dataset_pipeline.py +++ b/py/src/braintrust/dataset_pipeline.py @@ -1,8 +1,8 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable, Sequence +from collections.abc import Awaitable, Sequence from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeAlias, TypeVar +from typing import Any, Generic, Literal, Protocol, TypeAlias, TypeVar from typing_extensions import NotRequired, TypedDict @@ -20,16 +20,57 @@ class DatasetPipelineSource(TypedDict, total=False): org_name: str filter: str scope: DatasetPipelineScope - limit: int -class DatasetPipelineTarget(TypedDict, total=False): - project_id: str - project_name: str - org_name: str +@dataclass(frozen=True) +class PipelineSource: + filter: str | None = None + scope: DatasetPipelineScope | None = None + project_name: str | None = None + project_id: str | None = None + org_name: str | None = None + + def as_dict(self) -> DatasetPipelineSource: + return _drop_none( + { + "project_id": self.project_id, + "project_name": self.project_name, + "org_name": self.org_name, + "filter": self.filter, + "scope": self.scope, + } + ) + + +class DatasetPipelineTarget(TypedDict): dataset_name: str - description: str - metadata: Metadata + project_id: NotRequired[str] + project_name: NotRequired[str] + org_name: NotRequired[str] + description: NotRequired[str] + metadata: NotRequired[Metadata] + + +@dataclass(frozen=True) +class PipelineTarget: + dataset_name: str + project_name: str | None = None + project_id: str | None = None + org_name: str | None = None + description: str | None = None + metadata: Metadata | None = None + + def as_dict(self) -> DatasetPipelineTarget: + return _drop_none( + { + "project_id": self.project_id, + "project_name": self.project_name, + "org_name": self.org_name, + "dataset_name": self.dataset_name, + "description": self.description, + "metadata": self.metadata, + } + ) class DatasetPipelineRow(TypedDict, total=False): @@ -41,39 +82,61 @@ class DatasetPipelineRow(TypedDict, total=False): origin: ObjectReference -class DatasetPipelineCandidate(TypedDict): +Row = TypeVar("Row", bound=DatasetPipelineRow) + + +class DatasetPipelineTransformArgs(TypedDict, total=False): + input: Any | None + output: Any | None + metadata: Metadata | None + expected: Any | None trace: Trace - id: NotRequired[str] - origin: NotRequired[ObjectReference] -Candidate = TypeVar("Candidate", bound=DatasetPipelineCandidate) -Row = TypeVar("Row", bound=DatasetPipelineRow) +DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None +DatasetPipelineSourceLike: TypeAlias = DatasetPipelineSource | PipelineSource +DatasetPipelineTargetLike: TypeAlias = DatasetPipelineTarget | PipelineTarget -class DatasetPipelineTransformContext(TypedDict): - pipeline: "DatasetPipelineDefinition[Any, Any]" +def _drop_none(values: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in values.items() if value is not None} -DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None -DatasetPipelineTransform: TypeAlias = Callable[ - [Candidate, DatasetPipelineTransformContext], - DatasetPipelineTransformResult[Row] | Awaitable[DatasetPipelineTransformResult[Row]], -] +def _normalize_source(source: DatasetPipelineSourceLike) -> DatasetPipelineSource: + if isinstance(source, PipelineSource): + return source.as_dict() + return dict(source) + + +def _normalize_target(target: DatasetPipelineTargetLike) -> DatasetPipelineTarget: + if isinstance(target, PipelineTarget): + return target.as_dict() + return dict(target) + + +class DatasetPipelineTransform(Protocol[Row]): + def __call__( + self, + input: Any | None = None, + output: Any | None = None, + metadata: Metadata | None = None, + expected: Any | None = None, + trace: Trace | None = None, + ) -> DatasetPipelineTransformResult[Row] | Awaitable[DatasetPipelineTransformResult[Row]]: ... @dataclass(frozen=True) -class DatasetPipelineDefinition(Generic[Candidate, Row]): +class DatasetPipelineDefinition(Generic[Row]): source: DatasetPipelineSource - transform: DatasetPipelineTransform[Candidate, Row] + transform: DatasetPipelineTransform[Row] target: DatasetPipelineTarget name: str | None = None -_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any, Any]] = [] +_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any]] = [] -def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any, Any]]: +def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any]]: return list(_DATASET_PIPELINES) @@ -84,15 +147,15 @@ def is_dataset_pipeline_definition(value: object) -> bool: def DatasetPipeline( name: str | None = None, *, - source: DatasetPipelineSource, - transform: DatasetPipelineTransform[DatasetPipelineCandidate, DatasetPipelineRow], - target: DatasetPipelineTarget, -) -> DatasetPipelineDefinition[DatasetPipelineCandidate, DatasetPipelineRow]: + source: DatasetPipelineSourceLike, + transform: DatasetPipelineTransform[DatasetPipelineRow], + target: DatasetPipelineTargetLike, +) -> DatasetPipelineDefinition[DatasetPipelineRow]: definition = DatasetPipelineDefinition( name=name, - source=source, + source=_normalize_source(source), transform=transform, - target=target, + target=_normalize_target(target), ) _DATASET_PIPELINES.append(definition) return definition diff --git a/py/src/braintrust/trace.py b/py/src/braintrust/trace.py index d3426ac4..4a1aff01 100644 --- a/py/src/braintrust/trace.py +++ b/py/src/braintrust/trace.py @@ -63,9 +63,10 @@ def __init__( root_span_id: str, state: BraintrustState, span_type_filter: Optional[list[str]] = None, + include_scorers: bool = False, ): # Build the filter expression for root_span_id and optionally span_attributes.type - filter_expr = self._build_filter(root_span_id, span_type_filter) + filter_expr = self._build_filter(root_span_id, span_type_filter, include_scorers) super().__init__( object_type=object_type, @@ -75,7 +76,11 @@ def __init__( self._state = state @staticmethod - def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]: + def _build_filter( + root_span_id: str, + span_type_filter: Optional[list[str]] = None, + include_scorers: bool = False, + ) -> dict[str, Any]: """Build BTQL filter expression.""" children = [ # Base filter: root_span_id = 'value' @@ -84,23 +89,32 @@ def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = Non "left": {"op": "ident", "name": ["root_span_id"]}, "right": {"op": "literal", "value": root_span_id}, }, - # Exclude span_attributes.purpose = 'scorer' - { - "op": "or", - "children": [ - { - "op": "isnull", - "expr": {"op": "ident", "name": ["span_attributes", "purpose"]}, - }, - { - "op": "ne", - "left": {"op": "ident", "name": ["span_attributes", "purpose"]}, - "right": {"op": "literal", "value": "scorer"}, - }, - ], - }, ] + if not include_scorers: + children.append( + { + "op": "or", + "children": [ + { + "op": "isnull", + "expr": { + "op": "ident", + "name": ["span_attributes", "purpose"], + }, + }, + { + "op": "ne", + "left": { + "op": "ident", + "name": ["span_attributes", "purpose"], + }, + "right": {"op": "literal", "value": "scorer"}, + }, + ], + } + ) + # If span type filter specified, add it if span_type_filter and len(span_type_filter) > 0: children.append( @@ -122,6 +136,7 @@ def _get_state(self) -> BraintrustState: SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]] +SpanFetchWithOptionsFn = Callable[[Optional[list[str]], bool], Awaitable[list[SpanData]]] class GetThreadOptions(TypedDict, total=False): @@ -151,7 +166,14 @@ def __init__( if fetch_fn is not None: # Direct fetch function injection (for testing) - self._fetch_fn = fetch_fn + async def _fetch_fn( + span_type: Optional[list[str]], + include_scorers: bool = False, + ) -> list[SpanData]: + del include_scorers + return await fetch_fn(span_type) + + self._fetch_fn: SpanFetchWithOptionsFn = _fetch_fn else: # Standard constructor with SpanFetcher if object_type is None or object_id is None or root_span_id is None or get_state is None: @@ -159,7 +181,10 @@ def __init__( "Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state" ) - async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]: + async def _fetch_fn( + span_type: Optional[list[str]], + include_scorers: bool = False, + ) -> list[SpanData]: state = await get_state() fetcher = SpanFetcher( object_type=object_type, @@ -167,21 +192,14 @@ async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]: root_span_id=root_span_id, state=state, span_type_filter=span_type, + include_scorers=include_scorers, ) rows = list(fetcher.fetch()) - # Filter out scorer spans - filtered = [ - row - for row in rows - if not ( - isinstance(row.get("span_attributes"), dict) - and row.get("span_attributes", {}).get("purpose") == "scorer" - ) - ] return [ SpanData( input=row.get("input"), output=row.get("output"), + expected=row.get("expected"), metadata=row.get("metadata"), span_id=row.get("span_id"), span_parents=row.get("span_parents"), @@ -190,22 +208,33 @@ async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]: _xact_id=row.get("_xact_id"), _pagination_key=row.get("_pagination_key"), root_span_id=row.get("root_span_id"), + created=row.get("created"), + tags=row.get("tags"), ) - for row in filtered + for row in rows ] self._fetch_fn = _fetch_fn - async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + async def get_spans( + self, + span_type: Optional[list[str]] = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Get spans, using cache when possible. Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans """ + if include_scorers: + return await self._fetch_fn(span_type, True) + # If we've fetched all spans, just filter from cache if self._all_fetched: return self._get_from_cache(span_type) @@ -230,7 +259,7 @@ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanDat async def _fetch_spans(self, span_type: Optional[list[str]]) -> None: """Fetch spans from the server.""" - spans = await self._fetch_fn(span_type) + spans = await self._fetch_fn(span_type, False) for span in spans: span_attrs = span.span_attributes or {} @@ -266,12 +295,18 @@ def get_configuration(self) -> dict[str, str]: """Get the trace configuration (object_type, object_id, root_span_id).""" ... - async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + async def get_spans( + self, + span_type: Optional[list[str]] = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Fetch all spans for this root span. Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans @@ -351,7 +386,12 @@ def get_configuration(self) -> dict[str, str]: "root_span_id": self._root_span_id, } - async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + async def get_spans( + self, + span_type: Optional[list[str]] = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Fetch all rows for this root span from its parent object (experiment or project logs). First checks the local span cache for recently logged spans, then falls @@ -359,6 +399,7 @@ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanDat Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans @@ -367,7 +408,11 @@ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanDat cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id) if cached_spans and len(cached_spans) > 0: # Filter by purpose - spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"] + spans = [ + span + for span in cached_spans + if include_scorers or not (span.span_attributes or {}).get("purpose") == "scorer" + ] # Filter by span type if requested if span_type and len(span_type) > 0: @@ -378,6 +423,7 @@ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanDat SpanData( input=span.input, output=span.output, + expected=getattr(span, "expected", None), metadata=span.metadata, span_id=span.span_id, span_parents=span.span_parents, @@ -387,7 +433,7 @@ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanDat ] # Fall back to CachedSpanFetcher for BTQL fetching with caching - return await self._cached_fetcher.get_spans(span_type) + return await self._cached_fetcher.get_spans(span_type, include_scorers=include_scorers) async def get_thread(self, options: GetThreadOptions | None = None) -> list[Any]: """