Skip to content

Commit 090633d

Browse files
authored
fix: allow eval filtering with EvalCase objects (#211)
1 parent 514aece commit 090633d

2 files changed

Lines changed: 83 additions & 2 deletions

File tree

py/src/braintrust/framework.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import traceback
1010
import warnings
1111
from collections import defaultdict
12-
from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Sequence
12+
from collections.abc import Awaitable, Callable, Coroutine, Iterable, Iterator, Mapping, Sequence
1313
from concurrent.futures import ThreadPoolExecutor
1414
from contextlib import contextmanager
1515
from multiprocessing import cpu_count
@@ -1154,7 +1154,10 @@ def parse_filters(filters: list[str]) -> list[Filter]:
11541154
def evaluate_filter(object, filter: Filter):
11551155
key = object
11561156
for p in filter.path:
1157-
key = key.get(p)
1157+
if isinstance(key, Mapping):
1158+
key = key.get(p)
1159+
else:
1160+
key = getattr(key, p, None)
11581161
if key is None:
11591162
return False
11601163
return filter.pattern.match(serialize_json_with_plain_string(key)) is not None

py/src/braintrust/test_framework.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib.util
2+
import re
23
from typing import List
34
from unittest.mock import MagicMock
45

@@ -11,6 +12,9 @@
1112
EvalHooks,
1213
EvalResultWithSummary,
1314
Evaluator,
15+
Filter,
16+
evaluate_filter,
17+
parse_filters,
1418
run_evaluator,
1519
)
1620
from .score import Score, Scorer
@@ -626,3 +630,77 @@ async def test_run_evaluator_empty_dataset_warns(capsys):
626630
captured = capsys.readouterr()
627631
assert "Warning" in captured.err
628632
assert "empty" in captured.err.lower()
633+
634+
635+
class TestEvaluateFilter:
636+
"""Regression tests for https://github.com/braintrustdata/braintrust-sdk-python/issues/207."""
637+
638+
@pytest.mark.parametrize(
639+
"datum",
640+
[
641+
{"input": "hello", "metadata": {"name": "foo"}},
642+
EvalCase(input="hello", metadata={"name": "foo"}),
643+
],
644+
ids=["dict", "evalcase"],
645+
)
646+
def test_evaluate_filter_match(self, datum):
647+
f = Filter(path=["metadata", "name"], pattern=re.compile("foo"))
648+
assert evaluate_filter(datum, f) is True
649+
650+
@pytest.mark.parametrize(
651+
"datum",
652+
[
653+
{"input": "hello", "metadata": {"name": "bar"}},
654+
EvalCase(input="hello", metadata={"name": "bar"}),
655+
],
656+
ids=["dict", "evalcase"],
657+
)
658+
def test_evaluate_filter_no_match(self, datum):
659+
f = Filter(path=["metadata", "name"], pattern=re.compile("foo"))
660+
assert evaluate_filter(datum, f) is False
661+
662+
@pytest.mark.parametrize(
663+
"datum",
664+
[
665+
{"input": "hello"},
666+
EvalCase(input="hello"),
667+
],
668+
ids=["dict", "evalcase"],
669+
)
670+
def test_evaluate_filter_missing_key(self, datum):
671+
f = Filter(path=["metadata", "name"], pattern=re.compile("foo"))
672+
assert evaluate_filter(datum, f) is False
673+
674+
def test_evaluate_filter_nested_metadata(self):
675+
datum = EvalCase(input="hello", metadata={"priority": "P0", "owner": "alice"})
676+
f = Filter(path=["metadata", "priority"], pattern=re.compile("^P0$"))
677+
assert evaluate_filter(datum, f) is True
678+
679+
def test_evaluate_filter_input_field(self):
680+
datum = EvalCase(input={"text": "hello world"}, metadata={"name": "foo"})
681+
f = Filter(path=["input", "text"], pattern=re.compile("hello"))
682+
assert evaluate_filter(datum, f) is True
683+
684+
@pytest.mark.asyncio
685+
async def test_run_evaluator_with_filter_and_evalcase(self):
686+
data = [
687+
EvalCase(input="hello", metadata={"name": "foo"}),
688+
EvalCase(input="world", metadata={"name": "bar"}),
689+
]
690+
691+
evaluator = Evaluator(
692+
project_name="test-project",
693+
eval_name="test-filter-evalcase",
694+
data=data,
695+
task=lambda x: x,
696+
scores=[],
697+
experiment_name=None,
698+
metadata=None,
699+
)
700+
701+
filters = parse_filters(["metadata.name=foo"])
702+
result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=filters)
703+
704+
# Only the "foo" case should pass the filter
705+
assert len(result.results) == 1
706+
assert result.results[0].input == "hello"

0 commit comments

Comments
 (0)