Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 79 additions & 39 deletions py/src/braintrust/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from collections.abc import Mapping
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, TypedDict
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypedDict, cast

from jsonschema import Draft7Validator
from jsonschema.exceptions import ValidationError as JSONSchemaValidationError
from typing_extensions import NotRequired
from typing_extensions import NotRequired, TypeGuard

from .generated_types import PromptBlockDataNullish, PromptOptionsNullish
from .prompt import PromptData
from .serializable_data_class import SerializableDataClass

Expand All @@ -17,12 +18,17 @@
from .logger import Prompt


class PromptDataDict(TypedDict):
prompt: NotRequired[PromptBlockDataNullish | None]
options: NotRequired[PromptOptionsNullish | None]


class PromptParameter(TypedDict):
"""A prompt parameter specification."""

type: Literal["prompt"]
name: NotRequired[str | None]
default: NotRequired[PromptData | None]
default: NotRequired[PromptDataDict | PromptData | None]
description: NotRequired[str | None]


Expand All @@ -41,6 +47,18 @@ class ModelParameter(TypedDict):
ParametersSchema = Mapping[str, Any]


class _PydanticModelType(Protocol):
def __call__(self) -> Any: ...


class _PydanticV1ModelType(_PydanticModelType, Protocol):
def parse_obj(self, obj: Any) -> Any: ...


class _PydanticV2ModelType(_PydanticModelType, Protocol):
def model_validate(self, obj: Any) -> Any: ...


@dataclass
class RemoteEvalParameters(SerializableDataClass):
id: str | None
Expand Down Expand Up @@ -81,16 +99,24 @@ def _pydantic_to_json_schema(model: Any) -> dict[str, Any]:
raise ValueError(f"Cannot convert {model} to JSON schema - not a pydantic model")


def _is_prompt_parameter(schema: Any) -> bool:
def _is_prompt_parameter(schema: object) -> TypeGuard[PromptParameter]:
return isinstance(schema, dict) and schema.get("type") == "prompt"


def _is_model_parameter(schema: Any) -> bool:
def _is_model_parameter(schema: object) -> TypeGuard[ModelParameter]:
return isinstance(schema, dict) and schema.get("type") == "model"


def _is_pydantic_model(schema: Any) -> bool:
return hasattr(schema, "parse_obj") or hasattr(schema, "model_validate")
def _is_pydantic_model(schema: object) -> TypeGuard[_PydanticModelType]:
return _is_pydantic_v1_model(schema) or _is_pydantic_v2_model(schema)


def _is_pydantic_v1_model(schema: object) -> TypeGuard[_PydanticV1ModelType]:
return isinstance(schema, type) and hasattr(schema, "parse_obj")


def _is_pydantic_v2_model(schema: object) -> TypeGuard[_PydanticV2ModelType]:
return isinstance(schema, type) and hasattr(schema, "model_validate")


