Skip to content

Commit d22ed3a

Browse files
committed
Address PR comments for python parameters impl
1 parent 1b28df4 commit d22ed3a

7 files changed

Lines changed: 175 additions & 123 deletions

File tree

py/src/braintrust/cli/eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def run_evaluator_task(evaluator, position, opts: EvaluatorOpts):
133133
dataset = evaluator.data
134134

135135
parameters = None
136-
if RemoteEvalParameters.is_parameters(evaluator.parameters) and evaluator.parameters.id is not None:
136+
if isinstance(evaluator.parameters, RemoteEvalParameters) and evaluator.parameters.id is not None:
137137
parameters = {"id": evaluator.parameters.id}
138138
if evaluator.parameters.version is not None:
139139
parameters["version"] = evaluator.parameters.version

py/src/braintrust/devserver/eval_hooks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
from collections.abc import Callable
1111
from typing import Any
1212

13+
from ..parameters import ValidatedParameters
14+
1315

1416
class EvalHooks:
1517
"""Hooks provided to eval tasks for progress reporting."""
1618

1719
def __init__(
1820
self,
1921
report_progress: Callable[[dict[str, Any]], None] | None = None,
20-
parameters: dict[str, Any] | None = None,
22+
parameters: ValidatedParameters | None = None,
2123
):
2224
self._report_progress = report_progress
2325
self.parameters = parameters or {}

py/src/braintrust/devserver/server.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,22 @@
2525
)
2626
)
2727

28-
from ..framework import EvalAsync, EvalScorer, Evaluator, ExperimentSummary, SSEProgressEvent
28+
from ..framework import (
29+
EvalAsync,
30+
EvalHooks,
31+
EvalScorer,
32+
Evaluator,
33+
ExperimentSummary,
34+
SSEProgressEvent,
35+
)
2936
from ..generated_types import FunctionId
3037
from ..logger import BraintrustState, bt_iscoroutinefunction
31-
from ..parameters import RemoteEvalParameters, serialize_remote_eval_parameters_container, validate_parameters
38+
from ..parameters import (
39+
RemoteEvalParameters,
40+
ValidatedParameters,
41+
serialize_remote_eval_parameters_container,
42+
validate_parameters,
43+
)
3244
from ..span_identifier_v4 import parse_parent
3345
from .auth import AuthorizationMiddleware
3446
from .cache import cached_login
@@ -42,39 +54,16 @@
4254

4355

4456
class _ParameterOverrideHooks:
45-
def __init__(self, hooks: Any, parameters: dict[str, Any]):
57+
def __init__(self, hooks: EvalHooks[Any], parameters: ValidatedParameters):
4658
self._hooks = hooks
4759
self._parameters = parameters
4860

4961
@property
50-
def metadata(self):
51-
return self._hooks.metadata
52-
53-
@property
54-
def expected(self):
55-
return self._hooks.expected
56-
57-
@property
58-
def span(self):
59-
return self._hooks.span
60-
61-
@property
62-
def trial_index(self):
63-
return self._hooks.trial_index
64-
65-
@property
66-
def tags(self):
67-
return self._hooks.tags
68-
69-
@property
70-
def parameters(self):
62+
def parameters(self) -> ValidatedParameters:
7163
return self._parameters
7264

73-
def report_progress(self, progress):
74-
return self._hooks.report_progress(progress)
75-
76-
def meta(self, **info: Any):
77-
return self._hooks.meta(**info)
65+
def __getattr__(self, name: str):
66+
return getattr(self._hooks, name)
7867

7968

8069
class CheckAuthorizedMiddleware(BaseHTTPMiddleware):
@@ -192,7 +181,7 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse:
192181
# Set up SSE headers for streaming
193182
sse_queue = SSEQueue()
194183

195-
async def task(input, hooks):
184+
async def task(input: Any, hooks: EvalHooks[Any]):
196185
task_hooks = hooks if validated_parameters is None else _ParameterOverrideHooks(hooks, validated_parameters)
197186
if bt_iscoroutinefunction(evaluator.task):
198187
result = await evaluator.task(input, task_hooks)
@@ -228,7 +217,7 @@ def stream_fn(event: SSEProgressEvent):
228217
eval_kwargs = {
229218
k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"]
230219
}
231-
if validated_parameters is not None and not RemoteEvalParameters.is_parameters(evaluator.parameters):
220+
if validated_parameters is not None and not isinstance(evaluator.parameters, RemoteEvalParameters):
232221
eval_kwargs["parameters"] = validated_parameters
233222

234223
try:
@@ -329,7 +318,10 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = Non
329318

330319

331320
def run_dev_server(
332-
evaluators: list[Evaluator[Any, Any]], host: str = "localhost", port: int = 8300, org_name: str | None = None
321+
evaluators: list[Evaluator[Any, Any]],
322+
host: str = "localhost",
323+
port: int = 8300,
324+
org_name: str | None = None,
333325
):
334326
"""Start the dev server.
335327

py/src/braintrust/framework.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@
4242
stringify_exception,
4343
)
4444
from .logger import init as _init_experiment
45-
from .parameters import EvalParameters, RemoteEvalParameters, is_eval_parameter_schema, validate_parameters
45+
from .parameters import (
46+
EvalParameters,
47+
RemoteEvalParameters,
48+
ValidatedParameters,
49+
is_eval_parameter_schema,
50+
validate_parameters,
51+
)
4652
from .resource_manager import ResourceManager
4753
from .score import Score, is_score, is_scorer
4854
from .serializable_data_class import SerializableDataClass
@@ -215,7 +221,7 @@ def meta(self, **info: Any) -> None:
215221

216222
@property
217223
@abc.abstractmethod
218-
def parameters(self) -> dict[str, Any] | None:
224+
def parameters(self) -> ValidatedParameters | None:
219225
"""
220226
The parameters for the current evaluation. These are the validated parameter values
221227
that were passed to the evaluator.
@@ -744,7 +750,7 @@ async def make_empty_summary():
744750
dataset = evaluator.data
745751

