diff --git a/py/src/braintrust/parameters.py b/py/src/braintrust/parameters.py index 0d87ed4d..c8b71ddc 100644 --- a/py/src/braintrust/parameters.py +++ b/py/src/braintrust/parameters.py @@ -3,12 +3,13 @@ from collections.abc import Mapping from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal, TypedDict +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, cast from jsonschema import Draft7Validator from jsonschema.exceptions import ValidationError as JSONSchemaValidationError -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypeGuard +from .generated_types import PromptBlockDataNullish, PromptOptionsNullish from .prompt import PromptData from .serializable_data_class import SerializableDataClass @@ -17,12 +18,17 @@ from .logger import Prompt +class PromptDataDict(TypedDict): + prompt: NotRequired[PromptBlockDataNullish | None] + options: NotRequired[PromptOptionsNullish | None] + + class PromptParameter(TypedDict): """A prompt parameter specification.""" type: Literal["prompt"] name: NotRequired[str | None] - default: NotRequired[PromptData | None] + default: NotRequired[PromptDataDict | PromptData | None] description: NotRequired[str | None] @@ -41,6 +47,18 @@ class ModelParameter(TypedDict): ParametersSchema = Mapping[str, Any] +class _PydanticModelType(Protocol): + def __call__(self) -> Any: ... + + +class _PydanticV1ModelType(_PydanticModelType, Protocol): + def parse_obj(self, obj: Any) -> Any: ... + + +class _PydanticV2ModelType(_PydanticModelType, Protocol): + def model_validate(self, obj: Any) -> Any: ... + + @dataclass class RemoteEvalParameters(SerializableDataClass): id: str | None @@ -81,16 +99,24 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]: raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model") -def _is_prompt_parameter(schema: Any) -> bool: +def _is_prompt_parameter(schema: object) -> TypeGuard[PromptParameter]: return isinstance(schema, dict) and schema.get("type") == "prompt" -def _is_model_parameter(schema: Any) -> bool: +def _is_model_parameter(schema: object) -> TypeGuard[ModelParameter]: return isinstance(schema, dict) and schema.get("type") == "model" -def _is_pydantic_model(schema: Any) -> bool: - return hasattr(schema, "parse_obj") or hasattr(schema, "model_validate") +def _is_pydantic_model(schema: object) -> TypeGuard[_PydanticModelType]: + return _is_pydantic_v1_model(schema) or _is_pydantic_v2_model(schema) + + +def _is_pydantic_v1_model(schema: object) -> TypeGuard[_PydanticV1ModelType]: + return isinstance(schema, type) and hasattr(schema, "parse_obj") + + +def _is_pydantic_v2_model(schema: object) -> TypeGuard[_PydanticV2ModelType]: + return isinstance(schema, type) and hasattr(schema, "model_validate") def _get_pydantic_fields(schema: Any) -> dict[str, Any]: @@ -152,14 +178,18 @@ def _resolve_local_json_schema_refs( def _serialize_pydantic_parameter_schema(schema: Any) -> dict[str, Any]: schema_json = _pydantic_to_json_schema(schema) - schema_json = _resolve_local_json_schema_refs(schema_json, schema_json) + schema_root = cast(dict[str, JSONValue], schema_json) + schema_json = cast( + dict[str, Any], + _resolve_local_json_schema_refs(schema_root, schema_root), + ) schema_json.pop("$defs", None) schema_json.pop("definitions", None) fields = _get_pydantic_fields(schema) if len(fields) == 1 and "value" in fields: properties = schema_json.get("properties") if isinstance(properties, dict) and isinstance(properties.get("value"), dict): - return dict(properties["value"]) + return dict(cast(Mapping[str, Any], properties["value"])) return schema_json @@ -186,13 +216,13 @@ def is_eval_parameter_schema(schema: Any) -> bool: def _prompt_data_to_dict( - prompt_data: PromptData | dict[str, Any] | None, + prompt_data: PromptDataDict | PromptData | None, ) -> dict[str, Any] | None: if prompt_data is None: return None if isinstance(prompt_data, PromptData): return prompt_data.as_dict() - return prompt_data + return dict(prompt_data) def _create_prompt(name: str, prompt_data: dict[str, Any]) -> "Prompt": @@ -201,7 +231,7 @@ def _create_prompt(name: str, prompt_data: dict[str, Any]) -> "Prompt": return Prompt.from_prompt_data(name, PromptData.from_dict_deep(prompt_data)) -def _apply_defaults_to_json_schema_instance(instance: Any, schema: dict[str, Any]) -> Any: +def _apply_defaults_to_json_schema_instance(instance: Any, schema: Mapping[str, Any]) -> Any: if not isinstance(instance, dict): return instance if schema.get("type") != "object": @@ -283,10 +313,12 @@ def _validate_local_parameters( prompt_data = None if value is not None: prompt_data = value - elif schema.get("default") is not None: - prompt_data = _prompt_data_to_dict(schema["default"]) else: - raise ValueError(f"Parameter '{name}' is required") + default = schema.get("default") + if default is not None: + prompt_data = _prompt_data_to_dict(default) + else: + raise ValueError(f"Parameter '{name}' is required") if prompt_data is None: raise ValueError(f"Parameter '{name}' is required") @@ -310,20 +342,26 @@ def _validate_local_parameters( result[name] = default_instance.value except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc - elif hasattr(schema, "parse_obj"): - result[name] = schema.parse_obj({"value": value}).value else: - result[name] = schema.model_validate({"value": value}).value + if _is_pydantic_v1_model(schema): + result[name] = schema.parse_obj({"value": value}).value + elif _is_pydantic_v2_model(schema): + result[name] = schema.model_validate({"value": value}).value + else: + raise ValueError(f"Cannot validate {schema} - not a pydantic model") else: if value is None: try: result[name] = schema() except Exception as exc: raise ValueError(f"Parameter '{name}' is required") from exc - elif hasattr(schema, "parse_obj"): - result[name] = schema.parse_obj(value) else: - result[name] = schema.model_validate(value) + if _is_pydantic_v1_model(schema): + result[name] = schema.parse_obj(value) + elif _is_pydantic_v2_model(schema): + result[name] = schema.model_validate(value) + else: + raise ValueError(f"Cannot validate {schema} - not a pydantic model") else: result[name] = value except JSONSchemaValidationError as exc: @@ -368,23 +406,23 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]: for name, schema in parameters.items(): if _is_prompt_parameter(schema): - parameter_data = { + prompt_parameter_data: dict[str, Any] = { "type": "prompt", "description": schema.get("description"), } - default = schema.get("default") - if default is not None: - parameter_data["default"] = _prompt_data_to_dict(default) - result[name] = parameter_data + prompt_default = schema.get("default") + if prompt_default is not None: + prompt_parameter_data["default"] = _prompt_data_to_dict(prompt_default) + result[name] = prompt_parameter_data elif _is_model_parameter(schema): - parameter_data = { + model_parameter_data: dict[str, Any] = { "type": "model", "description": schema.get("description"), } - default = schema.get("default") - if default is not None: - parameter_data["default"] = default - result[name] = parameter_data + model_default = schema.get("default") + if model_default is not None: + model_parameter_data["default"] = model_default + result[name] = model_parameter_data elif schema is None: result[name] = { "type": "data", @@ -392,14 +430,14 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]: } else: schema_json = _serialize_pydantic_parameter_schema(schema) - parameter_data = { + data_parameter_data: dict[str, Any] = { "type": "data", "schema": schema_json, "description": schema_json.get("description"), } if "default" in schema_json: - parameter_data["default"] = schema_json["default"] - result[name] = parameter_data + data_parameter_data["default"] = schema_json["default"] + result[name] = data_parameter_data return result @@ -436,8 +474,9 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: default = _prompt_data_to_dict(schema.get("default")) if default is not None: property_schema["default"] = default - if schema.get("description") is not None: - property_schema["description"] = schema["description"] + description = schema.get("description") + if description is not None: + property_schema["description"] = description properties[name] = property_schema elif _is_model_parameter(schema): property_schema = { @@ -446,8 +485,9 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: } if "default" in schema: property_schema["default"] = schema.get("default") - if schema.get("description") is not None: - property_schema["description"] = schema["description"] + description = schema.get("description") + if description is not None: + property_schema["description"] = description properties[name] = property_schema elif schema is None: properties[name] = {} @@ -457,7 +497,7 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema: if _parameter_required(schema): required.append(name) - result: ParametersSchema = { + result: dict[str, Any] = { "type": "object", "properties": properties, "additionalProperties": True, diff --git a/py/src/braintrust/test_parameters.py b/py/src/braintrust/test_parameters.py index 74e7f2a1..601e00c1 100644 --- a/py/src/braintrust/test_parameters.py +++ b/py/src/braintrust/test_parameters.py @@ -179,6 +179,34 @@ def test_validate_remote_parameters_allows_prompt_overrides(): assert built["model"] == "gpt-5-nano" +def test_validate_local_parameters_accepts_v1_and_v2_pydantic_shapes(): + class LegacyParameter: + @classmethod + def parse_obj(cls, value): + return {"validated_by": "parse_obj", "value": value} + + class ModernParameter: + @classmethod + def model_validate(cls, value): + return {"validated_by": "model_validate", "value": value} + + result = validate_parameters( + { + "legacy": {"input": "legacy"}, + "modern": {"input": "modern"}, + }, + { + "legacy": LegacyParameter, + "modern": ModernParameter, + }, + ) + + assert result == { + "legacy": {"validated_by": "parse_obj", "value": {"input": "legacy"}}, + "modern": {"validated_by": "model_validate", "value": {"input": "modern"}}, + } + + @pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed") def test_parameters_to_json_schema_uses_scalar_schema_for_single_value_models(): from pydantic import BaseModel diff --git a/py/src/braintrust/type_tests/test_parameters_schema.py b/py/src/braintrust/type_tests/test_parameters_schema.py new file mode 100644 index 00000000..920f2816 --- /dev/null +++ b/py/src/braintrust/type_tests/test_parameters_schema.py @@ -0,0 +1,31 @@ +"""Type-check tests for saved parameter schema definitions.""" + +from braintrust import projects +from braintrust.parameters import EvalParameters + + +def test_parameters_create_accepts_discriminated_schema_entries() -> None: + project = projects.create("test-project") + + schema: EvalParameters = { + "model": { + "type": "model", + "default": "gpt-5-mini", + }, + "main": { + "type": "prompt", + "default": { + "prompt": { + "type": "chat", + "messages": [{"role": "user", "content": "{{input}}"}], + }, + "options": { + "model": "gpt-5-mini", + }, + }, + }, + } + + created = project.parameters.create(name="test-parameters", schema=schema) + + assert created is schema