Skip to content

fix(qwen3-vl): CP-local vision-embed hook + AllGatherVisionEmbeddings.apply kwarg#9

Open
Zhichenzzz wants to merge 1 commit into
bridgefrom
fix/qwen3vl-cp-vision-hook
Open

fix(qwen3-vl): CP-local vision-embed hook + AllGatherVisionEmbeddings.apply kwarg#9
Zhichenzzz wants to merge 1 commit into
bridgefrom
fix/qwen3vl-cp-vision-hook

Conversation

@Zhichenzzz

@Zhichenzzz Zhichenzzz commented Jun 8, 2026

Copy link
Copy Markdown

What this PR does

Make the bridge's Qwen3-VL model work when the training framework passes CP-pre-sharded packed inputs (Megatron's standard load-balanced "zigzag" context-parallel layout), plus one small API fix:

  1. Native CP-local vision-embed selectionQwen3VLModel.forward computes vision embeddings for the full sequence and scatters them into the vision-token positions of input_ids. When input_ids are already sharded across CP ranks (detected via cu_seqlens_q[-1] == cp_size * len(input_ids)), the local vision mask covers only this rank's vision tokens, so the rows must be filtered to match. Qwen3VLModel._cp_local_vision_embed_indices computes the selection natively: it all-gathers only per-chunk vision-token counts (a few ints per segment — not the token ids) and places this rank's vision tokens in the full vision-token order. Stock full-sequence behaviour is unchanged (the method returns None for full inputs, cp=1, or unrecognized layouts).
  2. AllGatherVisionEmbeddings.apply kwarg fixtorch.autograd.Function.apply rejects keyword arguments; pass cp_group positionally at the two call sites.

Design notes

  • An earlier revision exposed an overridable identity hook for frameworks to inject the selection; review feedback rightly flagged that as awkward, so the selection is now implemented natively (the zigzag layout is Megatron's own convention, not framework-specific).
  • Control flow before the count all-gather only depends on cu_seqlens (identical on all CP ranks) — rank-dependent shortcuts happen strictly after the collective, otherwise the CP group deadlocks.
  • A rank whose chunks contain no vision tokens selects zero rows (empty index), matching its empty local mask.

Validation

Qwen3-VL-2B, CP2 TP4 and CP2 TP2, THD packing, geo3k RL with miles (radixark/miles#1308): coherent rollouts, train_rollout_logprob_abs_diff 0.0127–0.0131 (same healthy band as non-CP runs), no NCCL timeouts, stock non-CP path unchanged.

Companion PR: radixark/miles#1308 (per-segment mRoPE positions for the same pre-sharded path; its vision plumbing is deleted in favour of this native support).

…dings.apply kwarg

Add Qwen3VLModel._cp_local_vision_embed_indices: when a training framework
pre-shards input_ids across context-parallel ranks (THD packed, zigzag CP
layout), select only this rank's vision-embedding rows so the scatter into the
CP-local vision mask matches. All-gathers only per-chunk vision-token counts,
not token ids, and keeps control flow uniform across CP ranks before the
collective. Also pass cp_group positionally to AllGatherVisionEmbeddings.apply
(torch.autograd.Function.apply rejects keyword args).
@Zhichenzzz Zhichenzzz force-pushed the fix/qwen3vl-cp-vision-hook branch from 6c78756 to eeb84d7 Compare June 18, 2026 21:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant