diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index 84a504aa..5367dd40 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -52,7 +52,7 @@ from .serializable_data_class import SerializableDataClass from .span_types import SpanTypeAttribute from .types._eval import EvalCaseDict, EvalCaseDictNoOutput, ExperimentDatasetEvent -from .util import bt_iscoroutinefunction, eprint, merge_dicts +from .util import bt_iscoroutinefunction, eprint, get_signature, merge_dicts Input = TypeVar("Input") @@ -471,7 +471,7 @@ def run_f(args, kwargs, ctx): def _call_user_fn_args(fn, kwargs): try: - signature = inspect.signature(fn) + signature = get_signature(fn) except: return [], kwargs @@ -1585,7 +1585,7 @@ def report_progress(event: TaskProgressEvent): # Check if the task takes a hooks argument task_args = [datum.input] try: - if len(inspect.signature(evaluator.task).parameters) == 2: + if len(get_signature(evaluator.task).parameters) == 2: task_args.append(hooks) except: pass diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 6ed497a1..522b31e1 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -88,6 +88,7 @@ encode_uri_component, eprint, get_caller_location, + get_signature, mask_api_key, merge_dicts, parse_env_var_float, @@ -2466,7 +2467,7 @@ def decorator(span_args, span_kwargs, f: F): span_args += (f.__name__,) try: - f_sig = inspect.signature(f) + f_sig = get_signature(f) except: f_sig = None diff --git a/py/src/braintrust/test_framework.py b/py/src/braintrust/test_framework.py index 608b7585..71168a38 100644 --- a/py/src/braintrust/test_framework.py +++ b/py/src/braintrust/test_framework.py @@ -1,5 +1,6 @@ import importlib.util import re +import sys from unittest.mock import MagicMock import pytest @@ -209,6 +210,56 @@ def _run_eval_sync(self, *args, **kwargs): assert result.summary.scores[scorer_name].score == 1.0 +@pytest.mark.asyncio +@pytest.mark.skipif( + sys.version_info < (3, 14), + reason="PEP 649 lazy annotation evaluation is 3.14+", +) +async def test_hooks_with_type_checking_only_annotation(): + """Regression test for #263. + + On Python 3.14 (PEP 649), `inspect.signature(fn)` defaults to VALUE + format, which eagerly evaluates annotations and raises `NameError` for + TYPE_CHECKING-only imports. The bare except in `_call_user_fn_args` + used to fall back to passing every kwarg through, and the separate + signature call that decides whether to inject `hooks` would also raise, + leaving the task to crash on a missing `hooks` argument. + """ + # The unresolved name in the annotation must be unquoted to trigger the + # PEP 649 lazy-eval path; defining this at module top-level would fail + # on Python <3.14, so build it via exec inside the 3.14-only branch. + saw_hooks: list[bool] = [] + ns: dict = {"saw_hooks": saw_hooks, "EvalHooks": EvalHooks} + exec( + "def task_with_unresolvable_hooks_annotation(\n" + " input_value: int,\n" + " hooks: EvalHooks[frozenset[SomeType]],\n" + ") -> int:\n" + " saw_hooks.append(hooks is not None)\n" + " return input_value * 2\n", + ns, + ) + task = ns["task_with_unresolvable_hooks_annotation"] + + evaluator = Evaluator( + project_name="test-project", + eval_name="test-pep649-typecheck-only", + data=[EvalCase(input=1, expected=2)], + task=task, + scores=[], + experiment_name=None, + metadata=None, + trial_count=1, + ) + + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + assert len(result.results) == 1 + assert saw_hooks == [True] + assert result.results[0].error is None + assert result.results[0].output == 2 + + @pytest.mark.asyncio async def test_hooks_trial_index(): """Test that trial_index is correctly passed to task via hooks.""" diff --git a/py/src/braintrust/util.py b/py/src/braintrust/util.py index 7fdb8abb..09c5b2c3 100644 --- a/py/src/braintrust/util.py +++ b/py/src/braintrust/util.py @@ -34,6 +34,20 @@ def parse_env_var_float(name: str, default: float) -> float: BT_IS_ASYNC_ATTRIBUTE = "_BT_IS_ASYNC" +def get_signature(fn: Callable) -> inspect.Signature: + # On Python 3.14+ (PEP 649), inspect.signature evaluates annotations + # eagerly in VALUE format by default. Annotations referencing + # TYPE_CHECKING-only imports raise NameError. Use FORWARDREF so + # unresolvable names become ForwardRef objects; callers here only + # inspect parameter names/kinds, not annotation values. + if sys.version_info >= (3, 14): + import annotationlib + + kwargs = {"annotation_format": annotationlib.Format.FORWARDREF} + return inspect.signature(fn, **kwargs) # pylint: disable=unexpected-keyword-arg + return inspect.signature(fn) + + # Taken from # https://stackoverflow.com/questions/5574702/how-do-i-print-to-stderr-in-python. def is_numeric(v):