Skip to content
Merged
3 changes: 2 additions & 1 deletion py/src/braintrust/bt_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import json
import math
import warnings
from typing import Any, Callable, Mapping, NamedTuple, cast, overload
from collections.abc import Callable, Mapping
from typing import Any, NamedTuple, cast, overload


# Try to import orjson for better performance
Expand Down
10 changes: 5 additions & 5 deletions py/src/braintrust/contrib/temporal/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
from dataclasses import dataclass
from datetime import timedelta
from typing import Any, Dict
from typing import Any

import pytest
import pytest_asyncio
Expand All @@ -31,7 +31,7 @@ class TestHeaderSerialization:
def test_span_context_to_headers_with_valid_context(self):
interceptor = BraintrustInterceptor()
span_context = {"trace_id": "test-trace-id", "span_id": "test-span-id"}
headers: Dict[str, temporalio.api.common.v1.Payload] = {}
headers: dict[str, temporalio.api.common.v1.Payload] = {}

result_headers = interceptor._span_context_to_headers(span_context, headers)

Expand All @@ -40,8 +40,8 @@ def test_span_context_to_headers_with_valid_context(self):

def test_span_context_to_headers_with_empty_context(self):
interceptor = BraintrustInterceptor()
span_context: Dict[str, Any] = {}
headers: Dict[str, temporalio.api.common.v1.Payload] = {}
span_context: dict[str, Any] = {}
headers: dict[str, temporalio.api.common.v1.Payload] = {}

result_headers = interceptor._span_context_to_headers(span_context, headers)

Expand Down Expand Up @@ -78,7 +78,7 @@ def test_span_context_from_headers_with_valid_header(self):

def test_span_context_from_headers_with_missing_header(self):
interceptor = BraintrustInterceptor()
headers: Dict[str, temporalio.api.common.v1.Payload] = {}
headers: dict[str, temporalio.api.common.v1.Payload] = {}

result = interceptor._span_context_from_headers(headers)

Expand Down
4 changes: 1 addition & 3 deletions py/src/braintrust/devserver/schemas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
from collections.abc import Sequence
from typing import Any, Union, get_args, get_origin, get_type_hints

from typing_extensions import TypedDict
from typing import Any, TypedDict, Union, get_args, get_origin, get_type_hints


# This is not beautiful code, but it saves us from introducing Pydantic as a dependency, and it is fairly
Expand Down
68 changes: 29 additions & 39 deletions py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
Any,
Generic,
Literal,
Protocol,
TypedDict,
TypeVar,
Union,
)

from tqdm.asyncio import tqdm as async_tqdm
from tqdm.auto import tqdm as std_tqdm
from typing_extensions import Protocol, TypedDict

from .generated_types import FunctionFormat, FunctionOutputType, ObjectReference
from .git_fields import GitMetadataSettings, RepoInfo
Expand Down Expand Up @@ -216,13 +216,8 @@ class EvalScorerArgs(SerializableDataClass, Generic[Input, Output, Expected]):
metadata: Metadata | None = None


OneOrMoreScores = Union[float, int, bool, None, Score, list[Score]]
OneOrMoreClassifications = Union[
None,
Classification,
Mapping[str, Any],
list[Classification | Mapping[str, Any]],
]
OneOrMoreScores = float | int | bool | None | Score | list[Score]
OneOrMoreClassifications = None | Classification | Mapping[str, Any] | list[Classification | Mapping[str, Any]]


# Synchronous scorer interface - implements callable
Expand All @@ -247,20 +242,19 @@ class AsyncScorerLike(Protocol, Generic[Input, Output, Expected]):
async def eval_async(self, output: Output, expected: Expected | None = None, **kwargs: Any) -> OneOrMoreScores: ...


# Union type for any kind of scorer (for typing)
ScorerLike = Union[SyncScorerLike[Input, Output, Expected], AsyncScorerLike[Input, Output, Expected]]
ScorerLike = SyncScorerLike[Input, Output, Expected] | AsyncScorerLike[Input, Output, Expected]

EvalScorer = Union[
ScorerLike[Input, Output, Expected],
type[ScorerLike[Input, Output, Expected]],
Callable[[Input, Output, Expected], OneOrMoreScores],
Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]],
]
EvalScorer = (
ScorerLike[Input, Output, Expected]
| type[ScorerLike[Input, Output, Expected]]
| Callable[[Input, Output, Expected], OneOrMoreScores]
| Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]]
)

EvalClassifier = Union[
Callable[[Input, Output, Expected], OneOrMoreClassifications],
Callable[[Input, Output, Expected], Awaitable[OneOrMoreClassifications]],
]
EvalClassifier = (
Callable[[Input, Output, Expected], OneOrMoreClassifications]
| Callable[[Input, Output, Expected], Awaitable[OneOrMoreClassifications]]
)


