Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions py/src/braintrust/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from .serializable_data_class import SerializableDataClass
from .span_types import SpanTypeAttribute
from .types._eval import EvalCaseDict, EvalCaseDictNoOutput, ExperimentDatasetEvent
from .util import bt_iscoroutinefunction, eprint, merge_dicts
from .util import bt_iscoroutinefunction, eprint, get_signature, merge_dicts


Input = TypeVar("Input")
Expand Down Expand Up @@ -471,7 +471,7 @@ def run_f(args, kwargs, ctx):

def _call_user_fn_args(fn, kwargs):
try:
signature = inspect.signature(fn)
signature = get_signature(fn)
except:
return [], kwargs

Expand Down Expand Up @@ -1585,7 +1585,7 @@ def report_progress(event: TaskProgressEvent):
# Check if the task takes a hooks argument
task_args = [datum.input]
try:
if len(inspect.signature(evaluator.task).parameters) == 2:
if len(get_signature(evaluator.task).parameters) == 2:
task_args.append(hooks)
except:
pass
Expand Down
3 changes: 2 additions & 1 deletion py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
encode_uri_component,
eprint,
get_caller_location,
get_signature,
mask_api_key,
merge_dicts,
parse_env_var_float,
Expand Down Expand Up @@ -2466,7 +2467,7 @@ def decorator(span_args, span_kwargs, f: F):
span_args += (f.__name__,)

try:
f_sig = inspect.signature(f)
f_sig = get_signature(f)
except:
f_sig = None

Expand Down
51 changes: 51 additions & 0 deletions py/src/braintrust/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.util
import re
import sys
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -209,6 +210,56 @@ def _run_eval_sync(self, *args, **kwargs):
assert result.summary.scores[scorer_name].score == 1.0


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 14),
reason="PEP 649 lazy annotation evaluation is 3.14+",
)
async def test_hooks_with_type_checking_only_annotation():
"""Regression test for #263.

On Python 3.14 (PEP 649), `inspect.signature(fn)` defaults to VALUE
format, which eagerly evaluates annotations and raises `NameError` for
TYPE_CHECKING-only imports. The bare except in `_call_user_fn_args`
used to fall back to passing every kwarg through, and the separate
signature call that decides whether to inject `hooks` would also raise,
leaving the task to crash on a missing `hooks` argument.
"""
# The unresolved name in the annotation must be unquoted to trigger the
# PEP 649 lazy-eval path; defining this at module top-level would fail
# on Python <3.14, so build it via exec inside the 3.14-only branch.
saw_hooks: list[bool] = []
ns: dict = {"saw_hooks": saw_hooks, "EvalHooks": EvalHooks}
exec(
"def task_with_unresolvable_hooks_annotation(\n"
" input_value: int,\n"
" hooks: EvalHooks[frozenset[SomeType]],\n"
") -> int:\n"
" saw_hooks.append(hooks is not None)\n"
" return input_value * 2\n",
ns,
)
task = ns["task_with_unresolvable_hooks_annotation"]

evaluator = Evaluator(
project_name="test-project",
eval_name="test-pep649-typecheck-only",
data=[EvalCase(input=1, expected=2)],
task=task,
scores=[],
experiment_name=None,
metadata=None,
trial_count=1,
)

result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[])

assert len(result.results) == 1
assert saw_hooks == [True]
assert result.results[0].error is None
assert result.results[0].output == 2


@pytest.mark.asyncio
async def test_hooks_trial_index():
"""Test that trial_index is correctly passed to task via hooks."""
Expand Down
14 changes: 14 additions & 0 deletions py/src/braintrust/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,20 @@ def parse_env_var_float(name: str, default: float) -> float:
BT_IS_ASYNC_ATTRIBUTE = "_BT_IS_ASYNC"


def get_signature(fn: Callable) -> inspect.Signature:
# On Python 3.14+ (PEP 649), inspect.signature evaluates annotations
# eagerly in VALUE format by default. Annotations referencing
# TYPE_CHECKING-only imports raise NameError. Use FORWARDREF so
# unresolvable names become ForwardRef objects; callers here only
# inspect parameter names/kinds, not annotation values.
if sys.version_info >= (3, 14):
import annotationlib

kwargs = {"annotation_format": annotationlib.Format.FORWARDREF}
return inspect.signature(fn, **kwargs) # pylint: disable=unexpected-keyword-arg
return inspect.signature(fn)


# Taken from
# https://stackoverflow.com/questions/5574702/how-do-i-print-to-stderr-in-python.
def is_numeric(v):
Expand Down