[5/n][trainer] feat: flowgrpo - test#56
[5/n][trainer] feat: flowgrpo - test#56AndyZhou952 wants to merge 22 commits intozhtmike:diffusers_enginefrom
Conversation
2ab30f2 to
c7a829f
Compare
|
Possibly ask ai to cherry pick eb393e7 to your branch? |
|
remaining todos: I think also need to move the changes in core_algo to diffusion/core_algo? also separate files for the test case? WDYT? @zhtmike |
Yes it is good |
might need to move diffusion_algos.py in 4/n then also will handle the cherry pick and the file move tmr |
The ci is ok, maybe in this pr.. |
|
|
||
| def _compute_old_log_prob(self, batch: DataProto): | ||
| # TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free | ||
| # step 1: convert dataproto to tensordict. |
There was a problem hiding this comment.
also clean step xxx in _compute_ref_olg_prob make it consistent
| adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator) | ||
| adv_kwargs = { | ||
| "token_level_rewards": data.batch["token_level_rewards"], | ||
| "token_level_rewards": data.batch["sample_level_rewards"], |
There was a problem hiding this comment.
totally change token_level_rewards to sample_level_rewards for all diffusion apis
verl/trainer/ppo/core_algos.py
Outdated
| OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline" | ||
| TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline" | ||
| GDPO = "gdpo" | ||
| FLOW_GRPO = "flow_grpo" |
There was a problem hiding this comment.
possibily move to diffusion part alone
| assert torch.equal(result.batch["returns"], torch.full((4, 6), 3.0)) | ||
|
|
||
|
|
||
| def test_compute_advantage_dispatches_generic_estimator_with_diffusion_kwargs( |
| ) | ||
|
|
||
|
|
||
| def test_compute_response_mask_returns_valid_step_mask() -> None: |
| assert torch.equal(valid_step_mask, torch.ones((3, 5), dtype=torch.int32)) | ||
|
|
||
|
|
||
| def test_build_diffusion_advantage_kwargs_maps_diffusion_batch_fields() -> None: |
| assert adv_kwargs["config"] is config | ||
|
|
||
|
|
||
| def test_compute_advantage_uses_diffusion_module_for_flow_grpo(monkeypatch: pytest.MonkeyPatch) -> None: |
There was a problem hiding this comment.
no need. we only need to test the core algo. one is enough
| assert result.batch["returns"].shape == (4, 6) | ||
|
|
||
|
|
||
| def test_flow_grpo_estimator_registered_from_diffusion_module() -> None: |
| # gdpo_reward_weights: per-dimension weights for aggregation (default: equal weights). | ||
| gdpo_reward_keys: Optional[list[str]] = None | ||
| gdpo_reward_weights: Optional[list[float]] = None | ||
| global_std: bool = True |
There was a problem hiding this comment.
move to a new DiffusionAlgoConfig, since most are llm related params
verl/trainer/diffusion/advantage.py
Outdated
| ) -> DataProto: | ||
| """Compute diffusion advantages using the shared estimator registry.""" | ||
| adv_kwargs = _build_diffusion_advantage_kwargs(data, config=config) | ||
| if adv_estimator == AdvantageEstimator.FLOW_GRPO: |
There was a problem hiding this comment.
make it generic. no if else branch
What does this PR do?
test
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.