-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathremote_server.py
More file actions
106 lines (80 loc) · 3.45 KB
/
remote_server.py
File metadata and controls
106 lines (80 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import threading
import argparse
import uvicorn
from fastapi import FastAPI
from openai import OpenAI
import logging
from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter
app = FastAPI()
# Attach Fireworks tracing handler to root logger
fireworks_handler = FireworksTracingHttpHandler()
logging.getLogger().addHandler(fireworks_handler)
force_early_error_message = None
@app.post("/init")
def init(req: InitRequest):
# Attach rollout_id filter to logger
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
# Kick off worker thread that does a single-turn chat via Langfuse OpenAI integration
def _worker():
try:
if not req.messages:
raise ValueError("messages is required")
model = req.completion_params.get("model")
if not model:
raise ValueError("model is required in completion_params")
# Spread all completion_params (model, temperature, max_tokens, etc.)
completion_kwargs = {"messages": req.messages, **req.completion_params}
if req.tools:
completion_kwargs["tools"] = req.tools
logger.info(f"Final completion_kwargs: {completion_kwargs}")
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
logger.info(f"Sending completion request to model {model}")
completion = client.chat.completions.create(**completion_kwargs)
logger.info(f"Completed response: {completion}")
# If force_early_error is set via command-line arg, log the error and return early
if force_early_error_message:
logger.error(
force_early_error_message,
extra={"status": Status.rollout_error(force_early_error_message)},
)
raise RuntimeError(force_early_error_message)
except Exception as e:
# Best-effort; mark as done even on error to unblock polling
logger.error(f"❌ Error in rollout {req.metadata.rollout_id}: {e}")
pass
finally:
if not force_early_error_message:
logger.info(
f"Rollout {req.metadata.rollout_id} completed",
extra={"status": Status.rollout_finished()},
)
t = threading.Thread(target=_worker, daemon=True)
t.start()
def main():
global force_early_error_message
parser = argparse.ArgumentParser(description="Run the remote server for evaluation protocol")
parser.add_argument(
"--host",
type=str,
default=os.getenv("REMOTE_SERVER_HOST", "127.0.0.1"),
help="Host to bind the server to (default: 127.0.0.1 or REMOTE_SERVER_HOST env var)",
)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("REMOTE_SERVER_PORT", "3000")),
help="Port to bind the server to (default: 3000 or REMOTE_SERVER_PORT env var)",
)
parser.add_argument(
"--force-early-error",
type=str,
default=None,
help="If set, /init will immediately return after logging a rollout_error with this message",
)
args = parser.parse_args()
force_early_error_message = args.force_early_error
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()