@@ -17,20 +17,23 @@ def main():
1717 parser = argparse .ArgumentParser (description = "GitHub Actions rollout worker" )
1818
1919 # Required arguments from workflow inputs
20- parser .add_argument ("--model" , required = True , help = "Model to use" )
21- parser .add_argument ("--completion-params" , required = False , help = "JSON completion params (optional)" )
20+ parser .add_argument ("--completion-params" , required = True , help = "JSON completion params (includes model)" )
2221 parser .add_argument ("--metadata" , required = True , help = "JSON serialized metadata object" )
2322 parser .add_argument ("--model-base-url" , required = True , help = "Base URL for the model API" )
2423
2524 args = parser .parse_args ()
2625
27- # Parse the metadata
28- completion_params = {}
29- if args .completion_params :
30- try :
31- completion_params = json .loads (args .completion_params )
32- except Exception as e :
33- print (f"⚠️ Failed to parse completion_params: { e } " )
26+ # Parse completion_params
27+ try :
28+ completion_params = json .loads (args .completion_params )
29+ except Exception as e :
30+ print (f"❌ Failed to parse completion_params: { e } " )
31+ exit (1 )
32+
33+ model = completion_params .get ("model" )
34+ if not model :
35+ print ("Error: model is required in completion_params" )
36+ exit (1 )
3437
3538 try :
3639 metadata = json .loads (args .metadata )
@@ -42,7 +45,7 @@ def main():
4245 row_id = metadata ["row_id" ]
4346
4447 print (f"🚀 Starting rollout { rollout_id } " )
45- print (f" Model: { args . model } " )
48+ print (f" Model: { model } " )
4649 print (f" Row ID: { row_id } " )
4750
4851 dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
@@ -57,16 +60,8 @@ def main():
5760 print (f" Messages: { len (messages )} messages" )
5861
5962 try :
60- completion_kwargs = {"model" : args .model , "messages" : messages }
61- # Parse and apply completion_params if provided
62- if args .completion_params :
63- try :
64- cp = json .loads (args .completion_params )
65- if cp .get ("model_kwargs" ):
66- completion_kwargs .update (cp ["model_kwargs" ])
67- print (f" Applied model_kwargs: { cp .get ('model_kwargs' )} " )
68- except Exception as e :
69- print (f"⚠️ Failed to parse completion_params: { e } " )
63+ # Build completion kwargs from completion_params
64+ completion_kwargs = {"messages" : messages , ** completion_params }
7065
7166 client = OpenAI (base_url = args .model_base_url , api_key = os .environ .get ("FIREWORKS_API_KEY" ))
7267
0 commit comments