def _get_pydantic_fields(schema: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -152,14 +178,18 @@ def _resolve_local_json_schema_refs(

def _serialize_pydantic_parameter_schema(schema: Any) -> dict[str, Any]:
schema_json = _pydantic_to_json_schema(schema)
schema_json = _resolve_local_json_schema_refs(schema_json, schema_json)
schema_root = cast(dict[str, JSONValue], schema_json)
schema_json = cast(
dict[str, Any],
_resolve_local_json_schema_refs(schema_root, schema_root),
)
schema_json.pop("$defs", None)
schema_json.pop("definitions", None)
fields = _get_pydantic_fields(schema)
if len(fields) == 1 and "value" in fields:
properties = schema_json.get("properties")
if isinstance(properties, dict) and isinstance(properties.get("value"), dict):
return dict(properties["value"])
return dict(cast(Mapping[str, Any], properties["value"]))
return schema_json


Expand All @@ -186,13 +216,13 @@ def is_eval_parameter_schema(schema: Any) -> bool:


def _prompt_data_to_dict(
prompt_data: PromptData | dict[str, Any] | None,
prompt_data: PromptDataDict | PromptData | None,
) -> dict[str, Any] | None:
if prompt_data is None:
return None
if isinstance(prompt_data, PromptData):
return prompt_data.as_dict()
return prompt_data
return dict(prompt_data)


def _create_prompt(name: str, prompt_data: dict[str, Any]) -> "Prompt":
Expand All @@ -201,7 +231,7 @@ def _create_prompt(name: str, prompt_data: dict[str, Any]) -> "Prompt":
return Prompt.from_prompt_data(name, PromptData.from_dict_deep(prompt_data))


def _apply_defaults_to_json_schema_instance(instance: Any, schema: dict[str, Any]) -> Any:
def _apply_defaults_to_json_schema_instance(instance: Any, schema: Mapping[str, Any]) -> Any:
if not isinstance(instance, dict):
return instance
if schema.get("type") != "object":
Expand Down Expand Up @@ -283,10 +313,12 @@ def _validate_local_parameters(
prompt_data = None
if value is not None:
prompt_data = value
elif schema.get("default") is not None:
prompt_data = _prompt_data_to_dict(schema["default"])
else:
raise ValueError(f"Parameter '{name}' is required")
default = schema.get("default")
if default is not None:
prompt_data = _prompt_data_to_dict(default)
else:
raise ValueError(f"Parameter '{name}' is required")

if prompt_data is None:
raise ValueError(f"Parameter '{name}' is required")
Expand All @@ -310,20 +342,26 @@ def _validate_local_parameters(
result[name] = default_instance.value
except Exception as exc:
raise ValueError(f"Parameter '{name}' is required") from exc
elif hasattr(schema, "parse_obj"):
result[name] = schema.parse_obj({"value": value}).value
else:
result[name] = schema.model_validate({"value": value}).value
if _is_pydantic_v1_model(schema):
result[name] = schema.parse_obj({"value": value}).value
elif _is_pydantic_v2_model(schema):
result[name] = schema.model_validate({"value": value}).value
else:
raise ValueError(f"Cannot validate {schema} - not a pydantic model")
else:
if value is None:
try:
result[name] = schema()
except Exception as exc:
raise ValueError(f"Parameter '{name}' is required") from exc
elif hasattr(schema, "parse_obj"):
result[name] = schema.parse_obj(value)
else:
result[name] = schema.model_validate(value)
if _is_pydantic_v1_model(schema):
result[name] = schema.parse_obj(value)
elif _is_pydantic_v2_model(schema):
result[name] = schema.model_validate(value)
else:
raise ValueError(f"Cannot validate {schema} - not a pydantic model")
else:
result[name] = value
except JSONSchemaValidationError as exc:
Expand Down Expand Up @@ -368,38 +406,38 @@ def serialize_eval_parameters(parameters: EvalParameters) -> dict[str, Any]:

for name, schema in parameters.items():
if _is_prompt_parameter(schema):
parameter_data = {
prompt_parameter_data: dict[str, Any] = {
"type": "prompt",
"description": schema.get("description"),
}
default = schema.get("default")
if default is not None:
parameter_data["default"] = _prompt_data_to_dict(default)
result[name] = parameter_data
prompt_default = schema.get("default")
if prompt_default is not None:
prompt_parameter_data["default"] = _prompt_data_to_dict(prompt_default)
result[name] = prompt_parameter_data
elif _is_model_parameter(schema):
parameter_data = {
model_parameter_data: dict[str, Any] = {
"type": "model",
"description": schema.get("description"),
}
default = schema.get("default")
if default is not None:
parameter_data["default"] = default
result[name] = parameter_data
model_default = schema.get("default")
if model_default is not None:
model_parameter_data["default"] = model_default
result[name] = model_parameter_data
elif schema is None:
result[name] = {
"type": "data",
"schema": {},
}
else:
schema_json = _serialize_pydantic_parameter_schema(schema)
parameter_data = {
data_parameter_data: dict[str, Any] = {
"type": "data",
"schema": schema_json,
"description": schema_json.get("description"),
}
if "default" in schema_json:
parameter_data["default"] = schema_json["default"]
result[name] = parameter_data
data_parameter_data["default"] = schema_json["default"]
result[name] = data_parameter_data

return result

Expand Down Expand Up @@ -436,8 +474,9 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema:
default = _prompt_data_to_dict(schema.get("default"))
if default is not None:
property_schema["default"] = default
if schema.get("description") is not None:
property_schema["description"] = schema["description"]
description = schema.get("description")
if description is not None:
property_schema["description"] = description
properties[name] = property_schema
elif _is_model_parameter(schema):
property_schema = {
Expand All @@ -446,8 +485,9 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema:
}
if "default" in schema:
property_schema["default"] = schema.get("default")
if schema.get("description") is not None:
property_schema["description"] = schema["description"]
description = schema.get("description")
if description is not None:
property_schema["description"] = description
properties[name] = property_schema
elif schema is None:
properties[name] = {}
Expand All @@ -457,7 +497,7 @@ def parameters_to_json_schema(parameters: EvalParameters) -> ParametersSchema:
if _parameter_required(schema):
required.append(name)

result: ParametersSchema = {
result: dict[str, Any] = {
"type": "object",
"properties": properties,
"additionalProperties": True,
Expand Down
28 changes: 28 additions & 0 deletions py/src/braintrust/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,34 @@ def test_validate_remote_parameters_allows_prompt_overrides():
assert built["model"] == "gpt-5-nano"


def test_validate_local_parameters_accepts_v1_and_v2_pydantic_shapes():
class LegacyParameter:
@classmethod
def parse_obj(cls, value):
return {"validated_by": "parse_obj", "value": value}

class ModernParameter:
@classmethod
def model_validate(cls, value):
return {"validated_by": "model_validate", "value": value}

result = validate_parameters(
{
"legacy": {"input": "legacy"},
"modern": {"input": "modern"},
},
{
"legacy": LegacyParameter,
"modern": ModernParameter,
},
)

assert result == {
"legacy": {"validated_by": "parse_obj", "value": {"input": "legacy"}},
"modern": {"validated_by": "model_validate", "value": {"input": "modern"}},
}


@pytest.mark.skipif(not HAS_PYDANTIC, reason="pydantic not installed")
def test_parameters_to_json_schema_uses_scalar_schema_for_single_value_models():
from pydantic import BaseModel
Expand Down
31 changes: 31 additions & 0 deletions py/src/braintrust/type_tests/test_parameters_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Type-check tests for saved parameter schema definitions."""

from braintrust import projects
from braintrust.parameters import EvalParameters


def test_parameters_create_accepts_discriminated_schema_entries() -> None:
project = projects.create("test-project")

schema: EvalParameters = {
"model": {
"type": "model",
"default": "gpt-5-mini",
},
"main": {
"type": "prompt",
"default": {
"prompt": {
"type": "chat",
"messages": [{"role": "user", "content": "{{input}}"}],
},
"options": {
"model": "gpt-5-mini",
},
},
},
}

created = project.parameters.create(name="test-parameters", schema=schema)

assert created is schema
Loading