|
25 | 25 | ) |
26 | 26 | ) |
27 | 27 |
|
28 | | -from ..framework import EvalAsync, EvalScorer, Evaluator, ExperimentSummary, SSEProgressEvent |
| 28 | +from ..framework import ( |
| 29 | + EvalAsync, |
| 30 | + EvalHooks, |
| 31 | + EvalScorer, |
| 32 | + Evaluator, |
| 33 | + ExperimentSummary, |
| 34 | + SSEProgressEvent, |
| 35 | +) |
29 | 36 | from ..generated_types import FunctionId |
30 | 37 | from ..logger import BraintrustState, bt_iscoroutinefunction |
31 | | -from ..parameters import RemoteEvalParameters, serialize_remote_eval_parameters_container, validate_parameters |
| 38 | +from ..parameters import ( |
| 39 | + RemoteEvalParameters, |
| 40 | + ValidatedParameters, |
| 41 | + serialize_remote_eval_parameters_container, |
| 42 | + validate_parameters, |
| 43 | +) |
32 | 44 | from ..span_identifier_v4 import parse_parent |
33 | 45 | from .auth import AuthorizationMiddleware |
34 | 46 | from .cache import cached_login |
|
42 | 54 |
|
43 | 55 |
|
44 | 56 | class _ParameterOverrideHooks: |
45 | | - def __init__(self, hooks: Any, parameters: dict[str, Any]): |
| 57 | + def __init__(self, hooks: EvalHooks[Any], parameters: ValidatedParameters): |
46 | 58 | self._hooks = hooks |
47 | 59 | self._parameters = parameters |
48 | 60 |
|
49 | 61 | @property |
50 | | - def metadata(self): |
51 | | - return self._hooks.metadata |
52 | | - |
53 | | - @property |
54 | | - def expected(self): |
55 | | - return self._hooks.expected |
56 | | - |
57 | | - @property |
58 | | - def span(self): |
59 | | - return self._hooks.span |
60 | | - |
61 | | - @property |
62 | | - def trial_index(self): |
63 | | - return self._hooks.trial_index |
64 | | - |
65 | | - @property |
66 | | - def tags(self): |
67 | | - return self._hooks.tags |
68 | | - |
69 | | - @property |
70 | | - def parameters(self): |
| 62 | + def parameters(self) -> ValidatedParameters: |
71 | 63 | return self._parameters |
72 | 64 |
|
73 | | - def report_progress(self, progress): |
74 | | - return self._hooks.report_progress(progress) |
75 | | - |
76 | | - def meta(self, **info: Any): |
77 | | - return self._hooks.meta(**info) |
| 65 | + def __getattr__(self, name: str): |
| 66 | + return getattr(self._hooks, name) |
78 | 67 |
|
79 | 68 |
|
80 | 69 | class CheckAuthorizedMiddleware(BaseHTTPMiddleware): |
@@ -192,7 +181,7 @@ async def run_eval(request: Request) -> JSONResponse | StreamingResponse: |
192 | 181 | # Set up SSE headers for streaming |
193 | 182 | sse_queue = SSEQueue() |
194 | 183 |
|
195 | | - async def task(input, hooks): |
| 184 | + async def task(input: Any, hooks: EvalHooks[Any]): |
196 | 185 | task_hooks = hooks if validated_parameters is None else _ParameterOverrideHooks(hooks, validated_parameters) |
197 | 186 | if bt_iscoroutinefunction(evaluator.task): |
198 | 187 | result = await evaluator.task(input, task_hooks) |
@@ -228,7 +217,7 @@ def stream_fn(event: SSEProgressEvent): |
228 | 217 | eval_kwargs = { |
229 | 218 | k: v for (k, v) in evaluator.__dict__.items() if k not in ["eval_name", "project_name", "parameter_values"] |
230 | 219 | } |
231 | | - if validated_parameters is not None and not RemoteEvalParameters.is_parameters(evaluator.parameters): |
| 220 | + if validated_parameters is not None and not isinstance(evaluator.parameters, RemoteEvalParameters): |
232 | 221 | eval_kwargs["parameters"] = validated_parameters |
233 | 222 |
|
234 | 223 | try: |
@@ -329,7 +318,10 @@ def create_app(evaluators: list[Evaluator[Any, Any]], org_name: str | None = Non |
329 | 318 |
|
330 | 319 |
|
331 | 320 | def run_dev_server( |
332 | | - evaluators: list[Evaluator[Any, Any]], host: str = "localhost", port: int = 8300, org_name: str | None = None |
| 321 | + evaluators: list[Evaluator[Any, Any]], |
| 322 | + host: str = "localhost", |
| 323 | + port: int = 8300, |
| 324 | + org_name: str | None = None, |
333 | 325 | ): |
334 | 326 | """Start the dev server. |
335 | 327 |
|
|
0 commit comments