Skip to content

[5/n][trainer] feat: flowgrpo - test#56

Open
AndyZhou952 wants to merge 22 commits intozhtmike:diffusers_enginefrom
AndyZhou952:trainer-pr
Open

[5/n][trainer] feat: flowgrpo - test#56
AndyZhou952 wants to merge 22 commits intozhtmike:diffusers_enginefrom
AndyZhou952:trainer-pr

Conversation

@AndyZhou952
Copy link
Copy Markdown

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

test

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

@AndyZhou952 AndyZhou952 marked this pull request as ready for review April 1, 2026 09:19
@AndyZhou952 AndyZhou952 force-pushed the trainer-pr branch 3 times, most recently from 2ab30f2 to c7a829f Compare April 1, 2026 10:00
@AndyZhou952 AndyZhou952 marked this pull request as draft April 1, 2026 13:58
@zhtmike
Copy link
Copy Markdown
Owner

zhtmike commented Apr 1, 2026

Possibly ask ai to cherry pick eb393e7 to your branch?

@AndyZhou952
Copy link
Copy Markdown
Author

remaining todos: I think also need to move the changes in core_algo to diffusion/core_algo? also separate files for the test case? WDYT? @zhtmike

@zhtmike
Copy link
Copy Markdown
Owner

zhtmike commented Apr 1, 2026

remaining todos: I think also need to move the changes in core_algo to diffusion/core_algo? also separate files for the test case? WDYT? @zhtmike

Yes it is good

@AndyZhou952
Copy link
Copy Markdown
Author

AndyZhou952 commented Apr 1, 2026

remaining todos: I think also need to move the changes in core_algo to diffusion/core_algo? also separate files for the test case? WDYT? @zhtmike

Yes it is good

might need to move diffusion_algos.py in 4/n then also

will handle the cherry pick and the file move tmr

@zhtmike
Copy link
Copy Markdown
Owner

zhtmike commented Apr 1, 2026

remaining todos: I think also need to move the changes in core_algo to diffusion/core_algo? also separate files for the test case? WDYT? @zhtmike

Yes it is good

might need to move diffusion_algos.py in 4/n then also

will handle the cherry pick and the file move tmr

The ci is ok, maybe in this pr..


def _compute_old_log_prob(self, batch: DataProto):
# TODO: remove step 1, 2, 4 after we make the whole training tensordict and padding free
# step 1: convert dataproto to tensordict.
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also clean step xxx in _compute_ref_olg_prob make it consistent

adv_estimator_fn = core_algos.get_adv_estimator_fn(adv_estimator)
adv_kwargs = {
"token_level_rewards": data.batch["token_level_rewards"],
"token_level_rewards": data.batch["sample_level_rewards"],
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally change token_level_rewards to sample_level_rewards for all diffusion apis

OPTIMAL_TOKEN_BASELINE = "optimal_token_baseline"
TIR_OPTIMAL_TOKEN_BASELINE = "tir_optimal_token_baseline"
GDPO = "gdpo"
FLOW_GRPO = "flow_grpo"
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibily move to diffusion part alone

assert torch.equal(result.batch["returns"], torch.full((4, 6), 3.0))


def test_compute_advantage_dispatches_generic_estimator_with_diffusion_kwargs(
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need

)


def test_compute_response_mask_returns_valid_step_mask() -> None:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need

assert torch.equal(valid_step_mask, torch.ones((3, 5), dtype=torch.int32))


def test_build_diffusion_advantage_kwargs_maps_diffusion_batch_fields() -> None:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need

assert adv_kwargs["config"] is config


def test_compute_advantage_uses_diffusion_module_for_flow_grpo(monkeypatch: pytest.MonkeyPatch) -> None:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need. we only need to test the core algo. one is enough

assert result.batch["returns"].shape == (4, 6)


def test_flow_grpo_estimator_registered_from_diffusion_module() -> None:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need.

# gdpo_reward_weights: per-dimension weights for aggregation (default: equal weights).
gdpo_reward_keys: Optional[list[str]] = None
gdpo_reward_weights: Optional[list[float]] = None
global_std: bool = True
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to a new DiffusionAlgoConfig, since most are llm related params

) -> DataProto:
"""Compute diffusion advantages using the shared estimator registry."""
adv_kwargs = _build_diffusion_advantage_kwargs(data, config=config)
if adv_estimator == AdvantageEstimator.FLOW_GRPO:
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it generic. no if else branch

@AndyZhou952 AndyZhou952 marked this pull request as ready for review April 2, 2026 09:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants