Skip to content

Commit d719a6c

Browse files
committed
updated server
1 parent 4c855e7 commit d719a6c

2 files changed

Lines changed: 65 additions & 40 deletions

File tree

eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py

Lines changed: 64 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import asyncio
1313
from flask import Flask, request, jsonify
1414
from openai import OpenAI
15+
import openai
1516
from dotenv import load_dotenv
1617

1718
from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter
@@ -49,56 +50,80 @@ def filter(self, record: logging.LogRecord) -> bool:
4950
app = Flask(__name__)
5051

5152

52-
async def execute_rollout_background(req, api_key):
53+
async def execute_rollout_background(req: InitRequest, api_key: str):
5354
"""Execute the OpenAI completion in background and log results"""
5455
# Attach rollout_id filter to logger
5556
logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}")
5657
logger.addFilter(RolloutIdFilter(req.metadata.rollout_id))
5758

58-
try:
59-
model = req.completion_params.get("model")
60-
# Uncomment if you need to strip fireworks_ai/ prefix
61-
# if model and isinstance(model, str) and model.startswith("fireworks_ai/"):
62-
# model = model[len("fireworks_ai/"):]
63-
64-
# Prepare completion arguments
65-
completion_kwargs = {
66-
"messages": req.messages,
67-
# "messages": [{"role": "user", "content": "Hello, how are you?"}],
68-
"model": model,
69-
"temperature": req.completion_params.get("temperature"),
70-
"max_tokens": req.completion_params.get("max_tokens"),
71-
}
59+
model = req.completion_params.get("model")
60+
# Uncomment if you need to strip fireworks_ai/ prefix
61+
# if model and isinstance(model, str) and model.startswith("fireworks_ai/"):
62+
# model = model[len("fireworks_ai/"):]
63+
64+
# Prepare completion arguments
65+
completion_kwargs = {
66+
"messages": req.messages,
67+
# "messages": [{"role": "user", "content": "Hello, how are you?"}],
68+
"model": model,
69+
"temperature": req.completion_params.get("temperature"),
70+
"max_tokens": req.completion_params.get("max_tokens"),
71+
}
72+
73+
# Add tools if present
74+
if req.tools:
75+
completion_kwargs["tools"] = req.tools
76+
77+
logger.info(
78+
f"DEBUG: {req.model_base_url}, COMPLETION_KWARGS: {completion_kwargs}, API_KEY: {api_key}, MODEL: {model}"
79+
)
7280

73-
# Add tools if present
74-
if req.tools:
75-
completion_kwargs["tools"] = req.tools
81+
# Create AsyncOpenAI client
82+
# client = AsyncOpenAI(base_url=req.model_base_url, api_key=api_key)
83+
client = OpenAI(base_url=req.model_base_url, api_key=api_key)
7684

77-
logger.info(
78-
f"DEBUG: {req.model_base_url}, COMPLETION_KWARGS: {completion_kwargs}, API_KEY: {api_key}, MODEL: {model}"
79-
)
85+
logger.info(f"Sending completion request to model {model}")
8086

81-
# Create AsyncOpenAI client
82-
# client = AsyncOpenAI(base_url=req.model_base_url, api_key=api_key)
83-
client = OpenAI(base_url=req.model_base_url, api_key=api_key)
87+
# Make the async model call with timeout
88+
import time
8489

85-
logger.info(f"Sending completion request to model {model}")
90+
logger.info(f"timing start: {time.time()}")
8691

87-
# Make the async model call with timeout
88-
import time
89-
90-
logger.info(f"timing start: {time.time()}")
92+
try:
9193
completion = client.chat.completions.create(**completion_kwargs)
92-
logger.info(f"Completed response: {completion}")
93-
logger.info(f"timing end: {time.time()}")
94-
# Log successful completion - THIS IS WHAT RemoteRolloutProcessor POLLS FOR
95-
logger.info(f"Rollout {req.metadata.rollout_id} completed", extra={"status": Status.rollout_finished()})
96-
94+
except (
95+
openai.AuthenticationError,
96+
openai.PermissionDeniedError,
97+
) as e:
98+
# These errors should be logged and will be retried by RemoteRolloutProcessor
99+
logger.error(
100+
f"Rollout {req.metadata.rollout_id} failed: {e}",
101+
extra={"status": Status.rollout_permission_denied_error(str(e))},
102+
)
103+
return
104+
except openai.NotFoundError as e:
105+
logger.error(
106+
f"Rollout {req.metadata.rollout_id} failed: {e}", extra={"status": Status.rollout_not_found_error(str(e))}
107+
)
108+
return
109+
except openai.RateLimitError as e:
110+
logger.error(
111+
f"Rollout {req.metadata.rollout_id} failed: {e}",
112+
extra={"status": Status.rollout_resource_exhausted_error(str(e))},
113+
)
114+
return
97115
except Exception as e:
98-
# Log error with structured status - THIS IS WHAT RemoteRolloutProcessor POLLS FOR
116+
# Non-OpenAI errors (shouldn't normally happen but catch anyway)
99117
logger.error(
100-
f"Rollout {req.metadata.rollout_id} failed: {e}", extra={"status": Status.rollout_error_from_exception(e)}
118+
f"Rollout {req.metadata.rollout_id} failed with unexpected error: {e}",
119+
extra={"status": Status.rollout_internal_error(str(e))},
101120
)
121+
return
122+
123+
logger.info(f"Completed response: {completion}")
124+
logger.info(f"timing end: {time.time()}")
125+
# Log successful completion - THIS IS WHAT RemoteRolloutProcessor POLLS FOR
126+
logger.info(f"Rollout {req.metadata.rollout_id} completed", extra={"status": Status.rollout_finished()})
102127

103128

104129
@app.route("/init", methods=["POST"])
@@ -114,7 +139,7 @@ async def init():
114139
# Validate required fields
115140
if not req.messages:
116141
error_msg = "messages is required"
117-
logger.error(error_msg, extra={"status": Status.rollout_error(error_msg)})
142+
logger.error(error_msg, extra={"status": Status.rollout_internal_error(error_msg)})
118143
return jsonify({"error": error_msg}), 400
119144

120145
# Get API key (prefer request api_key, fallback to environment)
@@ -126,7 +151,7 @@ async def init():
126151
api_key = os.environ.get("FIREWORKS_API_KEY")
127152
else:
128153
error_msg = "API key not provided in request or environment variable"
129-
logger.error(error_msg, extra={"status": Status.rollout_error(error_msg)})
154+
logger.error(error_msg, extra={"status": Status.rollout_internal_error(error_msg)})
130155
return jsonify({"error": error_msg}), 401
131156

132157
# 🔥 FIRE: Return immediately with acceptance (within 30s requirement)
@@ -137,7 +162,7 @@ async def init():
137162
}
138163

139164
# Fire and forget: Execute rollout asynchronously
140-
asyncio.create_task(execute_rollout_background(req, api_key))
165+
asyncio.create_task(execute_rollout_background(req, api_key or ""))
141166

142167
return jsonify(response_data), 200
143168

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
openai>=1.0.0
22
python-dotenv>=0.19.0
3-
eval_protocol>=0.2.70
3+
eval_protocol>=0.2.71
44
Flask[async]==3.0.3

0 commit comments

Comments
 (0)