@dataclasses.dataclass
Expand All @@ -278,27 +272,23 @@ class BaseExperiment:
"""


_AnyEvalCase = Union[
EvalCase[Input, Expected],
EvalCaseDict[Input, Expected],
EvalCaseDictNoOutput[Input],
ExperimentDatasetEvent,
]
_AnyEvalCase = (
EvalCase[Input, Expected] | EvalCaseDict[Input, Expected] | EvalCaseDictNoOutput[Input] | ExperimentDatasetEvent
)

_EvalDataObject = Union[
Iterable[_AnyEvalCase[Input, Expected]],
Iterator[_AnyEvalCase[Input, Expected]],
Awaitable[Iterator[_AnyEvalCase[Input, Expected]]],
Callable[[], Union[Iterator[_AnyEvalCase[Input, Expected]], Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]]],
BaseExperiment,
]
_EvalDataObject = (
Iterable[_AnyEvalCase[Input, Expected]]
| Iterator[_AnyEvalCase[Input, Expected]]
| Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]
| Callable[[], Iterator[_AnyEvalCase[Input, Expected]] | Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]]
| BaseExperiment
)

EvalData = Union[_EvalDataObject[Input, Expected], type[_EvalDataObject[Input, Expected]], Dataset]
EvalData = _EvalDataObject[Input, Expected] | type[_EvalDataObject[Input, Expected]] | Dataset

EvalTask = Union[
Callable[[Input], Union[Output, Awaitable[Output]]],
Callable[[Input, EvalHooks[Expected]], Union[Output, Awaitable[Output]]],
]
EvalTask = (
Callable[[Input], Output | Awaitable[Output]] | Callable[[Input, EvalHooks[Expected]], Output | Awaitable[Output]]
)

ErrorScoreHandler = Callable[[Span, EvalCase[Input, Expected], Sequence[str]], dict[str, float] | None]

Expand Down
12 changes: 4 additions & 8 deletions py/src/braintrust/functions/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import json
from collections.abc import Generator, Iterable
from itertools import tee
from typing import Literal, Union
from typing import Literal

from sseclient import SSEClient

Expand Down Expand Up @@ -79,13 +79,9 @@ class BraintrustInvokeError(ValueError):
pass


BraintrustStreamChunk = Union[
BraintrustTextChunk,
BraintrustJsonChunk,
BraintrustErrorChunk,
BraintrustConsoleChunk,
BraintrustProgressChunk,
]
BraintrustStreamChunk = (
BraintrustTextChunk | BraintrustJsonChunk | BraintrustErrorChunk | BraintrustConsoleChunk | BraintrustProgressChunk
)


class BraintrustStream:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections.abc import Mapping
from dataclasses import dataclass
from enum import Enum
from types import MappingProxyType
from typing import Final, Mapping
from typing import Final


class MessageClassName(str, Enum):
Expand Down
3 changes: 1 addition & 2 deletions py/src/braintrust/integrations/langchain/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import (
Any,
TypedDict,
Union,
)
from uuid import UUID

Expand Down Expand Up @@ -528,7 +527,7 @@ def on_llm_new_token(
self,
token: str,
*,
chunk: Union["GenerationChunk", "ChatGenerationChunk"] | None = None, # type: ignore
chunk: "GenerationChunk | ChatGenerationChunk | None" = None, # type: ignore
run_id: UUID,
parent_run_id: UUID | None = None,
**kwargs: Any,
Expand Down
3 changes: 2 additions & 1 deletion py/src/braintrust/integrations/langchain/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Sequence
from collections.abc import Sequence
from typing import Any
from unittest.mock import ANY


Expand Down
24 changes: 12 additions & 12 deletions py/src/braintrust/integrations/langchain/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
import uuid
from pathlib import Path
from typing import Dict, List, Union, cast
from typing import cast

import pytest
from braintrust import logger
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_llm_calls(logger_memory_logger):
presence_penalty=0,
n=1,
)
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)
chain.invoke({"number": "2"}, config={"callbacks": [cast(BaseCallbackHandler, handler)]})

spans = memory_logger.pop()
Expand Down Expand Up @@ -159,7 +159,7 @@ def test_chain_with_memory(logger_memory_logger):
handler = BraintrustCallbackHandler(logger=test_logger)
prompt = ChatPromptTemplate.from_template("{history} User: {input}")
model = ChatOpenAI(model="gpt-4o-mini")
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)

memory = {"history": "Assistant: Hello! How can I assist you today?"}
chain.invoke(
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_parallel_execution(logger_memory_logger):

map_chain.invoke({"topic": "bear"}, config={"callbacks": [cast(BaseCallbackHandler, handler)]})

spans = cast(List, memory_logger.pop())
spans = cast(list, memory_logger.pop())

# Find the LLM spans
llm_spans = find_spans_by_attributes(spans, name="ChatOpenAI")
Expand Down Expand Up @@ -480,16 +480,16 @@ def test_langgraph_state_management(logger_memory_logger):
n=1,
)

def say_hello(state: Dict[str, str]):
def say_hello(state: dict[str, str]):
response = model.invoke("Say hello")
return cast(Union[str, List[str], Dict[str, str]], response.content)
return cast(str | list[str] | dict[str, str], response.content)

def say_bye(state: Dict[str, str]):
def say_bye(state: dict[str, str]):
print("From the 'sayBye' node: Bye world!")
return "Bye"

workflow = (
StateGraph(state_schema=Dict[str, str])
StateGraph(state_schema=dict[str, str])
.add_node("sayHello", say_hello)
.add_node("sayBye", say_bye)
.add_edge(START, "sayHello")
Expand Down Expand Up @@ -837,10 +837,10 @@ def test_streaming_ttft(logger_memory_logger):
max_completion_tokens=50,
streaming=True,
)
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)

# Collect chunks to verify streaming works
chunks: List[str] = []
chunks: list[str] = []
for chunk in chain.stream({}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}):
if chunk.content:
chunks.append(str(chunk.content))
Expand Down Expand Up @@ -1272,9 +1272,9 @@ async def test_async_streaming(logger_memory_logger):
handler = BraintrustCallbackHandler(logger=test_logger)
prompt = ChatPromptTemplate.from_template("Count from 1 to 3.")
model = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20, streaming=True)
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)

chunks: List[str] = []
chunks: list[str] = []
async for chunk in chain.astream({}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}):
if chunk.content:
chunks.append(str(chunk.content))
Expand Down
3 changes: 1 addition & 2 deletions py/src/braintrust/integrations/langchain/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pyright: reportTypedDictNotRequiredAccess=none
from typing import Dict
from unittest.mock import ANY

import pytest
Expand Down Expand Up @@ -57,7 +56,7 @@ def test_global_handler(logger_memory_logger):
presence_penalty=0,
n=1,
)
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)

message = chain.invoke({"number": "2"})

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Dict
from typing import Any

import pytest
from openai import AsyncOpenAI
Expand Down Expand Up @@ -61,7 +61,7 @@ def memory_logger():
yield bgl


def _assert_metrics_are_valid(metrics: Dict[str, Any]):
def _assert_metrics_are_valid(metrics: dict[str, Any]):
assert metrics["tokens"] > 0
assert metrics["prompt_tokens"] > 0
assert metrics["completion_tokens"] > 0
Expand Down
13 changes: 5 additions & 8 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
Literal,
TypedDict,
TypeVar,
Union,
cast,
overload,
)
Expand Down Expand Up @@ -1600,7 +1599,7 @@ def init(
base_experiment_id: str | None = None,
repo_info: RepoInfo | None = None,
state: BraintrustState | None = None,
) -> Union["Experiment", "ReadonlyExperiment"]:
) -> "Experiment | ReadonlyExperiment":
"""
Log in, and then initialize a new experiment in a specified project. If the project does not exist, it will be created.

