diff --git a/py/src/braintrust/framework.py b/py/src/braintrust/framework.py index cbe529cf..b6077ef6 100644 --- a/py/src/braintrust/framework.py +++ b/py/src/braintrust/framework.py @@ -85,6 +85,7 @@ class EvalCase(SerializableDataClass, Generic[Input, Expected]): expected: Expected | None = None metadata: Metadata | None = None tags: Sequence[str] | None = None + trial_count: int | None = None # These fields are only set if the EvalCase is part of a Dataset. id: str | None = None @@ -1677,7 +1678,12 @@ async def with_max_concurrency(coro): disable=position is None, ) as pbar: async for datum in pbar: - for trial_index in range(evaluator.trial_count): + if isinstance(datum, dict): + datum_trial_count = datum.get("trial_count") + else: + datum_trial_count = getattr(datum, "trial_count", None) + trial_count = datum_trial_count if datum_trial_count is not None else evaluator.trial_count + for trial_index in range(trial_count): tasks.append(asyncio.create_task(with_max_concurrency(run_evaluator_task(datum, trial_index)))) if not tasks: diff --git a/py/src/braintrust/test_framework.py b/py/src/braintrust/test_framework.py index b83ef883..e428fa63 100644 --- a/py/src/braintrust/test_framework.py +++ b/py/src/braintrust/test_framework.py @@ -290,6 +290,119 @@ def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: assert sorted(input_2_trials) == [0, 1] +@pytest.mark.asyncio +async def test_per_input_trial_count_overrides_global(): + """Test that per-input trial_count overrides the global trial_count.""" + trial_data: List[tuple] = [] # (input, trial_index) + + def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: + trial_data.append((input_value, hooks.trial_index)) + return input_value * 2 + + evaluator = Evaluator( + project_name="test-project", + eval_name="test-per-input-trial-count", + data=[ + EvalCase(input=1, expected=2), # inherits global trial_count=2 + EvalCase(input=2, expected=4, trial_count=5), # overrides to 5 + EvalCase(input=3, expected=6, trial_count=1), # overrides to 1 + ], + task=task_with_hooks, + scores=[], + experiment_name=None, + metadata=None, + trial_count=2, + ) + + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + # 2 + 5 + 1 = 8 total results + assert len(result.results) == 8 + assert len(trial_data) == 8 + + input_1_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 1]) + input_2_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 2]) + input_3_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 3]) + + assert input_1_trials == [0, 1] + assert input_2_trials == [0, 1, 2, 3, 4] + assert input_3_trials == [0] + + +@pytest.mark.asyncio +async def test_per_input_trial_count_without_global(): + """Test that per-input trial_count works when no global trial_count is set.""" + trial_data: List[tuple] = [] # (input, trial_index) + + def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: + trial_data.append((input_value, hooks.trial_index)) + return input_value * 2 + + evaluator = Evaluator( + project_name="test-project", + eval_name="test-per-input-trial-count-no-global", + data=[ + EvalCase(input=1, expected=2), # uses default trial_count=1 + EvalCase(input=2, expected=4, trial_count=3), # overrides to 3 + ], + task=task_with_hooks, + scores=[], + experiment_name=None, + metadata=None, + ) + + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + # 1 + 3 = 4 total results + assert len(result.results) == 4 + assert len(trial_data) == 4 + + input_1_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 1]) + input_2_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 2]) + + assert input_1_trials == [0] + assert input_2_trials == [0, 1, 2] + + +@pytest.mark.asyncio +async def test_per_input_trial_count_with_dict_data(): + """Test that per-input trial_count works when data items are plain dicts.""" + trial_data: List[tuple] = [] # (input, trial_index) + + def task_with_hooks(input_value: int, hooks: EvalHooks) -> int: + trial_data.append((input_value, hooks.trial_index)) + return input_value * 2 + + evaluator = Evaluator( + project_name="test-project", + eval_name="test-per-input-trial-count-dict", + data=[ + {"input": 1, "expected": 2}, # inherits global trial_count=2 + {"input": 2, "expected": 4, "trial_count": 4}, # overrides to 4 + {"input": 3, "expected": 6, "trial_count": 1}, # overrides to 1 + ], + task=task_with_hooks, + scores=[], + experiment_name=None, + metadata=None, + trial_count=2, + ) + + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=[]) + + # 2 + 4 + 1 = 7 total results + assert len(result.results) == 7 + assert len(trial_data) == 7 + + input_1_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 1]) + input_2_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 2]) + input_3_trials = sorted([trial_idx for inp, trial_idx in trial_data if inp == 3]) + + assert input_1_trials == [0, 1] + assert input_2_trials == [0, 1, 2, 3] + assert input_3_trials == [0] + + @pytest.mark.vcr @pytest.mark.asyncio async def test_scorer_spans_have_purpose_attribute(with_memory_logger, with_simulate_login): diff --git a/py/src/braintrust/types/_eval.py b/py/src/braintrust/types/_eval.py index c8b5dc6d..0f5be193 100644 --- a/py/src/braintrust/types/_eval.py +++ b/py/src/braintrust/types/_eval.py @@ -25,6 +25,7 @@ class EvalCaseDictNoOutput(Generic[Input], TypedDict): input: Input metadata: NotRequired[dict[str, Any] | None] tags: NotRequired[Sequence[str] | None] + trial_count: NotRequired[int | None] id: NotRequired[str | None] _xact_id: NotRequired[str | None]