fix(qwen3-vl): CP-local vision-embed hook + AllGatherVisionEmbeddings.apply kwarg#9
Open
Zhichenzzz wants to merge 1 commit into
Open
fix(qwen3-vl): CP-local vision-embed hook + AllGatherVisionEmbeddings.apply kwarg#9Zhichenzzz wants to merge 1 commit into
Zhichenzzz wants to merge 1 commit into
Conversation
…dings.apply kwarg Add Qwen3VLModel._cp_local_vision_embed_indices: when a training framework pre-shards input_ids across context-parallel ranks (THD packed, zigzag CP layout), select only this rank's vision-embedding rows so the scatter into the CP-local vision mask matches. All-gathers only per-chunk vision-token counts, not token ids, and keeps control flow uniform across CP ranks before the collective. Also pass cp_group positionally to AllGatherVisionEmbeddings.apply (torch.autograd.Function.apply rejects keyword args).
6c78756 to
eeb84d7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this PR does
Make the bridge's Qwen3-VL model work when the training framework passes CP-pre-sharded packed inputs (Megatron's standard load-balanced "zigzag" context-parallel layout), plus one small API fix:
Qwen3VLModel.forwardcomputes vision embeddings for the full sequence and scatters them into the vision-token positions ofinput_ids. Wheninput_idsare already sharded across CP ranks (detected viacu_seqlens_q[-1] == cp_size * len(input_ids)), the local vision mask covers only this rank's vision tokens, so the rows must be filtered to match.Qwen3VLModel._cp_local_vision_embed_indicescomputes the selection natively: it all-gathers only per-chunk vision-token counts (a few ints per segment — not the token ids) and places this rank's vision tokens in the full vision-token order. Stock full-sequence behaviour is unchanged (the method returnsNonefor full inputs, cp=1, or unrecognized layouts).AllGatherVisionEmbeddings.applykwarg fix —torch.autograd.Function.applyrejects keyword arguments; passcp_grouppositionally at the two call sites.Design notes
cu_seqlens(identical on all CP ranks) — rank-dependent shortcuts happen strictly after the collective, otherwise the CP group deadlocks.Validation
Qwen3-VL-2B, CP2 TP4 and CP2 TP2, THD packing, geo3k RL with miles (radixark/miles#1308): coherent rollouts,
train_rollout_logprob_abs_diff0.0127–0.0131 (same healthy band as non-CP runs), no NCCL timeouts, stock non-CP path unchanged.Companion PR: radixark/miles#1308 (per-segment mRoPE positions for the same pre-sharded path; its vision plumbing is deleted in favour of this native support).