diff --git a/py/src/braintrust/bt_json.py b/py/src/braintrust/bt_json.py index 76dadb01..e6035d6b 100644 --- a/py/src/braintrust/bt_json.py +++ b/py/src/braintrust/bt_json.py @@ -2,7 +2,8 @@ import json import math import warnings -from typing import Any, Callable, Mapping, NamedTuple, cast, overload +from collections.abc import Callable, Mapping +from typing import Any, NamedTuple, cast, overload # Try to import orjson for better performance diff --git a/py/src/braintrust/contrib/temporal/test_temporal.py b/py/src/braintrust/contrib/temporal/test_temporal.py index 8ed87264..03999310 100644 --- a/py/src/braintrust/contrib/temporal/test_temporal.py +++ b/py/src/braintrust/contrib/temporal/test_temporal.py @@ -4,7 +4,7 @@ import uuid from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict +from typing import Any import pytest import pytest_asyncio @@ -31,7 +31,7 @@ class TestHeaderSerialization: def test_span_context_to_headers_with_valid_context(self): interceptor = BraintrustInterceptor() span_context = {"trace_id": "test-trace-id", "span_id": "test-span-id"} - headers: Dict[str, temporalio.api.common.v1.Payload] = {} + headers: dict[str, temporalio.api.common.v1.Payload] = {} result_headers = interceptor._span_context_to_headers(span_context, headers) @@ -40,8 +40,8 @@ def test_span_context_to_headers_with_valid_context(self): def test_span_context_to_headers_with_empty_context(self): interceptor = BraintrustInterceptor() - span_context: Dict[str, Any] = {} - headers: Dict[str, temporalio.api.common.v1.Payload] = {} + span_context: dict[str, Any] = {} + headers: dict[str, temporalio.api.common.v1.Payload] = {} result_headers = interceptor._span_context_to_headers(span_context, headers) @@ -78,7 +78,7 @@ def test_span_context_from_headers_with_valid_header(self): def test_span_context_from_headers_with_missing_header(self): interceptor = BraintrustInterceptor() - headers: Dict[str, temporalio.api.common.v1.Payload] = {} + headers: dict[str, temporalio.api.common.v1.Payload] = {} result = interceptor._span_context_from_headers(headers) diff --git a/py/src/braintrust/devserver/schemas.py b/py/src/braintrust/devserver/schemas.py index cd8f49da..a359a93d 100644 --- a/py/src/braintrust/devserver/schemas.py +++ b/py/src/braintrust/devserver/schemas.py @@ -1,8 +1,6 @@ import json from collections.abc import Sequence -from typing import Any, Union, get_args, get_origin, get_type_hints - -from typing_extensions import TypedDict +from typing import Any, TypedDict, Union, get_args, get_origin, get_type_hints # This is not beautiful code, but it saves us from introducing Pydantic as a dependency, and it is fairly diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 5ad20a25..554e4007 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -17,13 +17,13 @@ Any, Generic, Literal, + Protocol, + TypedDict, TypeVar, - Union, ) from tqdm.asyncio import tqdm as async_tqdm from tqdm.auto import tqdm as std_tqdm -from typing_extensions import Protocol, TypedDict from .generated_types import FunctionFormat, FunctionOutputType, ObjectReference from .git_fields import GitMetadataSettings, RepoInfo @@ -216,13 +216,8 @@ class EvalScorerArgs(SerializableDataClass, Generic[Input, Output, Expected]): metadata: Metadata | None = None -OneOrMoreScores = Union[float, int, bool, None, Score, list[Score]] -OneOrMoreClassifications = Union[ - None, - Classification, - Mapping[str, Any], - list[Classification | Mapping[str, Any]], -] +OneOrMoreScores = float | int | bool | None | Score | list[Score] +OneOrMoreClassifications = None | Classification | Mapping[str, Any] | list[Classification | Mapping[str, Any]] # Synchronous scorer interface - implements callable @@ -247,20 +242,19 @@ class AsyncScorerLike(Protocol, Generic[Input, Output, Expected]): async def eval_async(self, output: Output, expected: Expected | None = None, **kwargs: Any) -> OneOrMoreScores: ... -# Union type for any kind of scorer (for typing) -ScorerLike = Union[SyncScorerLike[Input, Output, Expected], AsyncScorerLike[Input, Output, Expected]] +ScorerLike = SyncScorerLike[Input, Output, Expected] | AsyncScorerLike[Input, Output, Expected] -EvalScorer = Union[ - ScorerLike[Input, Output, Expected], - type[ScorerLike[Input, Output, Expected]], - Callable[[Input, Output, Expected], OneOrMoreScores], - Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]], -] +EvalScorer = ( + ScorerLike[Input, Output, Expected] + | type[ScorerLike[Input, Output, Expected]] + | Callable[[Input, Output, Expected], OneOrMoreScores] + | Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]] +) -EvalClassifier = Union[ - Callable[[Input, Output, Expected], OneOrMoreClassifications], - Callable[[Input, Output, Expected], Awaitable[OneOrMoreClassifications]], -] +EvalClassifier = ( + Callable[[Input, Output, Expected], OneOrMoreClassifications] + | Callable[[Input, Output, Expected], Awaitable[OneOrMoreClassifications]] +) @dataclasses.dataclass @@ -278,27 +272,23 @@ class BaseExperiment: """ -_AnyEvalCase = Union[ - EvalCase[Input, Expected], - EvalCaseDict[Input, Expected], - EvalCaseDictNoOutput[Input], - ExperimentDatasetEvent, -] +_AnyEvalCase = ( + EvalCase[Input, Expected] | EvalCaseDict[Input, Expected] | EvalCaseDictNoOutput[Input] | ExperimentDatasetEvent +) -_EvalDataObject = Union[ - Iterable[_AnyEvalCase[Input, Expected]], - Iterator[_AnyEvalCase[Input, Expected]], - Awaitable[Iterator[_AnyEvalCase[Input, Expected]]], - Callable[[], Union[Iterator[_AnyEvalCase[Input, Expected]], Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]]], - BaseExperiment, -] +_EvalDataObject = ( + Iterable[_AnyEvalCase[Input, Expected]] + | Iterator[_AnyEvalCase[Input, Expected]] + | Awaitable[Iterator[_AnyEvalCase[Input, Expected]]] + | Callable[[], Iterator[_AnyEvalCase[Input, Expected]] | Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]] + | BaseExperiment +) -EvalData = Union[_EvalDataObject[Input, Expected], type[_EvalDataObject[Input, Expected]], Dataset] +EvalData = _EvalDataObject[Input, Expected] | type[_EvalDataObject[Input, Expected]] | Dataset -EvalTask = Union[ - Callable[[Input], Union[Output, Awaitable[Output]]], - Callable[[Input, EvalHooks[Expected]], Union[Output, Awaitable[Output]]], -] +EvalTask = ( + Callable[[Input], Output | Awaitable[Output]] | Callable[[Input, EvalHooks[Expected]], Output | Awaitable[Output]] +) ErrorScoreHandler = Callable[[Span, EvalCase[Input, Expected], Sequence[str]], dict[str, float] | None] diff --git a/py/src/braintrust/functions/stream.py b/py/src/braintrust/functions/stream.py index ba651315..d3a6b84d 100644 --- a/py/src/braintrust/functions/stream.py +++ b/py/src/braintrust/functions/stream.py @@ -9,7 +9,7 @@ import json from collections.abc import Generator, Iterable from itertools import tee -from typing import Literal, Union +from typing import Literal from sseclient import SSEClient @@ -79,13 +79,9 @@ class BraintrustInvokeError(ValueError): pass -BraintrustStreamChunk = Union[ - BraintrustTextChunk, - BraintrustJsonChunk, - BraintrustErrorChunk, - BraintrustConsoleChunk, - BraintrustProgressChunk, -] +BraintrustStreamChunk = ( + BraintrustTextChunk | BraintrustJsonChunk | BraintrustErrorChunk | BraintrustConsoleChunk | BraintrustProgressChunk +) class BraintrustStream: diff --git a/py/src/braintrust/integrations/claude_agent_sdk/_constants.py b/py/src/braintrust/integrations/claude_agent_sdk/_constants.py index fa228a34..f8757deb 100644 --- a/py/src/braintrust/integrations/claude_agent_sdk/_constants.py +++ b/py/src/braintrust/integrations/claude_agent_sdk/_constants.py @@ -1,7 +1,8 @@ +from collections.abc import Mapping from dataclasses import dataclass from enum import Enum from types import MappingProxyType -from typing import Final, Mapping +from typing import Final class MessageClassName(str, Enum): diff --git a/py/src/braintrust/integrations/langchain/callbacks.py b/py/src/braintrust/integrations/langchain/callbacks.py index 50650da0..a80a5625 100644 --- a/py/src/braintrust/integrations/langchain/callbacks.py +++ b/py/src/braintrust/integrations/langchain/callbacks.py @@ -7,7 +7,6 @@ from typing import ( Any, TypedDict, - Union, ) from uuid import UUID @@ -528,7 +527,7 @@ def on_llm_new_token( self, token: str, *, - chunk: Union["GenerationChunk", "ChatGenerationChunk"] | None = None, # type: ignore + chunk: "GenerationChunk | ChatGenerationChunk | None" = None, # type: ignore run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any, diff --git a/py/src/braintrust/integrations/langchain/helpers.py b/py/src/braintrust/integrations/langchain/helpers.py index f75b96db..7271ef37 100644 --- a/py/src/braintrust/integrations/langchain/helpers.py +++ b/py/src/braintrust/integrations/langchain/helpers.py @@ -1,4 +1,5 @@ -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from unittest.mock import ANY diff --git a/py/src/braintrust/integrations/langchain/test_callbacks.py b/py/src/braintrust/integrations/langchain/test_callbacks.py index da7181d3..adeaa37d 100644 --- a/py/src/braintrust/integrations/langchain/test_callbacks.py +++ b/py/src/braintrust/integrations/langchain/test_callbacks.py @@ -3,7 +3,7 @@ import time import uuid from pathlib import Path -from typing import Dict, List, Union, cast +from typing import cast import pytest from braintrust import logger @@ -56,7 +56,7 @@ def test_llm_calls(logger_memory_logger): presence_penalty=0, n=1, ) - chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model) + chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model) chain.invoke({"number": "2"}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}) spans = memory_logger.pop() @@ -159,7 +159,7 @@ def test_chain_with_memory(logger_memory_logger): handler = BraintrustCallbackHandler(logger=test_logger) prompt = ChatPromptTemplate.from_template("{history} User: {input}") model = ChatOpenAI(model="gpt-4o-mini") - chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model) + chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model) memory = {"history": "Assistant: Hello! How can I assist you today?"} chain.invoke( @@ -399,7 +399,7 @@ def test_parallel_execution(logger_memory_logger): map_chain.invoke({"topic": "bear"}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}) - spans = cast(List, memory_logger.pop()) + spans = cast(list, memory_logger.pop()) # Find the LLM spans llm_spans = find_spans_by_attributes(spans, name="ChatOpenAI") @@ -480,16 +480,16 @@ def test_langgraph_state_management(logger_memory_logger): n=1, ) - def say_hello(state: Dict[str, str]): + def say_hello(state: dict[str, str]): response = model.invoke("Say hello") - return cast(Union[str, List[str], Dict[str, str]], response.content) + return cast(str | list[str] | dict[str, str], response.content) - def say_bye(state: Dict[str, str]): + def say_bye(state: dict[str, str]): print("From the 'sayBye' node: Bye world!") return "Bye" workflow = ( - StateGraph(state_schema=Dict[str, str]) + StateGraph(state_schema=dict[str, str]) .add_node("sayHello", say_hello) .add_node("sayBye", say_bye) .add_edge(START, "sayHello") @@ -837,10 +837,10 @@ def test_streaming_ttft(logger_memory_logger): max_completion_tokens=50, streaming=True, ) - chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model) + chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model) # Collect chunks to verify streaming works - chunks: List[str] = [] + chunks: list[str] = [] for chunk in chain.stream({}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}): if chunk.content: chunks.append(str(chunk.content)) @@ -1272,9 +1272,9 @@ async def test_async_streaming(logger_memory_logger): handler = BraintrustCallbackHandler(logger=test_logger) prompt = ChatPromptTemplate.from_template("Count from 1 to 3.") model = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20, streaming=True) - chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model) + chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model) - chunks: List[str] = [] + chunks: list[str] = [] async for chunk in chain.astream({}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}): if chunk.content: chunks.append(str(chunk.content)) diff --git a/py/src/braintrust/integrations/langchain/test_context.py b/py/src/braintrust/integrations/langchain/test_context.py index 2076c4b3..37cf552a 100644 --- a/py/src/braintrust/integrations/langchain/test_context.py +++ b/py/src/braintrust/integrations/langchain/test_context.py @@ -1,5 +1,4 @@ # pyright: reportTypedDictNotRequiredAccess=none -from typing import Dict from unittest.mock import ANY import pytest @@ -57,7 +56,7 @@ def test_global_handler(logger_memory_logger): presence_penalty=0, n=1, ) - chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model) + chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model) message = chain.invoke({"number": "2"}) diff --git a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_wrap_openai.py b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_wrap_openai.py index c1dfceb3..dc112b6c 100644 --- a/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_wrap_openai.py +++ b/py/src/braintrust/integrations/pydantic_ai/test_pydantic_ai_wrap_openai.py @@ -1,5 +1,5 @@ import time -from typing import Any, Dict +from typing import Any import pytest from openai import AsyncOpenAI @@ -61,7 +61,7 @@ def memory_logger(): yield bgl -def _assert_metrics_are_valid(metrics: Dict[str, Any]): +def _assert_metrics_are_valid(metrics: dict[str, Any]): assert metrics["tokens"] > 0 assert metrics["prompt_tokens"] > 0 assert metrics["completion_tokens"] > 0 diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 32cd554b..6ed497a1 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -28,7 +28,6 @@ Literal, TypedDict, TypeVar, - Union, cast, overload, ) @@ -1600,7 +1599,7 @@ def init( base_experiment_id: str | None = None, repo_info: RepoInfo | None = None, state: BraintrustState | None = None, -) -> Union["Experiment", "ReadonlyExperiment"]: +) -> "Experiment | ReadonlyExperiment": """ Log in, and then initialize a new experiment in a specified project. If the project does not exist, it will be created. @@ -1767,7 +1766,7 @@ def compute_metadata(): return ret -def init_experiment(*args, **kwargs) -> Union["Experiment", "ReadonlyExperiment"]: +def init_experiment(*args, **kwargs) -> "Experiment | ReadonlyExperiment": """Alias for `init`""" return init(*args, **kwargs) @@ -2392,7 +2391,7 @@ def parent_context(parent: str | None, state: BraintrustState | None = None): def get_span_parent_object( parent: str | None = None, state: BraintrustState | None = None -) -> Union[SpanComponentsV4, "Logger", "Experiment", Span]: +) -> "SpanComponentsV4 | Logger | Experiment | Span": """Mainly for internal use. Return the parent object for starting a span in a global context. Applies precedence: current span > propagated parent string > experiment > logger.""" @@ -4857,7 +4856,7 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: def render_message(render: Callable[[str], str], message: PromptMessage): base = {k: v for (k, v) in message.as_dict().items() if v is not None} # TODO: shouldn't load_prompt guarantee content is a PromptMessage? - content = cast(Union[str, list[Union[TextPart, ImagePart]], dict[str, Any]], message.content) + content = cast(str | list[TextPart | ImagePart] | dict[str, Any], message.content) if content is not None: if isinstance(content, str): base["content"] = render(content) @@ -5618,9 +5617,7 @@ def __str__(self): class TracedThreadPoolExecutor(concurrent.futures.ThreadPoolExecutor): - # Returns Any because Future[T] generic typing was stabilized in Python 3.9, - # but we maintain compatibility with older type checkers. - def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future[Any]: # Capture all current context variables context = contextvars.copy_context() diff --git a/py/src/braintrust/merge_row_batch.py b/py/src/braintrust/merge_row_batch.py index c9047775..066744ce 100644 --- a/py/src/braintrust/merge_row_batch.py +++ b/py/src/braintrust/merge_row_batch.py @@ -1,5 +1,5 @@ from collections.abc import Callable, Sequence -from typing import Any, Optional, TypeVar +from typing import Any, TypeVar from .db_fields import IS_MERGE_FIELD @@ -8,7 +8,7 @@ from .util import merge_dicts -_MergedRowKey = tuple[Optional[Any], ...] +_MergedRowKey = tuple[Any | None, ...] def _generate_merged_row_key(row: dict[str, Any]) -> _MergedRowKey: diff --git a/py/src/braintrust/otel/context.py b/py/src/braintrust/otel/context.py index bb65be77..ea9703bc 100644 --- a/py/src/braintrust/otel/context.py +++ b/py/src/braintrust/otel/context.py @@ -1,7 +1,7 @@ """Unified context management using OTEL's built-in context.""" import logging -from typing import Any, Optional +from typing import Any from braintrust.context import ParentSpanIds, SpanInfo from braintrust.logger import Span @@ -18,7 +18,7 @@ class ContextManager: def __init__(self): pass - def get_current_span_info(self) -> Optional["SpanInfo"]: + def get_current_span_info(self) -> "SpanInfo | None": """Get information about the currently active span from OTEL context.""" # Get the current span from OTEL context diff --git a/py/src/braintrust/prompt.py b/py/src/braintrust/prompt.py index d4b7fa19..e9f413b7 100644 --- a/py/src/braintrust/prompt.py +++ b/py/src/braintrust/prompt.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal, Union +from typing import Literal from .generated_types import PromptOptions from .serializable_data_class import SerializableDataClass @@ -61,7 +61,7 @@ class PromptChatBlock(SerializableDataClass): type: Literal["chat"] = "chat" -PromptBlockData = Union[PromptCompletionBlock, PromptChatBlock] +PromptBlockData = PromptCompletionBlock | PromptChatBlock @dataclass diff --git a/py/src/braintrust/serializable_data_class.py b/py/src/braintrust/serializable_data_class.py index 8f9eeefc..ba32ecca 100644 --- a/py/src/braintrust/serializable_data_class.py +++ b/py/src/braintrust/serializable_data_class.py @@ -1,5 +1,6 @@ import dataclasses import json +import types from typing import Union, get_origin @@ -39,7 +40,7 @@ def from_dict_deep(cls, d: dict): and issubclass(fields[k].type, SerializableDataClass) ): filtered[k] = fields[k].type.from_dict_deep(v) - elif get_origin(fields[k].type) == Union: + elif get_origin(fields[k].type) is Union or isinstance(fields[k].type, types.UnionType): for t in fields[k].type.__args__: if t == type(None) and v is None: filtered[k] = None diff --git a/py/src/braintrust/span_cache.py b/py/src/braintrust/span_cache.py index f3248d0c..ee926614 100644 --- a/py/src/braintrust/span_cache.py +++ b/py/src/braintrust/span_cache.py @@ -11,7 +11,7 @@ import os import tempfile import uuid -from typing import Any, Optional +from typing import Any from braintrust.types import Metadata from braintrust.util import merge_dicts @@ -28,11 +28,11 @@ class CachedSpan: def __init__( self, span_id: str, - input: Optional[Any] = None, - output: Optional[Any] = None, + input: Any | None = None, + output: Any | None = None, metadata: Metadata | None = None, - span_parents: Optional[list[str]] = None, - span_attributes: Optional[dict[str, Any]] = None, + span_parents: list[str] | None = None, + span_attributes: dict[str, Any] | None = None, ): self.span_id = span_id self.input = input @@ -104,7 +104,7 @@ class SpanCache: """ def __init__(self, disabled: bool = False): - self._cache_file_path: Optional[str] = None + self._cache_file_path: str | None = None self._initialized = False # Tracks whether the cache was explicitly disabled (via constructor or disable()) self._explicitly_disabled = disabled @@ -226,7 +226,7 @@ def _flush_write_buffer(self) -> None: # This can happen if disk is full or file permissions changed pass - def get_by_root_span_id(self, root_span_id: str) -> Optional[list[CachedSpan]]: + def get_by_root_span_id(self, root_span_id: str) -> list[CachedSpan] | None: """ Get all cached spans for a given rootSpanId. diff --git a/py/src/braintrust/test_context.py b/py/src/braintrust/test_context.py index 17b3c6d1..80e47048 100644 --- a/py/src/braintrust/test_context.py +++ b/py/src/braintrust/test_context.py @@ -24,7 +24,8 @@ def _threadpool_scenario(test_logger, with_memory_logger): import subprocess import sys import threading -from typing import AsyncGenerator, Callable, Generator, TypeVar +from collections.abc import AsyncGenerator, Callable, Generator +from typing import TypeVar import braintrust import pytest @@ -277,7 +278,6 @@ async def async_worker(): assert parent_log["span_id"] in worker_log.get("span_parents", []), "Worker should have parent as parent" -@pytest.mark.skipif(sys.version_info < (3, 9), reason="to_thread requires Python 3.9+") @pytest.mark.asyncio async def test_to_thread_preserves_context(test_logger, with_memory_logger): """ diff --git a/py/src/braintrust/test_framework.py b/py/src/braintrust/test_framework.py index 6368cee3..608b7585 100644 --- a/py/src/braintrust/test_framework.py +++ b/py/src/braintrust/test_framework.py @@ -1,6 +1,5 @@ import importlib.util import re -from typing import List from unittest.mock import MagicMock import pytest @@ -213,7 +212,7 @@ def _run_eval_sync(self, *args, **kwargs): @pytest.mark.asyncio async def test_hooks_trial_index(): """Test that trial_index is correctly passed to task via hooks.""" - trial_indices: List[int] = [] + trial_indices: list[int] = [] # Task that captures trial indices def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: @@ -253,7 +252,7 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: @pytest.mark.asyncio async def test_hooks_trial_index_multiple_inputs(): """Test trial_index with multiple inputs to ensure proper indexing.""" - trial_data: List[tuple] = [] # (input, trial_index) + trial_data: list[tuple] = [] # (input, trial_index) def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: trial_data.append((input_value, hooks.trial_index)) @@ -293,7 +292,7 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: @pytest.mark.asyncio async def test_per_input_trial_count_overrides_global(): """Test that per-input trial_count overrides the global trial_count.""" - trial_data: List[tuple] = [] # (input, trial_index) + trial_data: list[tuple] = [] # (input, trial_index) def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: trial_data.append((input_value, hooks.trial_index)) @@ -332,7 +331,7 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: @pytest.mark.asyncio async def test_per_input_trial_count_without_global(): """Test that per-input trial_count works when no global trial_count is set.""" - trial_data: List[tuple] = [] # (input, trial_index) + trial_data: list[tuple] = [] # (input, trial_index) def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: trial_data.append((input_value, hooks.trial_index)) @@ -367,7 +366,7 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: @pytest.mark.asyncio async def test_per_input_trial_count_with_dict_data(): """Test that per-input trial_count works when data items are plain dicts.""" - trial_data: List[tuple] = [] # (input, trial_index) + trial_data: list[tuple] = [] # (input, trial_index) def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: trial_data.append((input_value, hooks.trial_index)) diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 7662ad77..e8c22bdc 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -5,7 +5,7 @@ import logging import os import time -from typing import AsyncGenerator, List +from collections.abc import AsyncGenerator from unittest import TestCase from unittest.mock import MagicMock, patch @@ -384,7 +384,7 @@ def test_load_parameters_prefers_version_over_environment_for_id(self): assert "environment" not in mock_api_conn.get_json.call_args.args[1] def test_extract_attachments_no_op(self): - attachments: List[BaseAttachment] = [] + attachments: list[BaseAttachment] = [] _extract_attachments({}, attachments) self.assertEqual(len(attachments), 0) @@ -441,7 +441,7 @@ def test_extract_attachments_with_attachments(self): } saved_nested = event["nested"] - attachments: List[BaseAttachment] = [] + attachments: list[BaseAttachment] = [] _extract_attachments(event, attachments) self.assertEqual( @@ -3135,7 +3135,7 @@ def test_extract_attachments_with_json_attachment(self): }, } - attachments: List[BaseAttachment] = [] + attachments: list[BaseAttachment] = [] _extract_attachments(event, attachments) self.assertEqual(len(attachments), 1) diff --git a/py/src/braintrust/test_serializable_data_class.py b/py/src/braintrust/test_serializable_data_class.py index 0cade6ae..e31e2078 100644 --- a/py/src/braintrust/test_serializable_data_class.py +++ b/py/src/braintrust/test_serializable_data_class.py @@ -1,14 +1,13 @@ import unittest from dataclasses import dataclass -from typing import List, Optional from .serializable_data_class import SerializableDataClass @dataclass class PromptData(SerializableDataClass): - prompt: Optional[str] = None - options: Optional[dict] = None + prompt: str | None = None + options: dict | None = None @dataclass @@ -18,9 +17,9 @@ class PromptSchema(SerializableDataClass): _xact_id: str name: str slug: str - description: Optional[str] + description: str | None prompt_data: PromptData - tags: Optional[List[str]] + tags: list[str] | None class TestSerializableDataClass(unittest.TestCase): diff --git a/py/src/braintrust/test_util.py b/py/src/braintrust/test_util.py index 90f18602..0dd27568 100644 --- a/py/src/braintrust/test_util.py +++ b/py/src/braintrust/test_util.py @@ -1,6 +1,5 @@ import os import unittest -from typing import List import pytest @@ -129,9 +128,9 @@ def compute_value(): lazy = LazyValue(compute_value, use_mutex=True) # Launch multiple threads that all try to get() simultaneously - threads: List[threading.Thread] = [] - results: List[str] = [] - errors: List[Exception] = [] + threads: list[threading.Thread] = [] + results: list[str] = [] + errors: list[Exception] = [] def worker(): try: diff --git a/py/src/braintrust/trace.py b/py/src/braintrust/trace.py index d3426ac4..24bcefa2 100644 --- a/py/src/braintrust/trace.py +++ b/py/src/braintrust/trace.py @@ -6,7 +6,8 @@ """ import asyncio -from typing import Any, Awaitable, Callable, Optional, Protocol, TypedDict +from collections.abc import Awaitable, Callable +from typing import Any, Protocol, TypedDict from braintrust.functions.invoke import invoke from braintrust.logger import BraintrustState, ObjectFetcher @@ -18,12 +19,12 @@ class SpanData: def __init__( self, - input: Optional[Any] = None, - output: Optional[Any] = None, + input: Any | None = None, + output: Any | None = None, metadata: Metadata | None = None, - span_id: Optional[str] = None, - span_parents: Optional[list[str]] = None, - span_attributes: Optional[dict[str, Any]] = None, + span_id: str | None = None, + span_parents: list[str] | None = None, + span_attributes: dict[str, Any] | None = None, **kwargs: Any, ): self.input = input @@ -62,7 +63,7 @@ def __init__( object_id: str, root_span_id: str, state: BraintrustState, - span_type_filter: Optional[list[str]] = None, + span_type_filter: list[str] | None = None, ): # Build the filter expression for root_span_id and optionally span_attributes.type filter_expr = self._build_filter(root_span_id, span_type_filter) @@ -75,7 +76,7 @@ def __init__( self._state = state @staticmethod - def _build_filter(root_span_id: str, span_type_filter: Optional[list[str]] = None) -> dict[str, Any]: + def _build_filter(root_span_id: str, span_type_filter: list[str] | None = None) -> dict[str, Any]: """Build BTQL filter expression.""" children = [ # Base filter: root_span_id = 'value' @@ -121,7 +122,7 @@ def _get_state(self) -> BraintrustState: return self._state -SpanFetchFn = Callable[[Optional[list[str]]], Awaitable[list[SpanData]]] +SpanFetchFn = Callable[[list[str] | None], Awaitable[list[SpanData]]] class GetThreadOptions(TypedDict, total=False): @@ -140,11 +141,11 @@ class CachedSpanFetcher: def __init__( self, - object_type: Optional[str] = None, # Literal["experiment", "project_logs", "playground_logs"] - object_id: Optional[str] = None, - root_span_id: Optional[str] = None, - get_state: Optional[Callable[[], Awaitable[BraintrustState]]] = None, - fetch_fn: Optional[SpanFetchFn] = None, + object_type: str | None = None, # Literal["experiment", "project_logs", "playground_logs"] + object_id: str | None = None, + root_span_id: str | None = None, + get_state: Callable[[], Awaitable[BraintrustState]] | None = None, + fetch_fn: SpanFetchFn | None = None, ): self._span_cache: dict[str, list[SpanData]] = {} self._all_fetched = False @@ -159,7 +160,7 @@ def __init__( "Must provide either fetch_fn or all of object_type, object_id, root_span_id, get_state" ) - async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]: + async def _fetch_fn(span_type: list[str] | None) -> list[SpanData]: state = await get_state() fetcher = SpanFetcher( object_type=object_type, @@ -196,7 +197,7 @@ async def _fetch_fn(span_type: Optional[list[str]]) -> list[SpanData]: self._fetch_fn = _fetch_fn - async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: """ Get spans, using cache when possible. @@ -228,7 +229,7 @@ async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanDat await self._fetch_spans(missing_types) return self._get_from_cache(span_type) - async def _fetch_spans(self, span_type: Optional[list[str]]) -> None: + async def _fetch_spans(self, span_type: list[str] | None) -> None: """Fetch spans from the server.""" spans = await self._fetch_fn(span_type) @@ -239,7 +240,7 @@ async def _fetch_spans(self, span_type: Optional[list[str]]) -> None: self._span_cache[span_type_str] = [] self._span_cache[span_type_str].append(span) - def _get_from_cache(self, span_type: Optional[list[str]]) -> list[SpanData]: + def _get_from_cache(self, span_type: list[str] | None) -> list[SpanData]: """Get spans from cache, optionally filtering by type.""" if not span_type or len(span_type) == 0: # Return all spans @@ -266,7 +267,7 @@ def get_configuration(self) -> dict[str, str]: """Get the trace configuration (object_type, object_id, root_span_id).""" ... - async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: """ Fetch all spans for this root span. @@ -307,7 +308,7 @@ def __init__( object_type: str, # Literal["experiment", "project_logs", "playground_logs"] object_id: str, root_span_id: str, - ensure_spans_flushed: Optional[Callable[[], Awaitable[None]]], + ensure_spans_flushed: Callable[[], Awaitable[None]] | None, state: BraintrustState, ): # Initialize dict with trace_ref for JSON serialization @@ -327,7 +328,7 @@ def __init__( self._ensure_spans_flushed = ensure_spans_flushed self._state = state self._spans_flushed = False - self._spans_flush_promise: Optional[asyncio.Task[None]] = None + self._spans_flush_promise: asyncio.Task[None] | None = None self._thread_cache: dict[str, asyncio.Task[list[Any]]] = {} async def get_state() -> BraintrustState: @@ -351,7 +352,7 @@ def get_configuration(self) -> dict[str, str]: "root_span_id": self._root_span_id, } - async def get_spans(self, span_type: Optional[list[str]] = None) -> list[SpanData]: + async def get_spans(self, span_type: list[str] | None = None) -> list[SpanData]: """ Fetch all rows for this root span from its parent object (experiment or project logs). First checks the local span cache for recently logged spans, then falls diff --git a/py/src/braintrust/types/_eval.py b/py/src/braintrust/types/_eval.py index 0f5be193..528df11e 100644 --- a/py/src/braintrust/types/_eval.py +++ b/py/src/braintrust/types/_eval.py @@ -5,7 +5,8 @@ underscore-prefixed so pyright strict mode doesn't flag them as private. """ -from typing import Any, Generic, Sequence, TypeVar +from collections.abc import Sequence +from typing import Any, Generic, TypeVar from typing_extensions import NotRequired, TypedDict diff --git a/py/src/braintrust/util.py b/py/src/braintrust/util.py index 3541cb5f..7fdb8abb 100644 --- a/py/src/braintrust/util.py +++ b/py/src/braintrust/util.py @@ -7,7 +7,7 @@ import urllib.parse from collections.abc import Callable, Mapping from dataclasses import dataclass -from typing import Any, Generic, Literal, TypedDict, TypeVar, Union +from typing import Any, Generic, Literal, TypedDict, TypeVar from requests import HTTPError, Response @@ -179,7 +179,7 @@ class _LazyValuePendingState: has_succeeded: Literal[False] = False -_LazyValueState = Union[_LazyValueResolvedState[T], _LazyValuePendingState] +_LazyValueState = _LazyValueResolvedState[T] | _LazyValuePendingState class LazyValue(Generic[T]): diff --git a/py/src/braintrust/wrappers/langsmith_wrapper.py b/py/src/braintrust/wrappers/langsmith_wrapper.py index b22117df..83e1fe01 100644 --- a/py/src/braintrust/wrappers/langsmith_wrapper.py +++ b/py/src/braintrust/wrappers/langsmith_wrapper.py @@ -40,7 +40,8 @@ def my_function(inputs: dict) -> dict: import inspect import logging import os -from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, ParamSpec, TypeVar +from collections.abc import Callable, Iterable, Iterator +from typing import Any, ParamSpec, TypeVar from braintrust.framework import EvalCase from braintrust.logger import NOOP_SPAN, current_span, init_logger, traced @@ -50,7 +51,7 @@ def my_function(inputs: dict) -> dict: logger = logging.getLogger(__name__) # Global list to store Braintrust eval results when running in tandem mode -_braintrust_eval_results: List[Any] = [] +_braintrust_eval_results: list[Any] = [] # TODO: langsmith.test/unit/expect, langsmith.AsyncClient, trace __all__ = [ @@ -68,7 +69,7 @@ def my_function(inputs: dict) -> dict: R = TypeVar("R") -def get_braintrust_results() -> List[Any]: +def get_braintrust_results() -> list[Any]: """Get all Braintrust eval results collected during tandem mode.""" return _braintrust_eval_results.copy() @@ -79,9 +80,9 @@ def clear_braintrust_results() -> None: def setup_langsmith( - api_key: Optional[str] = None, - project_id: Optional[str] = None, - project_name: Optional[str] = None, + api_key: str | None = None, + project_id: str | None = None, + project_name: str | None = None, standalone: bool = False, ) -> bool: """ @@ -169,7 +170,7 @@ def decorator(fn: Callable[P, R]) -> Callable[P, R]: def wrap_client( - Client: Any, project_name: Optional[str] = None, project_id: Optional[str] = None, standalone: bool = False + Client: Any, project_name: str | None = None, project_id: str | None = None, standalone: bool = False ) -> Any: """ Wrap langsmith.Client to redirect evaluate() and aevaluate() to Braintrust's Eval. @@ -203,9 +204,7 @@ def wrap_client( return Client -def make_evaluate_wrapper( - *, project_name: Optional[str] = None, project_id: Optional[str] = None, standalone: bool = False -): +def make_evaluate_wrapper(*, project_name: str | None = None, project_id: str | None = None, standalone: bool = False): def evaluate_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: result = None if not standalone: @@ -231,7 +230,7 @@ def evaluate_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any def make_aevaluate_wrapper( - *, project_name: Optional[str] = None, project_id: Optional[str] = None, standalone: bool = False + *, project_name: str | None = None, project_id: str | None = None, standalone: bool = False ): async def aevaluate_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: result = None @@ -258,7 +257,7 @@ async def aevaluate_wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) def wrap_evaluate( - evaluate: F, project_name: Optional[str] = None, project_id: Optional[str] = None, standalone: bool = False + evaluate: F, project_name: str | None = None, project_id: str | None = None, standalone: bool = False ) -> F: """ Wrap module-level langsmith.evaluate to redirect to Braintrust's Eval. @@ -282,8 +281,8 @@ def wrap_evaluate( def wrap_aevaluate( aevaluate: F, - project_name: Optional[str] = None, - project_id: Optional[str] = None, + project_name: str | None = None, + project_id: str | None = None, standalone: bool = False, ) -> F: """ @@ -318,8 +317,8 @@ def _is_patched(obj: Any) -> bool: def _run_braintrust_eval( args: Any, kwargs: Any, - project_name: Optional[str] = None, - project_id: Optional[str] = None, + project_name: str | None = None, + project_id: str | None = None, ) -> Any: """Run Braintrust Eval with LangSmith-style arguments.""" from braintrust.framework import Eval @@ -356,8 +355,8 @@ def _run_braintrust_eval( async def _run_braintrust_eval_async( args: Any, kwargs: Any, - project_name: Optional[str] = None, - project_id: Optional[str] = None, + project_name: str | None = None, + project_id: str | None = None, ) -> Any: """Run Braintrust EvalAsync with LangSmith-style arguments.""" from braintrust.framework import EvalAsync @@ -396,7 +395,7 @@ async def _run_braintrust_eval_async( # ============================================================================= -def _wrap_output(output: Any) -> Dict[str, Any]: +def _wrap_output(output: Any) -> dict[str, Any]: """Wrap non-dict outputs the same way LangSmith does.""" if not isinstance(output, dict): return {"output": output} @@ -413,7 +412,7 @@ def _make_braintrust_scorer( """ evaluator_name = getattr(evaluator, "__name__", "score") - def braintrust_scorer(input: Any, output: Any, expected: Optional[Any] = None, **kwargs: Any) -> Any: + def braintrust_scorer(input: Any, output: Any, expected: Any | None = None, **kwargs: Any) -> Any: from braintrust.score import Score # Run the evaluator with LangSmith's signature diff --git a/py/src/braintrust/xact_ids.py b/py/src/braintrust/xact_ids.py index 0327066e..1231e3bd 100644 --- a/py/src/braintrust/xact_ids.py +++ b/py/src/braintrust/xact_ids.py @@ -9,9 +9,7 @@ def modular_multiply(value: int, prime: int): return (value * prime) % MOD -# value : int | str -# Cannot use a | because of python 3.8 -def prettify_xact(value) -> str: +def prettify_xact(value: int | str) -> str: encoded = modular_multiply(int(value), COPRIME) return hex(encoded)[2:].rjust(16, "0")