From d2246788e1013456bc6f9da6d1962b936939e8a9 Mon Sep 17 00:00:00 2001 From: John Huang Date: Wed, 6 May 2026 16:12:54 -0700 Subject: [PATCH 1/5] allow TypedDict prompt data type --- py/src/braintrust/parameters.py | 125 ++++++++++++------ .../type_tests/test_parameters_schema.py | 31 +++++ 2 files changed, 115 insertions(+), 41 deletions(-) create mode 100644 py/src/braintrust/type_tests/test_parameters_schema.py diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index 0d87ed4d..43c96352 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -3,12 +3,13 @@ from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, cast -from jsonschema import Draft7Validator -from jsonschema.exceptions import ValidationError as JSONSchemaValidationError -from typing_extensions import NotRequired +from jsonschema import Draft7Validator # type: ignore[import-untyped] +from jsonschema.exceptions import ValidationError as JSONSchemaValidationError # type: ignore[import-untyped] +from typing_extensions import NotRequired, TypeGuard +from .generated_types import PromptBlockDataNullish, PromptOptionsNullish from .prompt import PromptData from .serializable_data_class import SerializableDataClass @@ -17,12 +18,20 @@ from .logger import Prompt +class PromptDataDefaultDict(TypedDict): + prompt: NotRequired[PromptBlockDataNullish | None] + options: NotRequired[PromptOptionsNullish | None] + + +PromptParameterDefault = PromptDataDefaultDict + + class PromptParameter(TypedDict): """A prompt parameter specification.""" type: Literal["prompt"] name: NotRequired[str | None] - default: NotRequired[PromptData | None] + default: NotRequired[PromptParameterDefault | None] description: NotRequired[str | None] @@ -41,6 +50,18 @@ class ModelParameter(TypedDict): ParametersSchema = Mapping[str, Any] +class _PydanticModelType(Protocol): + def __call__(self) -> Any: ... + + +class _PydanticV1ModelType(_PydanticModelType, Protocol): + def parse_obj(self, obj: Any) -> Any: ... + + +class _PydanticV2ModelType(_PydanticModelType, Protocol): + def model_validate(self, obj: Any) -> Any: ... + + @dataclass class RemoteEvalParameters(SerializableDataClass): id: str | None @@ -81,16 +102,24 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]: raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model") -def _is_prompt_parameter(schema: Any) -> bool: +def _is_prompt_parameter(schema: object) -> TypeGuard[PromptParameter]: return isinstance(schema, dict) and schema.get("type") == "prompt" -def _is_model_parameter(schema: Any) -> bool: +def _is_model_parameter(schema: object) -> TypeGuard[ModelParameter]: return isinstance(schema, dict) and schema.get("type") == "model" -def _is_pydantic_model(schema: Any) -> bool: - return hasattr(schema, "parse_obj") or hasattr(schema, "model_validate") +def _is_pydantic_model(schema: object) -> TypeGuard[_PydanticModelType]: + return _is_pydantic_v1_model(schema) or _is_pydantic_v2_model(schema) + + +def _is_pydantic_v1_model(schema: object) -> TypeGuard[_PydanticV1ModelType]: + return isinstance(schema, type) and hasattr(schema, "parse_obj") + + +def _is_pydantic_v2_model(schema: object) -> TypeGuard[_PydanticV2ModelType]: + return isinstance(schema, type) and hasattr(schema, "model_validate") def _get_pydantic_fields(schema: Any) -> dict[str, Any]: @@ -152,14 +181,18 @@ def _resolve_local_json_schema_refs( def _serialize_pydantic_parameter_schema(schema: Any) -> dict[str, Any]: schema_json = _pydantic_to_json_schema(schema) - schema_json = _resolve_local_json_schema_refs(schema_json, schema_json) + schema_root = cast(dict[str, JSONValue], schema_json) + schema_json = cast( + dict[str, Any], + _resolve_local_json_schema_refs(schema_root, schema_root), + ) schema_json.pop("$defs", None) schema_json.pop("definitions", None) fields = _get_pydantic_fields(schema) if len(fields) == 1 and "value" in fields: properties = schema_json.get("properties") if isinstance(properties, dict) and isinstance(properties.get("value"), dict): - return dict(properties["value"]) + return dict(cast(Mapping[str, Any], properties["value"])) return schema_json @@ -186,13 +219,13 @@ def is_eval_parameter_schema(schema: Any) -> bool: def _prompt_data_to_dict( - prompt_data: PromptData | dict[str, Any] | None, + prompt_data: PromptParameterDefault | PromptData | None, ) -> dict[str, Any] | None: if prompt_data is None: return None if isinstance(prompt_data, PromptData): return prompt_data.as_dict() - return prompt_data + return dict(prompt_data) def _create_prompt(name: str, prompt_data: dict[str, Any]) -> "Prompt": @@ -201,7 +234,7 @@ def _create_prompt(name: str, prompt_data: dict[str, Any]) -> "Prompt": return Prompt.from_prompt_data(name, PromptData.from_dict_deep(prompt_data)) -def _apply_defaults_to_json_schema_instance(instance: Any, schema: dict[str, Any]) -> Any: +def _apply_defaults_to_json_schema_instance(instance: Any, schema: Mapping[str, Any]) -> Any: if not isinstance(instance, dict): return instance if schema.get("type") != "object": @@ -283,10 +316,12 @@ def _validate_local_parameters( prompt_data = None if value is not None: prompt_data = value - elif schema.get("default") is not None: - prompt_data = _prompt_data_to_dict(schema["default"]) else: - raise ValueError(f"Parameter '{name}' is required") + default = schema.get("default") + if default is not None: + prompt_data = _prompt_data_to_dict(default) + else: + raise ValueError(f"Parameter '{name}' is required") if prompt_data is None: raise ValueError(f"Parameter '{name}' is required") @@ -310,20 +345,26 @@ def _validate_local_parameters( result[name] = default_instance.value except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc - elif hasattr(schema, "parse_obj"): - result[name] = schema.parse_obj({"value": value}).value else: - result[name] = schema.model_validate({"value": value}).value + if _is_pydantic_v1_model(schema): + result[name] = schema.parse_obj({"value": value}).value + elif _is_pydantic_v2_model(schema): + result[name] = schema.model_validate({"value": value}).value + else: + raise ValueError(f"Cannot validate {schema} - not a pydantic model") else: if value is None: try: result[name] = schema() except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc - elif hasattr(schema, "parse_obj"): - result[name] = schema.parse_obj(value) else: - result[name] = schema.model_validate(value) + if _is_pydantic_v1_model(schema): + result[name] = schema.parse_obj(value) + elif _is_pydantic_v2_model(schema): + result[name] = schema.model_validate(value) + else: + raise ValueError(f"Cannot validate {schema} - not a pydantic model") else: result[name] = value except JSONSchemaValidationError as exc: @@ -368,23 +409,23 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]: for name, schema in parameters.items(): if _is_prompt_parameter(schema): - parameter_data = { + prompt_parameter_data: dict[str, Any] = { "type": "prompt", "description": schema.get("description"), } - default = schema.get("default") - if default is not None: - parameter_data["default"] = _prompt_data_to_dict(default) - result[name] = parameter_data + prompt_default = schema.get("default") + if prompt_default is not None: + prompt_parameter_data["default"] = _prompt_data_to_dict(prompt_default) + result[name] = prompt_parameter_data elif _is_model_parameter(schema): - parameter_data = { + model_parameter_data: dict[str, Any] = { "type": "model", "description": schema.get("description"), } - default = schema.get("default") - if default is not None: - parameter_data["default"] = default - result[name] = parameter_data + model_default = schema.get("default") + if model_default is not None: + model_parameter_data["default"] = model_default + result[name] = model_parameter_data elif schema is None: result[name] = { "type": "data", @@ -392,14 +433,14 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]: } else: schema_json = _serialize_pydantic_parameter_schema(schema) - parameter_data = { + data_parameter_data: dict[str, Any] = { "type": "data", "schema": schema_json, "description": schema_json.get("description"), } if "default" in schema_json: - parameter_data["default"] = schema_json["default"] - result[name] = parameter_data + data_parameter_data["default"] = schema_json["default"] + result[name] = data_parameter_data return result @@ -436,8 +477,9 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: default = _prompt_data_to_dict(schema.get("default")) if default is not None: property_schema["default"] = default - if schema.get("description") is not None: - property_schema["description"] = schema["description"] + description = schema.get("description") + if description is not None: + property_schema["description"] = description properties[name] = property_schema elif _is_model_parameter(schema): property_schema = { @@ -446,8 +488,9 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: } if "default" in schema: property_schema["default"] = schema.get("default") - if schema.get("description") is not None: - property_schema["description"] = schema["description"] + description = schema.get("description") + if description is not None: + property_schema["description"] = description properties[name] = property_schema elif schema is None: properties[name] = {} @@ -457,7 +500,7 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: if _parameter_required(schema): required.append(name) - result: ParametersSchema = { + result: dict[str, Any] = { "type": "object", "properties": properties, "additionalProperties": True, diff --git a/py/src/braintrust/type_tests/test_parameters_schema.py b/py/src/braintrust/type_tests/test_parameters_schema.py new file mode 100644 index 00000000..920f2816 --- /dev/null +++ b/py/src/braintrust/type_tests/test_parameters_schema.py @@ -0,0 +1,31 @@ +"""Type-check tests for saved parameter schema definitions.""" + +from braintrust import projects +from braintrust.parameters import EvalParameters + + +def test_parameters_create_accepts_discriminated_schema_entries() -> None: + project = projects.create("test-project") + + schema: EvalParameters = { + "model": { + "type": "model", + "default": "gpt-5-mini", + }, + "main": { + "type": "prompt", + "default": { + "prompt": { + "type": "chat", + "messages": [{"role": "user", "content": "{{input}}"}], + }, + "options": { + "model": "gpt-5-mini", + }, + }, + }, + } + + created = project.parameters.create(name="test-parameters", schema=schema) + + assert created is schema From 8f794bf1adda4462b511b18aae66915dab8710a2 Mon Sep 17 00:00:00 2001 From: John Huang Date: Wed, 6 May 2026 16:16:27 -0700 Subject: [PATCH 2/5] test for pydantic versions --- py/src/braintrust/test_parameters.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/py/src/braintrust/test_parameters.py b/py/src/braintrust/test_parameters.py index 74e7f2a1..601e00c1 100644 --- a/py/src/braintrust/test_parameters.py +++ b/py/src/braintrust/test_parameters.py @@ -179,6 +179,34 @@ def test_validate_remote_parameters_allows_prompt_overrides(): assert built["model"] == "gpt-5-nano" +def test_validate_local_parameters_accepts_v1_and_v2_pydantic_shapes(): + class LegacyParameter: + @classmethod + def parse_obj(cls, value): + return {"validated_by": "parse_obj", "value": value} + + class ModernParameter: + @classmethod + def model_validate(cls, value): + return {"validated_by": "model_validate", "value": value} + + result = validate_parameters( + { + "legacy": {"input": "legacy"}, + "modern": {"input": "modern"}, + }, + { + "legacy": LegacyParameter, + "modern": ModernParameter, + }, + ) + + assert result == { + "legacy": {"validated_by": "parse_obj", "value": {"input": "legacy"}}, + "modern": {"validated_by": "model_validate", "value": {"input": "modern"}}, + } + + @pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") def test_parameters_to_json_schema_uses_scalar_schema_for_single_value_models(): from pydantic import BaseModel From c7594dfea86657ea13359254708d76ae36b7b0c9 Mon Sep 17 00:00:00 2001 From: John Huang Date: Wed, 6 May 2026 16:26:11 -0700 Subject: [PATCH 3/5] remove type suppressions --- py/src/braintrust/parameters.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index 43c96352..3cd586a8 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -5,8 +5,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, cast -from jsonschema import Draft7Validator # type: ignore[import-untyped] -from jsonschema.exceptions import ValidationError as JSONSchemaValidationError # type: ignore[import-untyped] +from jsonschema import Draft7Validator +from jsonschema.exceptions import ValidationError as JSONSchemaValidationError from typing_extensions import NotRequired, TypeGuard from .generated_types import PromptBlockDataNullish, PromptOptionsNullish From bab0e113bcdbc8f659bb75298648445ec6209ab5 Mon Sep 17 00:00:00 2001 From: John Huang Date: Wed, 6 May 2026 16:27:18 -0700 Subject: [PATCH 4/5] remove unneccesary type --- py/src/braintrust/parameters.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index 3cd586a8..a1cc321a 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -18,20 +18,17 @@ from .logger import Prompt -class PromptDataDefaultDict(TypedDict): +class PromptDataDict(TypedDict): prompt: NotRequired[PromptBlockDataNullish | None] options: NotRequired[PromptOptionsNullish | None] -PromptParameterDefault = PromptDataDefaultDict - - class PromptParameter(TypedDict): """A prompt parameter specification.""" type: Literal["prompt"] name: NotRequired[str | None] - default: NotRequired[PromptParameterDefault | None] + default: NotRequired[PromptDataDict | None] description: NotRequired[str | None] @@ -219,7 +216,7 @@ def is_eval_parameter_schema(schema: Any) -> bool: def _prompt_data_to_dict( - prompt_data: PromptParameterDefault | PromptData | None, + prompt_data: PromptDataDict | PromptData | None, ) -> dict[str, Any] | None: if prompt_data is None: return None From 6d8438bb7e1b768014d9c066584964e1794af478 Mon Sep 17 00:00:00 2001 From: john Date: Wed, 6 May 2026 17:33:26 -0700 Subject: [PATCH 5/5] Update py/src/braintrust/parameters.py Co-authored-by: Abhijeet Prasad --- py/src/braintrust/parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index a1cc321a..c8b71ddc 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -28,7 +28,7 @@ class PromptParameter(TypedDict): type: Literal["prompt"] name: NotRequired[str | None] - default: NotRequired[PromptDataDict | None] + default: NotRequired[PromptDataDict | PromptData | None] description: NotRequired[str | None]