diff --git a/py/src/braintrust/__init__.py b/py/src/braintrust/__init__.py index 72252a1d..4b50b9b9 100644 --- a/py/src/braintrust/__init__.py +++ b/py/src/braintrust/__init__.py @@ -62,6 +62,7 @@ def is_equal(expected, output): from .audit import * from .auto import auto_instrument as auto_instrument +from .dataset_pipeline import * from .framework import * from .framework2 import * from .functions.invoke import * diff --git a/py/src/braintrust/dataset_pipeline.py b/py/src/braintrust/dataset_pipeline.py new file mode 100644 index 00000000..4cbbbf4f --- /dev/null +++ b/py/src/braintrust/dataset_pipeline.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass +from typing import Any, Generic, Literal, Protocol, TypeAlias, TypeVar + +from typing_extensions import NotRequired, TypedDict + +from .generated_types import ObjectReference +from .logger import Metadata +from .trace import Trace + + +DatasetPipelineScope: TypeAlias = Literal["span", "trace"] + + +class DatasetPipelineSource(TypedDict, total=False): + project_id: str + project_name: str + org_name: str + filter: str + scope: DatasetPipelineScope + + +@dataclass(frozen=True) +class PipelineSource: + filter: str | None = None + scope: DatasetPipelineScope | None = None + project_name: str | None = None + project_id: str | None = None + org_name: str | None = None + + def as_dict(self) -> DatasetPipelineSource: + return _drop_none( + { + "project_id": self.project_id, + "project_name": self.project_name, + "org_name": self.org_name, + "filter": self.filter, + "scope": self.scope, + } + ) + + +class DatasetPipelineTarget(TypedDict): + dataset_name: str + project_id: NotRequired[str] + project_name: NotRequired[str] + org_name: NotRequired[str] + description: NotRequired[str] + metadata: NotRequired[Metadata] + + +@dataclass(frozen=True) +class PipelineTarget: + dataset_name: str + project_name: str | None = None + project_id: str | None = None + org_name: str | None = None + description: str | None = None + metadata: Metadata | None = None + + def as_dict(self) -> DatasetPipelineTarget: + return _drop_none( + { + "project_id": self.project_id, + "project_name": self.project_name, + "org_name": self.org_name, + "dataset_name": self.dataset_name, + "description": self.description, + "metadata": self.metadata, + } + ) + + +class DatasetPipelineRow(TypedDict, total=False): + id: str + input: Any | None + expected: Any | None + tags: Sequence[str] | None + metadata: Metadata | None + origin: ObjectReference + + +Row = TypeVar("Row", bound=DatasetPipelineRow) + + +class DatasetPipelineTransformArgs(TypedDict, total=False): + input: Any | None + output: Any | None + metadata: Metadata | None + expected: Any | None + trace: Trace + + +DatasetPipelineTransformResult: TypeAlias = Row | Sequence[Row] | None +DatasetPipelineSourceLike: TypeAlias = DatasetPipelineSource | PipelineSource +DatasetPipelineTargetLike: TypeAlias = DatasetPipelineTarget | PipelineTarget + + +def _drop_none(values: dict[str, Any]) -> dict[str, Any]: + return {key: value for key, value in values.items() if value is not None} + + +def _normalize_source(source: DatasetPipelineSourceLike) -> DatasetPipelineSource: + if isinstance(source, PipelineSource): + return source.as_dict() + return dict(source) + + +def _normalize_target(target: DatasetPipelineTargetLike) -> DatasetPipelineTarget: + if isinstance(target, PipelineTarget): + return target.as_dict() + return dict(target) + + +class DatasetPipelineTransform(Protocol[Row]): + def __call__( + self, + input: Any | None = None, + output: Any | None = None, + metadata: Metadata | None = None, + expected: Any | None = None, + trace: Trace | None = None, + ) -> DatasetPipelineTransformResult[Row] | Awaitable[DatasetPipelineTransformResult[Row]]: ... + + +@dataclass(frozen=True) +class DatasetPipelineDefinition(Generic[Row]): + source: DatasetPipelineSource + transform: DatasetPipelineTransform[Row] + target: DatasetPipelineTarget + name: str | None = None + + +_DATASET_PIPELINES: list[DatasetPipelineDefinition[Any]] = [] + + +def get_registered_dataset_pipelines() -> list[DatasetPipelineDefinition[Any]]: + return list(_DATASET_PIPELINES) + + +def is_dataset_pipeline_definition(value: object) -> bool: + return isinstance(value, DatasetPipelineDefinition) + + +def DatasetPipeline( + name: str | None = None, + *, + source: DatasetPipelineSourceLike, + transform: DatasetPipelineTransform[DatasetPipelineRow], + target: DatasetPipelineTargetLike, +) -> DatasetPipelineDefinition[DatasetPipelineRow]: + definition = DatasetPipelineDefinition( + name=name, + source=_normalize_source(source), + transform=transform, + target=_normalize_target(target), + ) + _DATASET_PIPELINES.append(definition) + return definition diff --git a/py/src/braintrust/trace.py b/py/src/braintrust/trace.py index 24bcefa2..5bdb814f 100644 --- a/py/src/braintrust/trace.py +++ b/py/src/braintrust/trace.py @@ -64,9 +64,10 @@ def __init__( root_span_id: str, state: BraintrustState, span_type_filter: list[str] | None = None, + include_scorers: bool = False, ): # Build the filter expression for root_span_id and optionally span_attributes.type - filter_expr = self._build_filter(root_span_id, span_type_filter) + filter_expr = self._build_filter(root_span_id, span_type_filter, include_scorers) super().__init__( object_type=object_type, @@ -76,7 +77,11 @@ def __init__( self._state = state @staticmethod - def _build_filter(root_span_id: str, span_type_filter: list[str] | None = None) -> dict[str, Any]: + def _build_filter( + root_span_id: str, + span_type_filter: list[str] | None = None, + include_scorers: bool = False, + ) -> dict[str, Any]: """Build BTQL filter expression.""" children = [ # Base filter: root_span_id = 'value' @@ -85,23 +90,32 @@ def _build_filter(root_span_id: str, span_type_filter: list[str] | None = None) "left": {"op": "ident", "name": ["root_span_id"]}, "right": {"op": "literal", "value": root_span_id}, }, - # Exclude span_attributes.purpose = 'scorer' - { - "op": "or", - "children": [ - { - "op": "isnull", - "expr": {"op": "ident", "name": ["span_attributes", "purpose"]}, - }, - { - "op": "ne", - "left": {"op": "ident", "name": ["span_attributes", "purpose"]}, - "right": {"op": "literal", "value": "scorer"}, - }, - ], - }, ] + if not include_scorers: + children.append( + { + "op": "or", + "children": [ + { + "op": "isnull", + "expr": { + "op": "ident", + "name": ["span_attributes", "purpose"], + }, + }, + { + "op": "ne", + "left": { + "op": "ident", + "name": ["span_attributes", "purpose"], + }, + "right": {"op": "literal", "value": "scorer"}, + }, + ], + } + ) + # If span type filter specified, add it if span_type_filter and len(span_type_filter) > 0: children.append( @@ -123,6 +137,7 @@ def _get_state(self) -> BraintrustState: SpanFetchFn = Callable[[list[str] | None], Awaitable[list[SpanData]]] +SpanFetchWithOptionsFn = Callable[[list[str] | None, bool], Awaitable[list[SpanData]]] class GetThreadOptions(TypedDict, total=False): @@ -152,7 +167,14 @@ def __init__( if fetch_fn is not None: # Direct fetch function injection (for testing) - self._fetch_fn = fetch_fn + async def _fetch_fn( + span_type: list[str] | None, + include_scorers: bool = False, + ) -> list[SpanData]: + del include_scorers + return await fetch_fn(span_type) + + self._fetch_fn: SpanFetchWithOptionsFn = _fetch_fn else: # Standard constructor with SpanFetcher if object_type is None or object_id is None or root_span_id is None or get_state is None: @@ -160,7 +182,10 @@ def __init__( "Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state" ) - async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: + async def _fetch_fn( + span_type: list[str] | None, + include_scorers: bool = False, + ) -> list[SpanData]: state = await get_state() fetcher = SpanFetcher( object_type=object_type, @@ -168,21 +193,14 @@ async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: root_span_id=root_span_id, state=state, span_type_filter=span_type, + include_scorers=include_scorers, ) rows = list(fetcher.fetch()) - # Filter out scorer spans - filtered = [ - row - for row in rows - if not ( - isinstance(row.get("span_attributes"), dict) - and row.get("span_attributes", {}).get("purpose") == "scorer" - ) - ] return [ SpanData( input=row.get("input"), output=row.get("output"), + expected=row.get("expected"), metadata=row.get("metadata"), span_id=row.get("span_id"), span_parents=row.get("span_parents"), @@ -191,22 +209,33 @@ async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: _xact_id=row.get("_xact_id"), _pagination_key=row.get("_pagination_key"), root_span_id=row.get("root_span_id"), + created=row.get("created"), + tags=row.get("tags"), ) - for row in filtered + for row in rows ] self._fetch_fn = _fetch_fn - async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: + async def get_spans( + self, + span_type: list[str] | None = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Get spans, using cache when possible. Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans """ + if include_scorers: + return await self._fetch_fn(span_type, True) + # If we've fetched all spans, just filter from cache if self._all_fetched: return self._get_from_cache(span_type) @@ -231,7 +260,7 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: async def _fetch_spans(self, span_type: list[str] | None) -> None: """Fetch spans from the server.""" - spans = await self._fetch_fn(span_type) + spans = await self._fetch_fn(span_type, False) for span in spans: span_attrs = span.span_attributes or {} @@ -267,12 +296,18 @@ def get_configuration(self) -> dict[str, str]: """Get the trace configuration (object_type, object_id, root_span_id).""" ... - async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: + async def get_spans( + self, + span_type: list[str] | None = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Fetch all spans for this root span. Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans @@ -352,7 +387,12 @@ def get_configuration(self) -> dict[str, str]: "root_span_id": self._root_span_id, } - async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: + async def get_spans( + self, + span_type: list[str] | None = None, + *, + include_scorers: bool = False, + ) -> list[SpanData]: """ Fetch all rows for this root span from its parent object (experiment or project logs). First checks the local span cache for recently logged spans, then falls @@ -360,6 +400,7 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: Args: span_type: Optional list of span types to filter by + include_scorers: Include spans with span_attributes.purpose = "scorer" Returns: List of matching spans @@ -368,7 +409,11 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: cached_spans = self._state.span_cache.get_by_root_span_id(self._root_span_id) if cached_spans and len(cached_spans) > 0: # Filter by purpose - spans = [span for span in cached_spans if not (span.span_attributes or {}).get("purpose") == "scorer"] + spans = [ + span + for span in cached_spans + if include_scorers or not (span.span_attributes or {}).get("purpose") == "scorer" + ] # Filter by span type if requested if span_type and len(span_type) > 0: @@ -379,6 +424,7 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: SpanData( input=span.input, output=span.output, + expected=getattr(span, "expected", None), metadata=span.metadata, span_id=span.span_id, span_parents=span.span_parents, @@ -388,7 +434,7 @@ async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: ] # Fall back to CachedSpanFetcher for BTQL fetching with caching - return await self._cached_fetcher.get_spans(span_type) + return await self._cached_fetcher.get_spans(span_type, include_scorers=include_scorers) async def get_thread(self, options: GetThreadOptions | None = None) -> list[Any]: """