746752
experiment_parameters = None
747-
if RemoteEvalParameters.is_parameters(evaluator.parameters) and evaluator.parameters.id is not None:
753+
if isinstance(evaluator.parameters, RemoteEvalParameters) and evaluator.parameters.id is not None:
748754
experiment_parameters = {"id": evaluator.parameters.id}
749755
if evaluator.parameters.version is not None:
750756
experiment_parameters["version"] = evaluator.parameters.version
@@ -1162,7 +1168,7 @@ def __init__(
11621168
trial_index: int = 0,
11631169
tags: Sequence[str] | None = None,
11641170
report_progress: Callable[[TaskProgressEvent], None] = None,
1165-
parameters: dict[str, Any] | None = None,
1171+
parameters: ValidatedParameters | None = None,
11661172
):
11671173
if metadata is not None:
11681174
self.update({"metadata": metadata})
@@ -1220,7 +1226,7 @@ def report_progress(self, event: TaskProgressEvent):
12201226
return self._report_progress(event)
12211227

12221228
@property
1223-
def parameters(self) -> dict[str, Any] | None:
1229+
def parameters(self) -> ValidatedParameters | None:
12241230
return self._parameters
12251231

12261232

@@ -1403,7 +1409,7 @@ def get_other_fields(s):
14031409

14041410
if evaluator.parameter_values is not None:
14051411
resolved_evaluator_parameters = evaluator.parameter_values
1406-
elif RemoteEvalParameters.is_parameters(evaluator.parameters):
1412+
elif isinstance(evaluator.parameters, RemoteEvalParameters):
14071413
resolved_evaluator_parameters = validate_parameters({}, evaluator.parameters)
14081414
elif is_eval_parameter_schema(evaluator.parameters):
14091415
resolved_evaluator_parameters = validate_parameters({}, evaluator.parameters)

py/src/braintrust/logger.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2044,7 +2044,7 @@ def _get_parameters_ref(
20442044
) -> ParametersRef | None:
20452045
if parameters is None:
20462046
return None
2047-
if RemoteEvalParameters.is_parameters(parameters):
2047+
if isinstance(parameters, RemoteEvalParameters):
20482048
if parameters.id is None:
20492049
return None
20502050
ref: ParametersRef = {"id": parameters.id}
@@ -2085,43 +2085,39 @@ def load_parameters(
20852085
:returns: A `RemoteEvalParameters` object.
20862086
"""
20872087
if version is not None and environment is not None:
2088-
raise ValueError(
2089-
"Cannot specify both 'version' and 'environment' parameters. Please use only one (remove the other)."
2090-
)
2088+
raise ValueError("Cannot specify both 'version' and 'environment' parameters.")
20912089

2092-
if id:
2093-
pass
2094-
elif not project and not project_id:
2090+
if id is None and not project and not project_id:
20952091
raise ValueError("Must specify at least one of project or project_id")
2096-
elif not slug:
2092+
if id is None and not slug:
20972093
raise ValueError("Must specify slug")
20982094

2095+
should_fall_back_to_cache = version is None and environment is None
2096+
query_args = _populate_args(
2097+
{
2098+
"version": version,
2099+
"environment": environment,
2100+
}
2101+
)
2102+
20992103
try:
21002104
login(org_name=org_name, api_key=api_key, app_url=app_url)
21012105
if id:
2102-
parameters_args = {}
2103-
if version is not None:
2104-
parameters_args["version"] = version
2105-
if environment is not None:
2106-
parameters_args["environment"] = environment
2107-
response = _state.api_conn().get_json(f"/v1/function/{id}", parameters_args)
2106+
response = _state.api_conn().get_json(f"/v1/function/{id}", query_args)
21082107
if response is not None:
21092108
response = {"objects": [response]}
21102109
else:
2111-
args = _populate_args(
2112-
{
2113-
"project_name": project,
2114-
"project_id": project_id,
2115-
"slug": slug,
2116-
"version": version,
2117-
"environment": environment,
2118-
"function_type": "parameters",
2119-
}
2120-
)
2110+
args = {
2111+
"project_name": project,
2112+
"project_id": project_id,
2113+
"slug": slug,
2114+
"function_type": "parameters",
2115+
**query_args,
2116+
}
21212117
response = _state.api_conn().get_json("/v1/function", args)
21222118
except Exception as server_error:
2123-
if environment is not None or version is not None:
2124-
raise ValueError("Parameters not found with specified parameters") from server_error
2119+
if not should_fall_back_to_cache:
2120+
raise
21252121

21262122
eprint(f"Failed to load parameters, attempting to fall back to cache: {server_error}")
21272123
try:
@@ -2139,7 +2135,7 @@ def load_parameters(
21392135
f"Parameters with id {id} not found (not found on server or in local cache): {cache_error}"
21402136
) from server_error
21412137
raise ValueError(
2142-
f"Parameters {slug} (version {version or 'latest'}) not found in {project or project_id} (not found on server or in local cache): {cache_error}"
2138+
f"Parameters {slug} not found in {project or project_id} (not found on server or in local cache): {cache_error}"
21432139
) from server_error
21442140

21452141
if response is None or "objects" not in response or len(response["objects"]) == 0:

0 commit comments

Comments
 (0)