1212import asyncio
1313from flask import Flask , request , jsonify
1414from openai import OpenAI
15+ import openai
1516from dotenv import load_dotenv
1617
1718from eval_protocol import Status , InitRequest , FireworksTracingHttpHandler , RolloutIdFilter
@@ -49,56 +50,80 @@ def filter(self, record: logging.LogRecord) -> bool:
4950app = Flask (__name__ )
5051
5152
52- async def execute_rollout_background (req , api_key ):
53+ async def execute_rollout_background (req : InitRequest , api_key : str ):
5354 """Execute the OpenAI completion in background and log results"""
5455 # Attach rollout_id filter to logger
5556 logger = logging .getLogger (f"{ __name__ } .{ req .metadata .rollout_id } " )
5657 logger .addFilter (RolloutIdFilter (req .metadata .rollout_id ))
5758
58- try :
59- model = req .completion_params .get ("model" )
60- # Uncomment if you need to strip fireworks_ai/ prefix
61- # if model and isinstance(model, str) and model.startswith("fireworks_ai/"):
62- # model = model[len("fireworks_ai/"):]
63-
64- # Prepare completion arguments
65- completion_kwargs = {
66- "messages" : req .messages ,
67- # "messages": [{"role": "user", "content": "Hello, how are you?"}],
68- "model" : model ,
69- "temperature" : req .completion_params .get ("temperature" ),
70- "max_tokens" : req .completion_params .get ("max_tokens" ),
71- }
59+ model = req .completion_params .get ("model" )
60+ # Uncomment if you need to strip fireworks_ai/ prefix
61+ # if model and isinstance(model, str) and model.startswith("fireworks_ai/"):
62+ # model = model[len("fireworks_ai/"):]
63+
64+ # Prepare completion arguments
65+ completion_kwargs = {
66+ "messages" : req .messages ,
67+ # "messages": [{"role": "user", "content": "Hello, how are you?"}],
68+ "model" : model ,
69+ "temperature" : req .completion_params .get ("temperature" ),
70+ "max_tokens" : req .completion_params .get ("max_tokens" ),
71+ }
72+
73+ # Add tools if present
74+ if req .tools :
75+ completion_kwargs ["tools" ] = req .tools
76+
77+ logger .info (
78+ f"DEBUG: { req .model_base_url } , COMPLETION_KWARGS: { completion_kwargs } , API_KEY: { api_key } , MODEL: { model } "
79+ )
7280
73- # Add tools if present
74- if req .tools :
75- completion_kwargs [ "tools" ] = req .tools
81+ # Create AsyncOpenAI client
82+ # client = AsyncOpenAI(base_url= req.model_base_url, api_key=api_key)
83+ client = OpenAI ( base_url = req .model_base_url , api_key = api_key )
7684
77- logger .info (
78- f"DEBUG: { req .model_base_url } , COMPLETION_KWARGS: { completion_kwargs } , API_KEY: { api_key } , MODEL: { model } "
79- )
85+ logger .info (f"Sending completion request to model { model } " )
8086
81- # Create AsyncOpenAI client
82- # client = AsyncOpenAI(base_url=req.model_base_url, api_key=api_key)
83- client = OpenAI (base_url = req .model_base_url , api_key = api_key )
87+ # Make the async model call with timeout
88+ import time
8489
85- logger .info (f"Sending completion request to model { model } " )
90+ logger .info (f"timing start: { time . time () } " )
8691
87- # Make the async model call with timeout
88- import time
89-
90- logger .info (f"timing start: { time .time ()} " )
92+ try :
9193 completion = client .chat .completions .create (** completion_kwargs )
92- logger .info (f"Completed response: { completion } " )
93- logger .info (f"timing end: { time .time ()} " )
94- # Log successful completion - THIS IS WHAT RemoteRolloutProcessor POLLS FOR
95- logger .info (f"Rollout { req .metadata .rollout_id } completed" , extra = {"status" : Status .rollout_finished ()})
96-
94+ except (
95+ openai .AuthenticationError ,
96+ openai .PermissionDeniedError ,
97+ ) as e :
98+ # These errors should be logged and will be retried by RemoteRolloutProcessor
99+ logger .error (
100+ f"Rollout { req .metadata .rollout_id } failed: { e } " ,
101+ extra = {"status" : Status .rollout_permission_denied_error (str (e ))},
102+ )
103+ return
104+ except openai .NotFoundError as e :
105+ logger .error (
106+ f"Rollout { req .metadata .rollout_id } failed: { e } " , extra = {"status" : Status .rollout_not_found_error (str (e ))}
107+ )
108+ return
109+ except openai .RateLimitError as e :
110+ logger .error (
111+ f"Rollout { req .metadata .rollout_id } failed: { e } " ,
112+ extra = {"status" : Status .rollout_resource_exhausted_error (str (e ))},
113+ )
114+ return
97115 except Exception as e :
98- # Log error with structured status - THIS IS WHAT RemoteRolloutProcessor POLLS FOR
116+ # Non-OpenAI errors (shouldn't normally happen but catch anyway)
99117 logger .error (
100- f"Rollout { req .metadata .rollout_id } failed: { e } " , extra = {"status" : Status .rollout_error_from_exception (e )}
118+ f"Rollout { req .metadata .rollout_id } failed with unexpected error: { e } " ,
119+ extra = {"status" : Status .rollout_internal_error (str (e ))},
101120 )
121+ return
122+
123+ logger .info (f"Completed response: { completion } " )
124+ logger .info (f"timing end: { time .time ()} " )
125+ # Log successful completion - THIS IS WHAT RemoteRolloutProcessor POLLS FOR
126+ logger .info (f"Rollout { req .metadata .rollout_id } completed" , extra = {"status" : Status .rollout_finished ()})
102127
103128
104129@app .route ("/init" , methods = ["POST" ])
@@ -114,7 +139,7 @@ async def init():
114139 # Validate required fields
115140 if not req .messages :
116141 error_msg = "messages is required"
117- logger .error (error_msg , extra = {"status" : Status .rollout_error (error_msg )})
142+ logger .error (error_msg , extra = {"status" : Status .rollout_internal_error (error_msg )})
118143 return jsonify ({"error" : error_msg }), 400
119144
120145 # Get API key (prefer request api_key, fallback to environment)
@@ -126,7 +151,7 @@ async def init():
126151 api_key = os .environ .get ("FIREWORKS_API_KEY" )
127152 else :
128153 error_msg = "API key not provided in request or environment variable"
129- logger .error (error_msg , extra = {"status" : Status .rollout_error (error_msg )})
154+ logger .error (error_msg , extra = {"status" : Status .rollout_internal_error (error_msg )})
130155 return jsonify ({"error" : error_msg }), 401
131156
132157 # 🔥 FIRE: Return immediately with acceptance (within 30s requirement)
@@ -137,7 +162,7 @@ async def init():
137162 }
138163
139164 # Fire and forget: Execute rollout asynchronously
140- asyncio .create_task (execute_rollout_background (req , api_key ))
165+ asyncio .create_task (execute_rollout_background (req , api_key or "" ))
141166
142167 return jsonify (response_data ), 200
143168
0 commit comments