-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathserver.py
More file actions
271 lines (217 loc) · 9.79 KB
/
server.py
File metadata and controls
271 lines (217 loc) · 9.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import importlib
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Union, Tuple, cast
import uvicorn # type: ignore[reportMissingImports]
from fastapi import FastAPI, HTTPException, Request # type: ignore[reportMissingImports]
from pydantic import BaseModel, Field # type: ignore[reportMissingImports]
from .models import EvaluateResult
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Message(BaseModel):
"""Model for a conversation message."""
role: str
content: str
class Config:
extra = "allow" # Allow extra fields
class RewardRequest(BaseModel):
"""Request model for reward endpoints."""
messages: List[Message] = Field(..., description="List of conversation messages")
ground_truth: Optional[Union[str, List[Message]]] = Field(
None, description="Ground truth data (string or list of messages) for context"
)
class Config:
extra = "allow" # Allow extra fields for arbitrary kwargs
class RewardServer:
"""
Server for hosting reward functions.
This class creates a FastAPI server that can host reward functions.
Args:
func_path: Path to the reward function to host (e.g., "module.path:function_name")
host: Host to bind the server to
port: Port to bind the server to
"""
def __init__(
self,
func_path: str,
host: str = "0.0.0.0",
port: int = 8000,
):
self.func_path = func_path
self.host = host
self.port = port
self.app = FastAPI(title="Reward Function Server")
# Load the reward function
self.reward_func = self._load_function()
# Register the endpoints
self._setup_routes()
def _load_function(self):
"""Load the reward function from the provided path."""
try:
if ":" not in self.func_path:
raise ValueError(f"Invalid func_path format: {self.func_path}, expected 'module.path:function_name'")
module_path, func_name = self.func_path.split(":", 1)
module = importlib.import_module(module_path)
func = getattr(module, func_name)
logger.info(f"Loaded reward function {func_name} from {module_path}")
return func
except (ImportError, AttributeError) as e:
raise ImportError(f"Failed to load function from path {self.func_path}: {str(e)}")
def _setup_routes(self):
"""Set up the API routes."""
@self.app.get("/")
async def root():
"""Get server info."""
return {
"status": "ok",
"reward_function": self.func_path,
"endpoints": ["/reward"],
}
@self.app.post("/reward")
async def reward(request: RewardRequest):
"""
Get reward score for messages.
Args:
request: RewardRequest object with messages and optional parameters
Returns:
EvaluateResult object with score and metrics
"""
try:
# Extract kwargs from the request
kwargs = request.dict(exclude={"messages", "ground_truth"})
# Set default for ground_truth if not provided and expected as list
ground_truth_data = request.ground_truth
if ground_truth_data is None:
# This default applies if ground_truth is expected to be a list of messages for context
ground_truth_data = request.messages[:-1] if request.messages else []
# Call the reward function
result = self.reward_func(
messages=request.messages,
ground_truth=ground_truth_data,
**kwargs,
)
# Handle different return types
# The self.reward_func is expected to be decorated by the new @reward_function,
# which returns a dictionary.
if isinstance(result, dict) and "score" in result:
return result
elif isinstance(result, EvaluateResult): # Should not happen if func is from new decorator
logger.warning("Reward function returned EvaluateResult object directly to server; expected dict.")
return result.model_dump()
elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
logger.warning("Reward function returned legacy tuple format to server.")
score, components = result
return {"score": score, "metrics": components}
else:
raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
except Exception as e:
logger.error(f"Error processing reward request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@self.app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "ok"}
def run(self):
"""Run the server."""
logger.info(f"Starting reward server on {self.host}:{self.port}")
uvicorn.run(self.app, host=self.host, port=self.port)
def serve(func_path: str, host: str = "0.0.0.0", port: int = 8000):
"""
Serve a reward function as an HTTP API.
Args:
func_path: Path to the reward function to serve (e.g., "module.path:function_name")
host: Host to bind the server to
port: Port to bind the server to
"""
server = RewardServer(func_path=func_path, host=host, port=port)
server.run()
# ngrok-based serve_tunnel is deprecated in favor of Serveo via subprocess_manager.
# def serve_tunnel(func_path: str, port: int = 8000):
# """
# Serve a reward function with an ngrok tunnel.
# DEPRECATED.
# """
# try:
# import pyngrok.ngrok as ngrok # type: ignore
# except ImportError:
# raise ImportError(
# "The 'pyngrok' package is required to use serve_tunnel. "
# "Please install it with 'pip install pyngrok'."
# )
#
# # Open the tunnel
# tunnel = ngrok.connect(port)
# public_url = tunnel.public_url
#
# # Print the tunnel URL
# logger.info(f"Reward function available at: {public_url}/reward")
#
# # Start the server
# serve(func_path=func_path, host="0.0.0.0", port=port)
def create_app(reward_func: Callable[..., EvaluateResult]) -> FastAPI:
"""
Create a FastAPI app for the given reward function.
This function creates a FastAPI app that can be used to serve a reward function.
It's particularly useful for testing or when you want to manage the lifecycle
of the app yourself.
Args:
reward_func: The reward function to serve
Returns:
A FastAPI app instance
"""
app = FastAPI(title="Reward Function Server")
@app.get("/")
async def root():
"""Get server info."""
return {"status": "ok", "endpoints": ["/reward"]}
@app.post("/reward")
async def reward(request_data: RewardRequest):
"""
Get reward score for messages.
Args:
request_data: RewardRequest object with messages and optional parameters
Returns:
EvaluateResult object with score and metrics
"""
try:
# Convert Pydantic models to dictionaries using model_dump (Pydantic v2)
messages = [msg.model_dump() for msg in request_data.messages]
ground_truth_data: Optional[Union[str, List[Dict[str, Any]]]] = None
if isinstance(request_data.ground_truth, str):
ground_truth_data = request_data.ground_truth
elif isinstance(request_data.ground_truth, list):
ground_truth_data = [msg.model_dump() for msg in request_data.ground_truth]
# Extract kwargs from any extra fields
kwargs = {k: v for k, v in request_data.model_dump().items() if k not in ["messages", "ground_truth"]}
# Set default for ground_truth if not provided and expected as list
if ground_truth_data is None:
# This default applies if ground_truth is expected to be a list of messages for context
ground_truth_data = messages[:-1] if messages else []
# Call the reward function
result = reward_func(messages=messages, ground_truth=ground_truth_data, **kwargs)
# Handle different return types
# The reward_func is expected to be decorated by the new @reward_function,
# which returns a dictionary.
if isinstance(result, dict) and "score" in result:
return result
elif isinstance(result, EvaluateResult): # Should not happen if func is from new decorator
logger.warning(
"Reward function passed to create_app returned EvaluateResult object directly; expected dict after decoration."
)
return result.model_dump()
elif isinstance(result, tuple) and len(result) == 2: # Legacy tuple
logger.warning("Reward function passed to create_app returned legacy tuple format.")
score, components = cast(Tuple[float, Dict[str, Any]], result)
return {"score": score, "metrics": components}
else:
raise TypeError(f"Invalid return type from reward function after decoration: {type(result)}")
except Exception as e:
logger.error(f"Error processing reward request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "ok"}
return app