|
1 | 1 | import importlib.util |
| 2 | +import re |
2 | 3 | from typing import List |
3 | 4 | from unittest.mock import MagicMock |
4 | 5 |
|
|
11 | 12 | EvalHooks, |
12 | 13 | EvalResultWithSummary, |
13 | 14 | Evaluator, |
| 15 | + Filter, |
| 16 | + evaluate_filter, |
| 17 | + parse_filters, |
14 | 18 | run_evaluator, |
15 | 19 | ) |
16 | 20 | from .score import Score, Scorer |
@@ -626,3 +630,77 @@ async def test_run_evaluator_empty_dataset_warns(capsys): |
626 | 630 | captured = capsys.readouterr() |
627 | 631 | assert "Warning" in captured.err |
628 | 632 | assert "empty" in captured.err.lower() |
| 633 | + |
| 634 | + |
| 635 | +class TestEvaluateFilter: |
| 636 | + """Regression tests for https://github.com/braintrustdata/braintrust-sdk-python/issues/207.""" |
| 637 | + |
| 638 | + @pytest.mark.parametrize( |
| 639 | + "datum", |
| 640 | + [ |
| 641 | + {"input": "hello", "metadata": {"name": "foo"}}, |
| 642 | + EvalCase(input="hello", metadata={"name": "foo"}), |
| 643 | + ], |
| 644 | + ids=["dict", "evalcase"], |
| 645 | + ) |
| 646 | + def test_evaluate_filter_match(self, datum): |
| 647 | + f = Filter(path=["metadata", "name"], pattern=re.compile("foo")) |
| 648 | + assert evaluate_filter(datum, f) is True |
| 649 | + |
| 650 | + @pytest.mark.parametrize( |
| 651 | + "datum", |
| 652 | + [ |
| 653 | + {"input": "hello", "metadata": {"name": "bar"}}, |
| 654 | + EvalCase(input="hello", metadata={"name": "bar"}), |
| 655 | + ], |
| 656 | + ids=["dict", "evalcase"], |
| 657 | + ) |
| 658 | + def test_evaluate_filter_no_match(self, datum): |
| 659 | + f = Filter(path=["metadata", "name"], pattern=re.compile("foo")) |
| 660 | + assert evaluate_filter(datum, f) is False |
| 661 | + |
| 662 | + @pytest.mark.parametrize( |
| 663 | + "datum", |
| 664 | + [ |
| 665 | + {"input": "hello"}, |
| 666 | + EvalCase(input="hello"), |
| 667 | + ], |
| 668 | + ids=["dict", "evalcase"], |
| 669 | + ) |
| 670 | + def test_evaluate_filter_missing_key(self, datum): |
| 671 | + f = Filter(path=["metadata", "name"], pattern=re.compile("foo")) |
| 672 | + assert evaluate_filter(datum, f) is False |
| 673 | + |
| 674 | + def test_evaluate_filter_nested_metadata(self): |
| 675 | + datum = EvalCase(input="hello", metadata={"priority": "P0", "owner": "alice"}) |
| 676 | + f = Filter(path=["metadata", "priority"], pattern=re.compile("^P0$")) |
| 677 | + assert evaluate_filter(datum, f) is True |
| 678 | + |
| 679 | + def test_evaluate_filter_input_field(self): |
| 680 | + datum = EvalCase(input={"text": "hello world"}, metadata={"name": "foo"}) |
| 681 | + f = Filter(path=["input", "text"], pattern=re.compile("hello")) |
| 682 | + assert evaluate_filter(datum, f) is True |
| 683 | + |
| 684 | + @pytest.mark.asyncio |
| 685 | + async def test_run_evaluator_with_filter_and_evalcase(self): |
| 686 | + data = [ |
| 687 | + EvalCase(input="hello", metadata={"name": "foo"}), |
| 688 | + EvalCase(input="world", metadata={"name": "bar"}), |
| 689 | + ] |
| 690 | + |
| 691 | + evaluator = Evaluator( |
| 692 | + project_name="test-project", |
| 693 | + eval_name="test-filter-evalcase", |
| 694 | + data=data, |
| 695 | + task=lambda x: x, |
| 696 | + scores=[], |
| 697 | + experiment_name=None, |
| 698 | + metadata=None, |
| 699 | + ) |
| 700 | + |
| 701 | + filters = parse_filters(["metadata.name=foo"]) |
| 702 | + result = await run_evaluator(experiment=None, evaluator=evaluator, position=None, filters=filters) |
| 703 | + |
| 704 | + # Only the "foo" case should pass the filter |
| 705 | + assert len(result.results) == 1 |
| 706 | + assert result.results[0].input == "hello" |
0 commit comments