Skip to content

Commit 115c765

Browse files
committed
implemented hardening in generic_server
1 parent 3c8d8f2 commit 115c765

File tree

3 files changed

+32
-13
lines changed

3 files changed

+32
-13
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,6 @@ package.json
243243
tau2-bench
244244
*.err
245245
eval-protocol
246+
_pytest_deps/
247+
.test_deps/
248+
.test_deps/

eval_protocol/generic_server.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import importlib
2+
import logging
23
import os
34
from typing import Any, Dict, List, Optional
45

@@ -9,12 +10,15 @@
910
# Assuming these models are correctly defined in eval_protocol.models
1011
from eval_protocol.models import EvaluateResult, Message
1112

13+
logger = logging.getLogger(__name__)
14+
1215

1316
# --- Request and Response Models ---
1417
class EvaluationRequest(BaseModel):
1518
messages: List[Dict[str, Any]] # Could also be List[Message] if we enforce that model on input
1619
ground_truth: Optional[str] = None
17-
kwargs: Optional[Dict[str, Any]] = {}
20+
# Avoid shared mutable default across requests.
21+
kwargs: Optional[Dict[str, Any]] = None
1822

1923

2024
# --- Global variable to store the loaded reward function ---
@@ -74,8 +78,10 @@ async def evaluate_endpoint(request: EvaluationRequest):
7478
if not isinstance(result, EvaluateResult):
7579
# This case should ideally not happen if functions are correctly decorated
7680
# and return EvaluateResult, but good to have a fallback.
77-
print(
78-
f"Warning: Reward function '{_REWARD_FUNCTION_NAME}' did not return an EvaluateResult instance. Type: {type(result)}"
81+
logger.warning(
82+
"Reward function '%s' did not return an EvaluateResult instance. Type: %s",
83+
_REWARD_FUNCTION_NAME,
84+
type(result),
7985
)
8086
# Attempt to construct an EvaluateResult if it's a dict-like object,
8187
# otherwise, this will raise an error or return a poorly formed response.
@@ -89,15 +95,18 @@ async def evaluate_endpoint(request: EvaluationRequest):
8995

9096
return result
9197
except ValidationError as ve: # Pydantic validation error from reward function's input/output
92-
print(f"Validation Error calling reward function '{_REWARD_FUNCTION_NAME}': {ve}")
98+
logger.warning(
99+
"Validation error calling reward function '%s': %s",
100+
_REWARD_FUNCTION_NAME,
101+
ve,
102+
)
93103
raise HTTPException(
94104
status_code=422,
95105
detail=f"Input/Output validation error for reward function: {ve.errors()}",
96106
)
97107
except Exception as e:
98-
print(f"Error during evaluation with reward function '{_REWARD_FUNCTION_NAME}': {e}")
99-
# Consider logging the full traceback here
100-
raise HTTPException(status_code=500, detail=f"Internal server error during evaluation: {str(e)}")
108+
logger.exception("Error during evaluation with reward function '%s'", _REWARD_FUNCTION_NAME)
109+
raise HTTPException(status_code=500, detail="Internal server error during evaluation.")
101110

102111

103112
@app.get("/health")
@@ -121,9 +130,9 @@ def load_reward_function(import_string: str):
121130
module = importlib.import_module(module_path)
122131
_LOADED_REWARD_FUNCTION = getattr(module, function_name)
123132
_REWARD_FUNCTION_NAME = import_string
124-
print(f"Successfully loaded reward function: {_REWARD_FUNCTION_NAME}")
133+
logger.info("Successfully loaded reward function: %s", _REWARD_FUNCTION_NAME)
125134
except Exception as e:
126-
print(f"Error loading reward function from '{import_string}': {e}")
135+
logger.exception("Error loading reward function from '%s'", import_string)
127136
_LOADED_REWARD_FUNCTION = None
128137
_REWARD_FUNCTION_NAME = "Error loading"
129138
raise # Re-raise to make it fatal if loading fails on startup
@@ -153,13 +162,16 @@ def load_reward_function(import_string: str):
153162
try:
154163
load_reward_function(args.import_string)
155164
except Exception:
156-
print("Failed to load reward function. Exiting.")
165+
logger.error("Failed to load reward function. Exiting.")
157166
exit(1)
158167

159168
if not _LOADED_REWARD_FUNCTION:
160-
print(f"Reward function {_REWARD_FUNCTION_NAME} could not be loaded. Server will not start correctly.")
169+
logger.error(
170+
"Reward function %s could not be loaded. Server will not start correctly.",
171+
_REWARD_FUNCTION_NAME,
172+
)
161173
# Depending on desired behavior, could exit here or let it run and fail on /evaluate
162174
exit(1)
163175

164-
print(f"Starting server for reward function: {args.import_string} on http://{args.host}:{args.port}")
176+
logger.info("Starting server for reward function: %s on http://%s:%s", args.import_string, args.host, args.port)
165177
uvicorn.run(app, host=args.host, port=args.port) # reload=args.reload for dev

tests/test_generic_server.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,11 @@ def test_evaluate_endpoint_reward_function_raises_error(self):
168168
request_payload = EvaluationRequest(messages=[{"role": "user", "content": "test"}])
169169
response = self.client.post("/evaluate", json=request_payload.model_dump())
170170
assert response.status_code == 500
171-
assert "Intentional error in dummy_reward_func_error" in response.json()["detail"]
171+
assert response.json()["detail"] == "Internal server error during evaluation."
172+
173+
def test_evaluation_request_kwargs_defaults_to_none(self):
174+
payload = EvaluationRequest(messages=[{"role": "user", "content": "test"}])
175+
assert payload.kwargs is None
172176

173177
def test_evaluate_endpoint_function_returns_invalid_type(self):
174178
"""

0 commit comments

Comments
 (0)