Skip to content

Commit b9d48b0

Browse files
author
Shrey Modi
committed
addressedcomments
1 parent 2ff684b commit b9d48b0

8 files changed

Lines changed: 37 additions & 66 deletions

File tree

.github/workflows/rollout.yml

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ run-name: rollout:${{ fromJSON(inputs.metadata).rollout_id }}
55
on:
66
workflow_dispatch:
77
inputs:
8-
model:
9-
description: 'Model to use'
10-
required: true
8+
completion_params:
9+
description: 'JSON completion params (optional, includes model_kwargs)'
10+
required: false
1111
type: string
1212
metadata:
1313
description: 'JSON serialized metadata object'
@@ -17,10 +17,7 @@ on:
1717
description: 'Base URL for the model API'
1818
required: true
1919
type: string
20-
completion_params:
21-
description: 'JSON completion params (optional, includes model_kwargs)'
22-
required: false
23-
type: string
20+
2421

2522
jobs:
2623
rollout:
@@ -45,7 +42,6 @@ jobs:
4542
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
4643
run: |
4744
python tests/github_actions/rollout_worker.py \
48-
--model "${{ inputs.model }}" \
45+
--completion-params '${{ inputs.completion_params }}' \
4946
--metadata '${{ inputs.metadata }}' \
50-
--model-base-url "${{ inputs.model_base_url }}" \
51-
${{ inputs.completion_params && format('--completion-params ''{0}''', inputs.completion_params) || '' }}
47+
--model-base-url "${{ inputs.model_base_url }}"

eval_protocol/pytest/github_action_rollout_processor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,9 @@ def _dispatch_workflow():
9494
payload = {
9595
"ref": self.ref,
9696
"inputs": {
97-
"model": model,
97+
"completion_params": json.dumps(init_request.completion_params),
9898
"metadata": init_request.metadata.model_dump_json(),
9999
"model_base_url": init_request.model_base_url,
100-
"completion_params": json.dumps(init_request.completion_params),
101100
},
102101
}
103102
r = requests.post(url, json=payload, headers=self._headers(), timeout=30)

tests/github_actions/rollout_worker.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,23 @@ def main():
1717
parser = argparse.ArgumentParser(description="GitHub Actions rollout worker")
1818

1919
# Required arguments from workflow inputs
20-
parser.add_argument("--model", required=True, help="Model to use")
21-
parser.add_argument("--completion-params", required=False, help="JSON completion params (optional)")
20+
parser.add_argument("--completion-params", required=True, help="JSON completion params (includes model)")
2221
parser.add_argument("--metadata", required=True, help="JSON serialized metadata object")
2322
parser.add_argument("--model-base-url", required=True, help="Base URL for the model API")
2423

2524
args = parser.parse_args()
2625

27-
# Parse the metadata
28-
completion_params = {}
29-
if args.completion_params:
30-
try:
31-
completion_params = json.loads(args.completion_params)
32-
except Exception as e:
33-
print(f"⚠️ Failed to parse completion_params: {e}")
26+
# Parse completion_params
27+
try:
28+
completion_params = json.loads(args.completion_params)
29+
except Exception as e:
30+
print(f"❌ Failed to parse completion_params: {e}")
31+
exit(1)
32+
33+
model = completion_params.get("model")
34+
if not model:
35+
print("Error: model is required in completion_params")
36+
exit(1)
3437

3538
try:
3639
metadata = json.loads(args.metadata)
@@ -42,7 +45,7 @@ def main():
4245
row_id = metadata["row_id"]
4346

4447
print(f"🚀 Starting rollout {rollout_id}")
45-
print(f" Model: {args.model}")
48+
print(f" Model: {model}")
4649
print(f" Row ID: {row_id}")
4750

