Skip to content

Commit 0269b68

Browse files
chore: modernize typing for Python 3.10+ minimum (#356)
1 parent 3308c56 commit 0269b68

27 files changed

Lines changed: 146 additions & 167 deletions

py/src/braintrust/bt_json.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import json
33
import math
44
import warnings
5-
from typing import Any, Callable, Mapping, NamedTuple, cast, overload
5+
from collections.abc import Callable, Mapping
6+
from typing import Any, NamedTuple, cast, overload
67

78

89
# Try to import orjson for better performance

py/src/braintrust/contrib/temporal/test_temporal.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from dataclasses import dataclass
66
from datetime import timedelta
7-
from typing import Any, Dict
7+
from typing import Any
88

99
import pytest
1010
import pytest_asyncio
@@ -31,7 +31,7 @@ class TestHeaderSerialization:
3131
def test_span_context_to_headers_with_valid_context(self):
3232
interceptor = BraintrustInterceptor()
3333
span_context = {"trace_id": "test-trace-id", "span_id": "test-span-id"}
34-
headers: Dict[str, temporalio.api.common.v1.Payload] = {}
34+
headers: dict[str, temporalio.api.common.v1.Payload] = {}
3535

3636
result_headers = interceptor._span_context_to_headers(span_context, headers)
3737

@@ -40,8 +40,8 @@ def test_span_context_to_headers_with_valid_context(self):
4040

4141
def test_span_context_to_headers_with_empty_context(self):
4242
interceptor = BraintrustInterceptor()
43-
span_context: Dict[str, Any] = {}
44-
headers: Dict[str, temporalio.api.common.v1.Payload] = {}
43+
span_context: dict[str, Any] = {}
44+
headers: dict[str, temporalio.api.common.v1.Payload] = {}
4545

4646
result_headers = interceptor._span_context_to_headers(span_context, headers)
4747

@@ -78,7 +78,7 @@ def test_span_context_from_headers_with_valid_header(self):
7878

7979
def test_span_context_from_headers_with_missing_header(self):
8080
interceptor = BraintrustInterceptor()
81-
headers: Dict[str, temporalio.api.common.v1.Payload] = {}
81+
headers: dict[str, temporalio.api.common.v1.Payload] = {}
8282

8383
result = interceptor._span_context_from_headers(headers)
8484

py/src/braintrust/devserver/schemas.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import json
22
from collections.abc import Sequence
3-
from typing import Any, Union, get_args, get_origin, get_type_hints
4-
5-
from typing_extensions import TypedDict
3+
from typing import Any, TypedDict, Union, get_args, get_origin, get_type_hints
64

75

86
# This is not beautiful code, but it saves us from introducing Pydantic as a dependency, and it is fairly

py/src/braintrust/framework.py

Lines changed: 29 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
Any,
1818
Generic,
1919
Literal,
20+
Protocol,
21+
TypedDict,
2022
TypeVar,
21-
Union,
2223
)
2324

2425
from tqdm.asyncio import tqdm as async_tqdm
2526
from tqdm.auto import tqdm as std_tqdm
26-
from typing_extensions import Protocol, TypedDict
2727

2828
from .generated_types import FunctionFormat, FunctionOutputType, ObjectReference
2929
from .git_fields import GitMetadataSettings, RepoInfo
@@ -216,13 +216,8 @@ class EvalScorerArgs(SerializableDataClass, Generic[Input, Output, Expected]):
216216
metadata: Metadata | None = None
217217

218218

219-
OneOrMoreScores = Union[float, int, bool, None, Score, list[Score]]
220-
OneOrMoreClassifications = Union[
221-
None,
222-
Classification,
223-
Mapping[str, Any],
224-
list[Classification | Mapping[str, Any]],
225-
]
219+
OneOrMoreScores = float | int | bool | None | Score | list[Score]
220+
OneOrMoreClassifications = None | Classification | Mapping[str, Any] | list[Classification | Mapping[str, Any]]
226221

227222

228223
# Synchronous scorer interface - implements callable
@@ -247,20 +242,19 @@ class AsyncScorerLike(Protocol, Generic[Input, Output, Expected]):
247242
async def eval_async(self, output: Output, expected: Expected | None = None, **kwargs: Any) -> OneOrMoreScores: ...
248243

249244

250-
# Union type for any kind of scorer (for typing)
251-
ScorerLike = Union[SyncScorerLike[Input, Output, Expected], AsyncScorerLike[Input, Output, Expected]]
245+
ScorerLike = SyncScorerLike[Input, Output, Expected] | AsyncScorerLike[Input, Output, Expected]
252246