Expand Down Expand Up @@ -1767,7 +1766,7 @@ def compute_metadata():
return ret


def init_experiment(*args, **kwargs) -> Union["Experiment", "ReadonlyExperiment"]:
def init_experiment(*args, **kwargs) -> "Experiment | ReadonlyExperiment":
"""Alias for `init`"""

return init(*args, **kwargs)
Expand Down Expand Up @@ -2392,7 +2391,7 @@ def parent_context(parent: str | None, state: BraintrustState | None = None):

def get_span_parent_object(
parent: str | None = None, state: BraintrustState | None = None
) -> Union[SpanComponentsV4, "Logger", "Experiment", Span]:
) -> "SpanComponentsV4 | Logger | Experiment | Span":
"""Mainly for internal use. Return the parent object for starting a span in a global context.
Applies precedence: current span > propagated parent string > experiment > logger."""

Expand Down Expand Up @@ -4857,7 +4856,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None:
def render_message(render: Callable[[str], str], message: PromptMessage):
base = {k: v for (k, v) in message.as_dict().items() if v is not None}
# TODO: shouldn't load_prompt guarantee content is a PromptMessage?
content = cast(Union[str, list[Union[TextPart, ImagePart]], dict[str, Any]], message.content)
content = cast(str | list[TextPart | ImagePart] | dict[str, Any], message.content)
if content is not None:
if isinstance(content, str):
base["content"] = render(content)
Expand Down Expand Up @@ -5618,9 +5617,7 @@ def __str__(self):


class TracedThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor):
# Returns Any because Future[T] generic typing was stabilized in Python 3.9,
# but we maintain compatibility with older type checkers.
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future[Any]:
# Capture all current context variables
context = contextvars.copy_context()

Expand Down
Loading