4851
dataset = [ # In this example, worker has access to the dataset and we use index to associate rows.
@@ -57,16 +60,8 @@ def main():
5760
print(f" Messages: {len(messages)} messages")
5861

5962
try:
60-
completion_kwargs = {"model": args.model, "messages": messages}
61-
# Parse and apply completion_params if provided
62-
if args.completion_params:
63-
try:
64-
cp = json.loads(args.completion_params)
65-
if cp.get("model_kwargs"):
66-
completion_kwargs.update(cp["model_kwargs"])
67-
print(f" Applied model_kwargs: {cp.get('model_kwargs')}")
68-
except Exception as e:
69-
print(f"⚠️ Failed to parse completion_params: {e}")
63+
# Build completion kwargs from completion_params
64+
completion_kwargs = {"messages": messages, **completion_params}
7065

7166
client = OpenAI(base_url=args.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
7267

tests/github_actions/test_github_actions_rollout.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ def rows() -> List[EvaluationRow]:
5555

5656
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
5757
@pytest.mark.parametrize(
58-
"completion_params",
59-
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "model_kwargs": {"temperature": 0.5}}],
58+
"completion_params", [{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}]
6059
)
6160
@evaluation_test(
6261
data_loaders=DynamicDataLoader(

tests/remote_server/remote_server.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,8 @@ def _worker():
4040
if not model:
4141
raise ValueError("model is required in completion_params")
4242

43-
completion_kwargs = {
44-
"model": model,
45-
"messages": req.messages,
46-
}
47-
48-
# Apply model_kwargs if present
49-
if req.completion_params.get("model_kwargs"):
50-
model_kwargs = req.completion_params["model_kwargs"]
51-
if isinstance(model_kwargs, dict):
52-
completion_kwargs.update(model_kwargs)
43+
# Spread all completion_params (model, temperature, max_tokens, etc.)
44+
completion_kwargs = {"messages": req.messages, **req.completion_params}
5345

5446
if req.tools:
5547
completion_kwargs["tools"] = req.tools

tests/remote_server/remote_server_multi_turn.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,6 @@ def _worker():
3838

3939
client = OpenAI(base_url=req.model_base_url, api_key=os.environ.get("FIREWORKS_API_KEY"))
4040

41-
# Apply model_kwargs if present
42-
if req.completion_params.get("model_kwargs"):
43-
model_kwargs = req.completion_params["model_kwargs"]
44-
if isinstance(model_kwargs, dict):
45-
completion_kwargs.update(model_kwargs)
4641
# Build up conversation over 6 turns (3 user messages + 3 assistant responses)
4742
# Convert Message objects to dicts for OpenAI API
4843
conversation_history = [{"role": m.role, "content": m.content} for m in req.messages]
@@ -55,9 +50,8 @@ def _worker():
5550
# First completion (turns 1-2: initial user message + assistant response)
5651
logger.info(f"Turn 1-2: Sending initial completion request to model {model}")
5752
completion = client.chat.completions.create(
58-
model=model,
59-
messages=conversation_history, # type: ignore,
60-
**completion_kwargs,
53+
messages=conversation_history, # type: ignore
54+
**req.completion_params,
6155
)
6256
assistant_message = completion.choices[0].message
6357
assistant_content = assistant_message.content or ""
@@ -68,8 +62,8 @@ def _worker():
6862
conversation_history.append({"role": "user", "content": follow_up_questions[0]})
6963
logger.info(f"Turn 3: User asks: {follow_up_questions[0]}")
7064
completion = client.chat.completions.create(
71-
model=model,
7265
messages=conversation_history, # type: ignore
66+
**req.completion_params,
7367
)
7468
assistant_message = completion.choices[0].message
7569
assistant_content = assistant_message.content or ""
@@ -80,8 +74,8 @@ def _worker():
8074
conversation_history.append({"role": "user", "content": follow_up_questions[1]})
8175
logger.info(f"Turn 5: User asks: {follow_up_questions[1]}")
8276
completion = client.chat.completions.create(
83-
model=model,
8477
messages=conversation_history, # type: ignore
78+
**req.completion_params,
8579
)
8680
assistant_message = completion.choices[0].message
8781
assistant_content = assistant_message.content or ""

tests/remote_server/test_remote_fireworks.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def rows() -> List[EvaluationRow]:
6060
@pytest.mark.skipif(os.environ.get("CI") == "true", reason="Only run this test locally (skipped in CI)")
6161
@pytest.mark.parametrize(
6262
"completion_params",
63-
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "model_kwargs": {"temperature": 0.5}}],
63+
[{"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b", "temperature": 0.5}],
6464
)
6565
@evaluation_test(
6666
data_loaders=DynamicDataLoader(
@@ -85,8 +85,6 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
8585
assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
8686
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
8787
)
88-
assert row.input_metadata.completion_params["model_kwargs"] == {"temperature": 0.5}, (
89-
"Row should have correct model_kwargs"
90-
)
88+
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"
9189

9290
return row

typescript/index.ts

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,17 @@ export function initRequestToCompletionParams(
8989
throw new Error("messages is required");
9090
}
9191

92-
const baseParams: ChatCompletionCreateParamsNonStreaming = {
92+
// Spread completion_params directly (model, temperature, max_tokens, etc.)
93+
const { model: _, ...otherParams } = initRequest.completion_params || {};
94+
95+
const completionParams: ChatCompletionCreateParamsNonStreaming = {
9396
model: model,
9497
messages: initRequest.messages,
9598
...(toolsToOpenAI && { tools: toolsToOpenAI }),
99+
...otherParams // Spreads temperature, max_tokens, etc.
96100
};
97101

98-
// Apply model_kwargs if present
99-
const model_kwargs = initRequest.completion_params?.['model_kwargs'];
100-
if (model_kwargs && typeof model_kwargs === 'object') {
101-
Object.assign(baseParams, model_kwargs);
102-
}
103-
104-
return baseParams;
102+
return completionParams;
105103
}
106104

107105
export function createLangfuseConfigTags(initRequest: InitRequest): string[] {

0 commit comments

Comments
 (0)