-
Notifications
You must be signed in to change notification settings - Fork 16
support groupwise scoring #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
d0cd7de
2431dbe
a439e76
1b8032d
3406889
d587101
d000f19
9e41b2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -73,9 +73,11 @@ async def process_row(row: EvaluationRow) -> EvaluationRow: | |
|
|
||
| _litellm = importlib.import_module("litellm") | ||
| acompletion = getattr(_litellm, "acompletion") | ||
| logger.debug(f"********** request_params: {request_params} **********") | ||
| response = await acompletion(**request_params) | ||
|
|
||
| assistant_content = response.choices[0].message.content or "" | ||
| logger.debug(f"********** assistant_content: {assistant_content} **********") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None | ||
|
|
||
| converted_tool_calls = None | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,14 +19,11 @@ | |
|
|
||
| Dataset = List[EvaluationRow] | ||
|
|
||
| EvaluationTestMode = Literal["batch", "pointwise"] | ||
| EvaluationTestMode = Literal["pointwise", "groupwise", "listwise"] | ||
| """ | ||
| "batch": (default) expects test function to handle full dataset. | ||
| "pointwise": applies test function to each row. | ||
|
|
||
| How to choose between "batch" and "pointwise": | ||
| If your evaluation requires the rollout of all rows to be passed into your eval compute the score, use "batch". | ||
| If your evaluation can be computed pointwise, use "pointwise" as EP can pipeline the rollouts and evals to be faster. | ||
| "pointwise": (default) applies test function to each row (rollout result). | ||
| "groupwise": applies test function to a group of rollout results from the same original row (for use cases such as dpo/grpo). | ||
| "listwise": applies test function to the whole dataset. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. listwise is confusing, probably just "all" or something
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
| """ | ||
|
|
||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,28 @@ | ||
| from typing import List | ||
|
|
||
| from eval_protocol.models import EvaluationRow, Message, EvaluateResult | ||
| from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test | ||
|
|
||
|
|
||
| @evaluation_test( | ||
| input_messages=[ | ||
| [ | ||
| Message(role="user", content="What is the capital of France?"), | ||
| ] | ||
| ], | ||
| completion_params=[ | ||
| {"model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}, | ||
| {"model": "fireworks_ai/accounts/fireworks/models/gpt-4.1"}, | ||
| ], | ||
| rollout_processor=SingleTurnRolloutProcessor(), | ||
| mode="groupwise", | ||
| ) | ||
| def test_pytest_groupwise(rows: List[EvaluationRow]) -> List[EvaluationRow]: | ||
| """Run math evaluation on sample dataset using pytest interface.""" | ||
| assert rows[0].input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b" | ||
| assert rows[1].input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-4.1" | ||
| rows[0].evaluation_result = EvaluateResult(score=1.0, reason="test") | ||
| rows[1].evaluation_result = EvaluateResult(score=0.0, reason="test") | ||
| print(rows[0].model_dump_json()) | ||
| print(rows[1].model_dump_json()) | ||
| return rows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete