From a1c7b000053bed0a183c8b48515ba9815c83adee Mon Sep 17 00:00:00 2001 From: Abhijeet Prasad Date: Tue, 3 Mar 2026 16:22:02 -0500 Subject: [PATCH] feat: add tags support to SDK builder classes --- py/src/braintrust/cli/push.py | 2 + py/src/braintrust/cli/test_push.py | 101 ++++++++++++++++ py/src/braintrust/conftest.py | 9 ++ py/src/braintrust/framework2.py | 21 +++- py/src/braintrust/test_framework2.py | 168 ++++++++++++++++----------- 5 files changed, 234 insertions(+), 67 deletions(-) create mode 100644 py/src/braintrust/cli/test_push.py diff --git a/py/src/braintrust/cli/push.py b/py/src/braintrust/cli/push.py index ff7af8d8..4bcd6a68 100644 --- a/py/src/braintrust/cli/push.py +++ b/py/src/braintrust/cli/push.py @@ -251,6 +251,8 @@ def _collect_function_function_defs( } if f.metadata is not None: j["metadata"] = f.metadata + if f.tags is not None: + j["tags"] = list(f.tags) if f.parameters is None: raise ValueError(f"Function {f.name} has no supplied parameters") j["function_schema"] = { diff --git a/py/src/braintrust/cli/test_push.py b/py/src/braintrust/cli/test_push.py new file mode 100644 index 00000000..24207ca7 --- /dev/null +++ b/py/src/braintrust/cli/test_push.py @@ -0,0 +1,101 @@ +"""Tests for push command serialization.""" + +import pytest + +pydantic = pytest.importorskip("pydantic") + +from ..framework2 import ( + global_, + projects, +) +from .push import _collect_function_function_defs + + +class ToolInput(pydantic.BaseModel): + value: int + + +@pytest.fixture(autouse=True) +def clear_global_state(): + global_.functions.clear() + global_.prompts.clear() + yield + global_.functions.clear() + global_.prompts.clear() + + +class TestPushMetadata: + """Tests for metadata in push command serialization.""" + + def test_collect_function_function_defs_includes_metadata(self, mock_project_ids): + project = projects.create("test-project") + metadata = {"version": "1.0", "author": "test"} + + tool = project.tools.create( + handler=lambda x: x, + name="test-tool", + parameters=ToolInput, + metadata=metadata, + ) + global_.functions.append(tool) + + functions = [] + _collect_function_function_defs(mock_project_ids, functions, "bundle-123", "error") + + assert len(functions) == 1 + assert functions[0]["metadata"] == metadata + assert functions[0]["name"] == "test-tool" + + def test_collect_function_function_defs_excludes_metadata_when_none(self, mock_project_ids): + project = projects.create("test-project") + + tool = project.tools.create( + handler=lambda x: x, + name="test-tool", + parameters=ToolInput, + ) + global_.functions.append(tool) + + functions = [] + _collect_function_function_defs(mock_project_ids, functions, "bundle-123", "error") + + assert len(functions) == 1 + assert "metadata" not in functions[0] + + +class TestPushTags: + """Tests for tags in push command serialization.""" + + def test_collect_function_function_defs_includes_tags(self, mock_project_ids): + project = projects.create("test-project") + tags = ["production", "v1"] + + tool = project.tools.create( + handler=lambda x: x, + name="test-tool", + parameters=ToolInput, + tags=tags, + ) + global_.functions.append(tool) + + functions = [] + _collect_function_function_defs(mock_project_ids, functions, "bundle-123", "error") + + assert len(functions) == 1 + assert functions[0]["tags"] == ["production", "v1"] + + def test_collect_function_function_defs_excludes_tags_when_none(self, mock_project_ids): + project = projects.create("test-project") + + tool = project.tools.create( + handler=lambda x: x, + name="test-tool", + parameters=ToolInput, + ) + global_.functions.append(tool) + + functions = [] + _collect_function_function_defs(mock_project_ids, functions, "bundle-123", "error") + + assert len(functions) == 1 + assert "tags" not in functions[0] diff --git a/py/src/braintrust/conftest.py b/py/src/braintrust/conftest.py index 0a20821e..24bd76af 100644 --- a/py/src/braintrust/conftest.py +++ b/py/src/braintrust/conftest.py @@ -1,6 +1,8 @@ import os +from unittest.mock import MagicMock import pytest +from braintrust.framework2 import ProjectIdCache def _patch_vcr_aiohttp_stubs(): @@ -67,6 +69,13 @@ def cached_content(self): _patch_vcr_aiohttp_stubs() +@pytest.fixture +def mock_project_ids(): + mock = MagicMock(spec=ProjectIdCache) + mock.get.return_value = "project-123" + return mock + + @pytest.fixture(autouse=True) def override_app_url_for_tests(): """ diff --git a/py/src/braintrust/framework2.py b/py/src/braintrust/framework2.py index b5450885..a3170843 100644 --- a/py/src/braintrust/framework2.py +++ b/py/src/braintrust/framework2.py @@ -1,6 +1,6 @@ import dataclasses import json -from collections.abc import Callable +from collections.abc import Callable, Sequence from typing import Any, overload import slugify @@ -53,6 +53,7 @@ class CodeFunction: returns: Any if_exists: IfExists | None metadata: dict[str, Any] | None = None + tags: Sequence[str] | None = None @dataclasses.dataclass @@ -69,6 +70,7 @@ class CodePrompt: id: str | None if_exists: IfExists | None metadata: dict[str, Any] | None = None + tags: Sequence[str] | None = None def to_function_definition(self, if_exists: IfExists | None, project_ids: ProjectIdCache) -> dict[str, Any]: prompt_data = self.prompt @@ -102,6 +104,8 @@ def to_function_definition(self, if_exists: IfExists | None, project_ids: Projec j["function_type"] = self.function_type if self.metadata is not None: j["metadata"] = self.metadata + if self.tags is not None: + j["tags"] = list(self.tags) return j @@ -124,6 +128,7 @@ def create( returns: Any = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, ) -> CodeFunction: """Creates a tool. @@ -136,6 +141,7 @@ def create( returns: The tool's output schema, as a Pydantic model. if_exists: What to do if the tool already exists. metadata: Custom metadata to attach to the tool. + tags: A list of tags for the tool. Returns: A handle to the created tool, that can be used in a prompt. @@ -160,6 +166,7 @@ def create( returns=returns, if_exists=if_exists, metadata=metadata, + tags=tags, ) self.project.add_code_function(f) return f @@ -186,6 +193,7 @@ def create( tools: list[CodeFunction | SavedFunctionId | ToolFunctionDefinition] | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, ) -> CodePrompt: ... @overload # messages only, no prompt @@ -202,6 +210,7 @@ def create( tools: list[CodeFunction | SavedFunctionId | ToolFunctionDefinition] | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, ) -> CodePrompt: ... def create( @@ -218,6 +227,7 @@ def create( tools: list[CodeFunction | SavedFunctionId | ToolFunctionDefinition] | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, ): """Creates a prompt. @@ -233,6 +243,7 @@ def create( tools: The tools to use for the prompt. if_exists: What to do if the prompt already exists. metadata: Custom metadata to attach to the prompt. + tags: A list of tags for the prompt. """ self._task_counter += 1 if not name: @@ -282,6 +293,7 @@ def create( id=id, if_exists=if_exists, metadata=metadata, + tags=tags, ) self.project.add_prompt(p) return p @@ -304,6 +316,7 @@ def create( description: str | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, handler: Callable[..., Any], parameters: Any, returns: Any = None, @@ -319,6 +332,7 @@ def create( description: str | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, prompt: str, model: str, params: ModelParams | None = None, @@ -336,6 +350,7 @@ def create( description: str | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, messages: list[ChatCompletionMessageParam], model: str, params: ModelParams | None = None, @@ -351,6 +366,7 @@ def create( description: str | None = None, if_exists: IfExists | None = None, metadata: dict[str, Any] | None = None, + tags: Sequence[str] | None = None, # Code scorer params. handler: Callable[..., Any] | None = None, parameters: Any = None, @@ -371,6 +387,7 @@ def create( description: The description of the scorer. if_exists: What to do if the scorer already exists. metadata: Custom metadata to attach to the scorer. + tags: A list of tags for the scorer. The remaining args are mutually exclusive; that is, the function will only accept args from one of the following overloads. @@ -410,6 +427,7 @@ def create( returns=returns, if_exists=if_exists, metadata=metadata, + tags=tags, ) self.project.add_code_function(f) return f @@ -449,6 +467,7 @@ def create( id=None, if_exists=if_exists, metadata=metadata, + tags=tags, ) self.project.add_prompt(p) return p diff --git a/py/src/braintrust/test_framework2.py b/py/src/braintrust/test_framework2.py index 8f86eedf..9b06c5b5 100644 --- a/py/src/braintrust/test_framework2.py +++ b/py/src/braintrust/test_framework2.py @@ -1,14 +1,10 @@ -"""Tests for framework2 module, specifically metadata support.""" +"""Tests for framework2 module, specifically metadata and tags support.""" import importlib.util -from unittest.mock import MagicMock import pytest -from .framework2 import ( - ProjectIdCache, - projects, -) +from .framework2 import projects # Check if pydantic is available HAS_PYDANTIC = importlib.util.find_spec("pydantic") is not None @@ -18,7 +14,6 @@ class TestCodeFunctionMetadata: """Tests for CodeFunction metadata support.""" def test_code_function_with_metadata(self): - """Test that CodeFunction stores metadata correctly.""" project = projects.create("test-project") metadata = {"version": "1.0", "author": "test"} @@ -34,7 +29,6 @@ def test_code_function_with_metadata(self): assert tool.slug == "test-tool" def test_code_function_without_metadata(self): - """Test that CodeFunction works without metadata.""" project = projects.create("test-project") tool = project.tools.create( @@ -50,7 +44,6 @@ class TestCodePromptMetadata: """Tests for CodePrompt metadata support.""" def test_code_prompt_with_metadata(self): - """Test that CodePrompt stores metadata correctly.""" project = projects.create("test-project") metadata = {"category": "greeting", "priority": "high"} @@ -65,7 +58,6 @@ def test_code_prompt_with_metadata(self): assert prompt.name == "test-prompt" def test_code_prompt_without_metadata(self): - """Test that CodePrompt works without metadata.""" project = projects.create("test-project") prompt = project.prompts.create( @@ -76,8 +68,7 @@ def test_code_prompt_without_metadata(self): assert prompt.metadata is None - def test_code_prompt_to_function_definition_includes_metadata(self): - """Test that to_function_definition includes metadata when present.""" + def test_code_prompt_to_function_definition_includes_metadata(self, mock_project_ids): project = projects.create("test-project") metadata = {"version": "2.0", "tag": "production"} @@ -88,17 +79,13 @@ def test_code_prompt_to_function_definition_includes_metadata(self): metadata=metadata, ) - mock_project_ids = MagicMock(spec=ProjectIdCache) - mock_project_ids.get.return_value = "project-123" - func_def = prompt.to_function_definition(None, mock_project_ids) assert func_def["metadata"] == metadata assert func_def["name"] == "test-prompt" assert func_def["project_id"] == "project-123" - def test_code_prompt_to_function_definition_excludes_metadata_when_none(self): - """Test that to_function_definition excludes metadata when None.""" + def test_code_prompt_to_function_definition_excludes_metadata_when_none(self, mock_project_ids): project = projects.create("test-project") prompt = project.prompts.create( @@ -107,9 +94,6 @@ def test_code_prompt_to_function_definition_excludes_metadata_when_none(self): model="gpt-4", ) - mock_project_ids = MagicMock(spec=ProjectIdCache) - mock_project_ids.get.return_value = "project-123" - func_def = prompt.to_function_definition(None, mock_project_ids) assert "metadata" not in func_def @@ -120,7 +104,6 @@ class TestScorerMetadata: @pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") def test_code_scorer_with_metadata(self): - """Test that code scorer stores metadata correctly.""" from pydantic import BaseModel class ScorerInput(BaseModel): @@ -144,7 +127,6 @@ def my_scorer(output: str, expected: str) -> float: assert scorer.name == "test-scorer" def test_llm_scorer_with_metadata(self): - """Test that LLM scorer stores metadata correctly.""" project = projects.create("test-project") metadata = {"type": "llm_classifier", "version": "2.0"} @@ -161,73 +143,127 @@ def test_llm_scorer_with_metadata(self): assert scorer.name == "llm-scorer" -@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") -class TestPushMetadata: - """Tests for metadata in push command serialization.""" +class TestCodeFunctionTags: + """Tests for CodeFunction tags support.""" - def test_collect_function_function_defs_includes_metadata(self): - """Test that _collect_function_function_defs includes metadata.""" - from pydantic import BaseModel + def test_code_function_with_tags(self): + project = projects.create("test-project") + tags = ["production", "v1"] - from .cli.push import _collect_function_function_defs - from .framework2 import global_ + tool = project.tools.create( + handler=lambda x: x, + name="test-tool", + parameters=None, + tags=tags, + ) - class ToolInput(BaseModel): - value: int + assert tool.tags == tags + def test_code_function_without_tags(self): project = projects.create("test-project") - metadata = {"version": "1.0", "author": "test"} - - global_.functions.clear() tool = project.tools.create( handler=lambda x: x, name="test-tool", - parameters=ToolInput, - metadata=metadata, + parameters=None, ) - global_.functions.append(tool) - mock_project_ids = MagicMock(spec=ProjectIdCache) - mock_project_ids.get.return_value = "project-123" + assert tool.tags is None - functions = [] - _collect_function_function_defs(mock_project_ids, functions, "bundle-123", "error") - assert len(functions) == 1 - assert functions[0]["metadata"] == metadata - assert functions[0]["name"] == "test-tool" +class TestCodePromptTags: + """Tests for CodePrompt tags support.""" - global_.functions.clear() + def test_code_prompt_with_tags(self): + project = projects.create("test-project") + tags = ["greeting", "v2"] - def test_collect_function_function_defs_excludes_metadata_when_none(self): - """Test that _collect_function_function_defs excludes metadata when None.""" - from pydantic import BaseModel + prompt = project.prompts.create( + name="test-prompt", + prompt="Hello {{name}}", + model="gpt-4", + tags=tags, + ) + + assert prompt.tags == tags - from .cli.push import _collect_function_function_defs - from .framework2 import global_ + def test_code_prompt_without_tags(self): + project = projects.create("test-project") + + prompt = project.prompts.create( + name="test-prompt", + prompt="Hello {{name}}", + model="gpt-4", + ) - class ToolInput(BaseModel): - value: int + assert prompt.tags is None + def test_code_prompt_to_function_definition_includes_tags(self, mock_project_ids): project = projects.create("test-project") + tags = ["production", "scorer"] - global_.functions.clear() + prompt = project.prompts.create( + name="test-prompt", + prompt="Hello {{name}}", + model="gpt-4", + tags=tags, + ) - tool = project.tools.create( - handler=lambda x: x, - name="test-tool", - parameters=ToolInput, + func_def = prompt.to_function_definition(None, mock_project_ids) + + assert func_def["tags"] == ["production", "scorer"] + + def test_code_prompt_to_function_definition_excludes_tags_when_none(self, mock_project_ids): + project = projects.create("test-project") + + prompt = project.prompts.create( + name="test-prompt", + prompt="Hello {{name}}", + model="gpt-4", + ) + + func_def = prompt.to_function_definition(None, mock_project_ids) + + assert "tags" not in func_def + + +class TestScorerTags: + """Tests for Scorer tags support.""" + + @pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") + def test_code_scorer_with_tags(self): + from pydantic import BaseModel + + class ScorerInput(BaseModel): + output: str + expected: str + + project = projects.create("test-project") + tags = ["accuracy", "v1"] + + def my_scorer(output: str, expected: str) -> float: + return 1.0 if output == expected else 0.0 + + scorer = project.scorers.create( + handler=my_scorer, + name="test-scorer", + parameters=ScorerInput, + tags=tags, ) - global_.functions.append(tool) - mock_project_ids = MagicMock(spec=ProjectIdCache) - mock_project_ids.get.return_value = "project-123" + assert scorer.tags == tags - functions = [] - _collect_function_function_defs(mock_project_ids, functions, "bundle-123", "error") + def test_llm_scorer_with_tags(self): + project = projects.create("test-project") + tags = ["classifier", "v2"] - assert len(functions) == 1 - assert "metadata" not in functions[0] + scorer = project.scorers.create( + name="llm-scorer", + prompt="Is this correct?", + model="gpt-4", + use_cot=True, + choice_scores={"yes": 1.0, "no": 0.0}, + tags=tags, + ) - global_.functions.clear() + assert scorer.tags == tags