253-
EvalScorer = Union[
254-
ScorerLike[Input, Output, Expected],
255-
type[ScorerLike[Input, Output, Expected]],
256-
Callable[[Input, Output, Expected], OneOrMoreScores],
257-
Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]],
258-
]
247+
EvalScorer = (
248+
ScorerLike[Input, Output, Expected]
249+
| type[ScorerLike[Input, Output, Expected]]
250+
| Callable[[Input, Output, Expected], OneOrMoreScores]
251+
| Callable[[Input, Output, Expected], Awaitable[OneOrMoreScores]]
252+
)
259253

260-
EvalClassifier = Union[
261-
Callable[[Input, Output, Expected], OneOrMoreClassifications],
262-
Callable[[Input, Output, Expected], Awaitable[OneOrMoreClassifications]],
263-
]
254+
EvalClassifier = (
255+
Callable[[Input, Output, Expected], OneOrMoreClassifications]
256+
| Callable[[Input, Output, Expected], Awaitable[OneOrMoreClassifications]]
257+
)
264258

265259

266260
@dataclasses.dataclass
@@ -278,27 +272,23 @@ class BaseExperiment:
278272
"""
279273

280274

281-
_AnyEvalCase = Union[
282-
EvalCase[Input, Expected],
283-
EvalCaseDict[Input, Expected],
284-
EvalCaseDictNoOutput[Input],
285-
ExperimentDatasetEvent,
286-
]
275+
_AnyEvalCase = (
276+
EvalCase[Input, Expected] | EvalCaseDict[Input, Expected] | EvalCaseDictNoOutput[Input] | ExperimentDatasetEvent
277+
)
287278

288-
_EvalDataObject = Union[
289-
Iterable[_AnyEvalCase[Input, Expected]],
290-
Iterator[_AnyEvalCase[Input, Expected]],
291-
Awaitable[Iterator[_AnyEvalCase[Input, Expected]]],
292-
Callable[[], Union[Iterator[_AnyEvalCase[Input, Expected]], Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]]],
293-
BaseExperiment,
294-
]
279+
_EvalDataObject = (
280+
Iterable[_AnyEvalCase[Input, Expected]]
281+
| Iterator[_AnyEvalCase[Input, Expected]]
282+
| Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]
283+
| Callable[[], Iterator[_AnyEvalCase[Input, Expected]] | Awaitable[Iterator[_AnyEvalCase[Input, Expected]]]]
284+
| BaseExperiment
285+
)
295286

296-
EvalData = Union[_EvalDataObject[Input, Expected], type[_EvalDataObject[Input, Expected]], Dataset]
287+
EvalData = _EvalDataObject[Input, Expected] | type[_EvalDataObject[Input, Expected]] | Dataset
297288

298-
EvalTask = Union[
299-
Callable[[Input], Union[Output, Awaitable[Output]]],
300-
Callable[[Input, EvalHooks[Expected]], Union[Output, Awaitable[Output]]],
301-
]
289+
EvalTask = (
290+
Callable[[Input], Output | Awaitable[Output]] | Callable[[Input, EvalHooks[Expected]], Output | Awaitable[Output]]
291+
)
302292

303293
ErrorScoreHandler = Callable[[Span, EvalCase[Input, Expected], Sequence[str]], dict[str, float] | None]
304294

py/src/braintrust/functions/stream.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import json
1010
from collections.abc import Generator, Iterable
1111
from itertools import tee
12-
from typing import Literal, Union
12+
from typing import Literal
1313

1414
from sseclient import SSEClient
1515

@@ -79,13 +79,9 @@ class BraintrustInvokeError(ValueError):
7979
pass
8080

8181

82-
BraintrustStreamChunk = Union[
83-
BraintrustTextChunk,
84-
BraintrustJsonChunk,
85-
BraintrustErrorChunk,
86-
BraintrustConsoleChunk,
87-
BraintrustProgressChunk,
88-
]
82+
BraintrustStreamChunk = (
83+
BraintrustTextChunk | BraintrustJsonChunk | BraintrustErrorChunk | BraintrustConsoleChunk | BraintrustProgressChunk
84+
)
8985

9086

9187
class BraintrustStream:

py/src/braintrust/integrations/claude_agent_sdk/_constants.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
from collections.abc import Mapping
12
from dataclasses import dataclass
23
from enum import Enum
34
from types import MappingProxyType
4-
from typing import Final, Mapping
5+
from typing import Final
56

67

78
class MessageClassName(str, Enum):

py/src/braintrust/integrations/langchain/callbacks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from typing import (
88
Any,
99
TypedDict,
10-
Union,
1110
)
1211
from uuid import UUID
1312

@@ -528,7 +527,7 @@ def on_llm_new_token(
528527
self,
529528
token: str,
530529
*,
531-
chunk: Union["GenerationChunk", "ChatGenerationChunk"] | None = None, # type: ignore
530+
chunk: "GenerationChunk | ChatGenerationChunk | None" = None, # type: ignore
532531
run_id: UUID,
533532
parent_run_id: UUID | None = None,
534533
**kwargs: Any,

py/src/braintrust/integrations/langchain/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Sequence
1+
from collections.abc import Sequence
2+
from typing import Any
23
from unittest.mock import ANY
34

45

py/src/braintrust/integrations/langchain/test_callbacks.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import uuid
55
from pathlib import Path
6-
from typing import Dict, List, Union, cast
6+
from typing import cast
77

88
import pytest
99
from braintrust import logger
@@ -56,7 +56,7 @@ def test_llm_calls(logger_memory_logger):
5656
presence_penalty=0,
5757
n=1,
5858
)
59-
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
59+
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)
6060
chain.invoke({"number": "2"}, config={"callbacks": [cast(BaseCallbackHandler, handler)]})
6161

6262
spans = memory_logger.pop()
@@ -159,7 +159,7 @@ def test_chain_with_memory(logger_memory_logger):
159159
handler = BraintrustCallbackHandler(logger=test_logger)
160160
prompt = ChatPromptTemplate.from_template("{history} User: {input}")
161161
model = ChatOpenAI(model="gpt-4o-mini")
162-
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
162+
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)
163163

164164
memory = {"history": "Assistant: Hello! How can I assist you today?"}
165165
chain.invoke(
@@ -399,7 +399,7 @@ def test_parallel_execution(logger_memory_logger):
399399

400400
map_chain.invoke({"topic": "bear"}, config={"callbacks": [cast(BaseCallbackHandler, handler)]})
401401

402-
spans = cast(List, memory_logger.pop())
402+
spans = cast(list, memory_logger.pop())
403403

404404
# Find the LLM spans
405405
llm_spans = find_spans_by_attributes(spans, name="ChatOpenAI")
@@ -480,16 +480,16 @@ def test_langgraph_state_management(logger_memory_logger):
480480
n=1,
481481
)
482482

483-
def say_hello(state: Dict[str, str]):
483+
def say_hello(state: dict[str, str]):
484484
response = model.invoke("Say hello")
485-
return cast(Union[str, List[str], Dict[str, str]], response.content)
485+
return cast(str | list[str] | dict[str, str], response.content)
486486

487-
def say_bye(state: Dict[str, str]):
487+
def say_bye(state: dict[str, str]):
488488
print("From the 'sayBye' node: Bye world!")
489489
return "Bye"
490490

491491
workflow = (
492-
StateGraph(state_schema=Dict[str, str])
492+
StateGraph(state_schema=dict[str, str])
493493
.add_node("sayHello", say_hello)
494494
.add_node("sayBye", say_bye)
495495
.add_edge(START, "sayHello")
@@ -837,10 +837,10 @@ def test_streaming_ttft(logger_memory_logger):
837837
max_completion_tokens=50,
838838
streaming=True,
839839
)
840-
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
840+
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)
841841

842842
# Collect chunks to verify streaming works
843-
chunks: List[str] = []
843+
chunks: list[str] = []
844844
for chunk in chain.stream({}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}):
845845
if chunk.content:
846846
chunks.append(str(chunk.content))
@@ -1272,9 +1272,9 @@ async def test_async_streaming(logger_memory_logger):
12721272
handler = BraintrustCallbackHandler(logger=test_logger)
12731273
prompt = ChatPromptTemplate.from_template("Count from 1 to 3.")
12741274
model = ChatOpenAI(model="gpt-4o-mini", max_completion_tokens=20, streaming=True)
1275-
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
1275+
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)
12761276

1277-
chunks: List[str] = []
1277+
chunks: list[str] = []
12781278
async for chunk in chain.astream({}, config={"callbacks": [cast(BaseCallbackHandler, handler)]}):
12791279
if chunk.content:
12801280
chunks.append(str(chunk.content))

py/src/braintrust/integrations/langchain/test_context.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# pyright: reportTypedDictNotRequiredAccess=none
2-
from typing import Dict
32
from unittest.mock import ANY
43

54
import pytest
@@ -57,7 +56,7 @@ def test_global_handler(logger_memory_logger):
5756
presence_penalty=0,
5857
n=1,
5958
)
60-
chain: RunnableSerializable[Dict[str, str], BaseMessage] = prompt.pipe(model)
59+
chain: RunnableSerializable[dict[str, str], BaseMessage] = prompt.pipe(model)
6160

6261
message = chain.invoke({"number": "2"})
6362

0 commit comments

Comments
 (0)