From 9f099d42292bd1ecafcc0dfd626954be21952f30 Mon Sep 17 00:00:00 2001 From: Atharva Joshi Date: Tue, 9 Jun 2026 15:34:57 -0700 Subject: [PATCH 1/2] fix(cosmos3): pin VAE latent norm buffers to encode output device Under sharded placement (device_map="balanced"), vae.encode() runs on the VAE's own device while the mean/inv_std buffers were pinned to x.device, causing a cross-device RuntimeError. Compute raw_mu first, then pin the normalization buffers to its device so all tensors share one device. --- src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py index 5425b7b575eb..39012327b61c 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos3_omni.py @@ -450,9 +450,9 @@ def _encode_video(self, x: torch.Tensor) -> torch.Tensor: matches Wan2pt2VAEInterface; no autocast (WanVAE was trained with is_amp=False).""" in_dtype = x.dtype dtype = self.vae.dtype - mean = self._vae_latents_mean.to(device=x.device, dtype=dtype) - inv_std = self._vae_latents_inv_std.to(device=x.device, dtype=dtype) raw_mu = retrieve_latents(self.vae.encode(x.to(dtype)), sample_mode="argmax") + mean = self._vae_latents_mean.to(device=raw_mu.device, dtype=dtype) + inv_std = self._vae_latents_inv_std.to(device=raw_mu.device, dtype=dtype) return ((raw_mu - mean.view(1, -1, 1, 1, 1)) * inv_std.view(1, -1, 1, 1, 1)).to(in_dtype) def decode_sound(self, latent: torch.Tensor) -> torch.Tensor: From 6edc5fdd28b724e0f356facd36ec0510ca842057 Mon Sep 17 00:00:00 2001 From: Atharva Joshi Date: Tue, 23 Jun 2026 10:46:04 -0700 Subject: [PATCH 2/2] =?UTF-8?q?feat(cosmos3):=20multi-GPU=20inference=20?= =?UTF-8?q?=E2=80=94=20context=20+=20tensor=20parallelism?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cosmos 3 cannot use diffusers' declarative `_cp_plan` CP path: it is grouped-query attention (the shared Ulysses kernel assumes K/V share the query head count), its understanding (causal) and generation (full) streams are separate packed sequences (gen attends to cat(und, gen)), and per-pathway lengths are ragged. The model carries no parallelism logic -- it exposes only small, CP-agnostic seams; all sharding lives outside it, in a reusable example module. Model (transformer_cosmos3.py): adds two default-None `forward` seams -- `_cp_shard_fn` (shards und/gen + rotary before the decoder layers) and `_cp_gather_fn` (gathers/unpads after the final norm) -- and extracts `Cosmos3AttnProcessor._run_attention` as an override point. The non-parallel path is unchanged. Helpers (examples/cosmos3/cosmos_parallel.py): one importable module, two orthogonal and composable axes: * Context parallelism (Ulysses) -- `enable_cosmos3_context_parallel`. Shards the sequence; brackets the two attention pathways with all-to-all (DTensor redistribute), repeats GQA KV heads, pads ragged lengths and masks padded generation keys. * Tensor parallelism (Megatron) -- `enable_cosmos3_tensor_parallel`. Column/row-shards the attention + MLP weights so a checkpoint that does not fit one GPU (Super, ~120 GB) loads across several; weights load to CPU then shard layer by layer. Both expand KV heads to the query-head count and call SDPA with enable_gqa=False so it dispatches to the flash kernel; enable_gqa=True forces the math path, which materializes the full [S, S] score matrix and OOMs on long videos. A dense `Cosmos3FlashAttnProcessor` (`enable_cosmos3_flash_attention`) provides the same for TP without CP. CLI (examples/cosmos3/inference_cosmos3.py): imports these helpers, so any modality (text-to-image/video, image-to-video, sound, action) runs single- or multi-GPU via `--tp-degree` / `--cp-degree` (their product must equal --nproc_per_node). Single-GPU behavior is unchanged. Docs + example README updated. Verified: CP attention core is bit-exact vs non-CP in fp32 (max|d|=0), and a full 36-layer forward matches CP-on vs CP-off to ~1e-6 in fp32 (bf16 differs only by floating-point rounding). --- docs/source/en/api/pipelines/cosmos3.md | 89 ++++ examples/cosmos3/README.md | 65 ++- examples/cosmos3/cosmos_parallel.py | 426 ++++++++++++++++++ examples/cosmos3/inference_cosmos3.py | 162 +++++-- .../transformers/transformer_cosmos3.py | 47 +- 5 files changed, 747 insertions(+), 42 deletions(-) create mode 100644 examples/cosmos3/cosmos_parallel.py diff --git a/docs/source/en/api/pipelines/cosmos3.md b/docs/source/en/api/pipelines/cosmos3.md index 7ce1ff4f58cf..3b4e128d3054 100644 --- a/docs/source/en/api/pipelines/cosmos3.md +++ b/docs/source/en/api/pipelines/cosmos3.md @@ -464,6 +464,95 @@ if result.action is not None: +## Context parallelism + +For long videos or high resolutions, a single forward pass can exceed the memory and latency budget of one GPU. Cosmos 3 supports **context parallelism (CP)** to shard the sequence dimension across multiple GPUs, splitting the attention computation so each device holds only a slice of the tokens. + +Cosmos 3 supports **Ulysses** context parallelism (all-to-all sequence/head exchange). Ring attention is not supported. + +Unlike most diffusers models, Cosmos 3 does **not** wire CP into the transformer or the declarative [`~ModelMixin.enable_parallelism`] path: its grouped-query attention, separate understanding/generation streams (the generation stream attends to both), and ragged per-stream lengths can't be expressed as a `_cp_plan`. Instead, the model exposes small no-op shard/gather seams, and the implementation lives in [`examples/cosmos3/cosmos_parallel.py`](https://github.com/huggingface/diffusers/blob/main/examples/cosmos3/cosmos_parallel.py) — a self-contained module you can read end to end and adapt. It offers two orthogonal, composable sharding axes: + +| Helper | Shards | Use for | +|---|---|---| +| `enable_cosmos3_context_parallel(transformer, cp_mesh)` | sequence (CP / Ulysses) | latency on a model that fits one GPU (`Nano`) | +| `enable_cosmos3_tensor_parallel(transformer, tp_mesh)` | weights (TP) | fitting a model that doesn't fit one GPU (`Super`) | + +Use either alone or both together on a 2-D `(tp, cp)` mesh (see [Fitting large models with tensor parallelism](#fitting-large-models-with-tensor-parallelism)). + +Two requirements are specific to Cosmos 3: + +- Use the `native` attention backend. Cosmos 3 uses grouped-query attention (GQA), and the native SDPA backend is the only one that accepts `enable_gqa` (cuDNN and flash reject it). The helpers expand the KV heads to the query-head count and call SDPA with `enable_gqa=False` so it still dispatches to the flash kernel (the math fallback would materialize the full `[S, S]` scores and OOM on long sequences). +- The CP (Ulysses) degree must divide the query-head count (32 for `Nano`, 64 for `Super`); for TP, the degree must divide the KV heads (8). The understanding (text) and generation (video/sound) streams are sharded independently along the sequence, and ragged lengths are zero-padded internally to a multiple of the world size. + +### Run it + +The full CLI [`examples/cosmos3/inference_cosmos3.py`](https://github.com/huggingface/diffusers/blob/main/examples/cosmos3/inference_cosmos3.py) reuses these helpers, so **any modality** (text-to-image/video, image-to-video, sound, action modes) runs multi-GPU via `--tp-degree` / `--cp-degree`. Launch with [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html); `--tp-degree * --cp-degree` must equal `--nproc_per_node`. Every rank produces the same output; rank 0 writes it. + +```bash +# CP only — Nano (fits one GPU); CP degree must divide 32 query heads. +torchrun --nproc_per_node=4 examples/cosmos3/inference_cosmos3.py --model nano --cp-degree 4 --prompt "..." + +# TP only — Super; TP degree must divide 64 query heads and 8 KV heads. +torchrun --nproc_per_node=4 examples/cosmos3/inference_cosmos3.py --model super --tp-degree 4 --prompt "..." + +# TP + CP — Super, with sound (TP=2 x CP=2 across 4 GPUs). +torchrun --nproc_per_node=4 examples/cosmos3/inference_cosmos3.py \ + --model super --tp-degree 2 --cp-degree 2 --enable-sound --prompt "..." +``` + +`Super`'s ~120 GB of weights do not fit on one 96 GB GPU, so it needs TP; `Nano` fits on a single GPU, so CP for it is a pure latency optimization. (Omit both flags to run single-GPU.) + +### Fitting large models with tensor parallelism + +CP shards *activations* but replicates every weight on every rank, so it does not reduce a model's weight footprint — a model that doesn't fit on one GPU still won't fit under CP alone. To shard the **weights**, `enable_cosmos3_tensor_parallel(transformer, tp_mesh)` applies Megatron-style tensor parallelism on a second, orthogonal mesh axis: + +- The attention and MLP projections are column/row sharded across the TP group (`to_q/to_k/to_v` + `add_q/k/v` and the MLPs' `gate/up` are column-parallel; `to_out/to_add_out` and the MLPs' `down` are row-parallel with an all-reduce). Each rank ends up owning `query_heads / tp` query heads and `kv_heads / tp` KV heads. +- TP composes with CP on a 2-D `(tp, cp)` device mesh: TP splits heads/weights persistently, CP shards the sequence on top. The constraints are `tp` divides the KV heads (8), and `tp * cp` divides the query heads (32 for `Nano`, 64 for `Super`). +- Weights are loaded to CPU and sharded onto the GPUs layer by layer, so the full model is never materialized on a single device. + +> [!TIP] +> TP issues an all-reduce on every attention and MLP block, so it is bandwidth-heavy. On hosts without NVLink it is the dominant cost; prefer the smallest TP degree that makes the weights fit and put the remaining GPUs into CP. + +### Use it in your own pipeline + +The CLI flags are convenient, but you can call the helpers directly. Build the device mesh, apply TP *before* the model lands on the GPUs, switch to the `native` backend, then enable CP — the rest of your pipeline code is unchanged: + +```python +from torch.distributed.device_mesh import init_device_mesh + +# Make the helper module importable. +import sys +sys.path.insert(0, "examples/cosmos3") +from cosmos_parallel import ( + enable_cosmos3_context_parallel, + enable_cosmos3_flash_attention, + enable_cosmos3_tensor_parallel, +) + +# torchrun sets RANK / WORLD_SIZE / LOCAL_RANK. Pick tp_degree * cp_degree == world size. +mesh = init_device_mesh("cuda", (tp_degree, cp_degree), mesh_dim_names=("tp", "cp")) + +# Load on CPU first; a TP-sharded model may not fit one GPU. +pipe = Cosmos3OmniPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) +if tp_degree > 1: + enable_cosmos3_tensor_parallel(pipe.transformer, mesh["tp"]) # shard weights -> GPUs +pipe.to(f"cuda:{local_rank}") # move the replicated remainder +pipe.transformer.set_attention_backend("native") +if cp_degree > 1: + enable_cosmos3_context_parallel(pipe.transformer, mesh["cp"]) # shard the sequence +elif tp_degree > 1: + enable_cosmos3_flash_attention(pipe.transformer) # GQA-safe dense attention + +# `pipe(...)` is called exactly as in the single-GPU workflows above. +``` + +For CP only (no weight sharding), use a 1-D mesh: `init_device_mesh("cuda", (world_size,), mesh_dim_names=("cp",))` and just `enable_cosmos3_context_parallel`. + +> [!TIP] +> On some multi-GPU topologies the first NCCL all-to-all can hang. If a CP run stalls at the start of the first denoising step, set `NCCL_P2P_DISABLE=1` in the environment before launching `torchrun`. + +CP and TP compose with all the workflows above (text-to-video, image-to-video, text-to-video with sound, and action-conditioned generation) and with both the `Nano` and `Super` checkpoints — only the pipeline construction and the parallelism setup lines change. + ## Metadata templates `tokenize_prompt` appends short metadata sentences inside the user message so the LLM sees the conditioning the model was trained with. The positive prompt gets sentences like *"The video is 7.9 seconds long and is of 24 FPS."* and *"This video is of 720x1280 resolution."*; the negative prompt gets the inverse (*"… is not …"*). diff --git a/examples/cosmos3/README.md b/examples/cosmos3/README.md index dd4be5dc286f..08fd1f50969c 100644 --- a/examples/cosmos3/README.md +++ b/examples/cosmos3/README.md @@ -5,9 +5,14 @@ The canonical reference for `Cosmos3OmniPipeline` lives in the diffusers docs: examples there as the source of truth for application code — they cover text-to-image, text-to-video, image-to-video, and text+sound modes. -This directory provides a small CLI wrapper (`inference_cosmos3.py`) that exercises the full -load → encode → denoise → decode path against either the Hub release or a local checkpoint -during development. +This directory provides two files: + +- `inference_cosmos3.py` — the runnable CLI (text-to-image/video, image-to-video, sound, action + modes). Single-GPU by default; pass `--tp-degree` / `--cp-degree` and launch with `torchrun` + to run any modality multi-GPU (see [Multi-GPU inference](#multi-gpu-inference-context-parallelism) + below). +- `cosmos_parallel.py` — the importable multi-GPU helpers (context + tensor parallelism). No + `main`; the CLI imports from it. Read it to understand or adapt the sharding. ## Setup @@ -168,3 +173,57 @@ Pick the tier that matches the native resolution of your conditioning input (`48 | `--no-duration-template` | off | Skip the duration metadata sentence appended to the prompt and negative prompt. Ignored for `--num-frames 1` and for action modes (which build a structured caption instead). | | `--no-resolution-template` | off | Skip the resolution metadata sentence appended to the prompt and negative prompt. Ignored for action modes. | | `--output` | `.` | Directory to write `sample.jpg` or `sample.mp4`. | + +## Multi-GPU inference (context parallelism) + +Cosmos 3 can be sharded across GPUs on two orthogonal axes (implemented in `cosmos_parallel.py`): + +- **Context parallelism (CP)** — `enable_cosmos3_context_parallel`. The *sequence* is sharded + across GPUs and attention runs with two Ulysses all-to-all collectives per layer, cutting + per-step latency for long videos / high resolutions. Weights are replicated, so this is for + models that already fit one GPU (`Nano`). +- **Tensor parallelism (TP)** — `enable_cosmos3_tensor_parallel`. The attention and MLP *weight* + matrices are sharded across GPUs (Megatron-style), so a checkpoint that doesn't fit one GPU + (`Super`, ~120 GB) loads. The sequence is not sharded. +- **TP + CP** — both at once on a 2-D `(tp, cp)` mesh: a large model *and* latency. + +The model itself carries no parallelism logic — it exposes small no-op shard/gather seams, and +`cosmos_parallel.py` implements the entire path (collectives, GQA KV-head handling, ragged-length +padding, the dual-pathway attention, weight sharding) behind those two helpers. It is meant to be +read end to end and adapted. + +The CLI imports these helpers, so you run **any modality** (text-to-image/video, image-to-video, +sound, action modes) multi-GPU by adding `--tp-degree` / `--cp-degree` and launching with +[torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html) — `--tp-degree * --cp-degree` +must equal `--nproc_per_node`: + +```bash +# CP only (Nano): CP degree must divide the 32 query heads. +torchrun --nproc_per_node 4 examples/cosmos3/inference_cosmos3.py --model nano --cp-degree 4 --prompt "..." + +# TP only (Super): TP degree must divide the 64 query heads and 8 KV heads. +torchrun --nproc_per_node 4 examples/cosmos3/inference_cosmos3.py --model super --tp-degree 4 --prompt "..." + +# TP + CP (Super), 4 GPUs as 2 x 2, with sound: +torchrun --nproc_per_node 4 examples/cosmos3/inference_cosmos3.py \ + --model super --tp-degree 2 --cp-degree 2 --enable-sound --prompt "A waterfall in a forest." +``` + +Notes: + +- The helpers use the `native` attention backend (the only one that supports GQA's `enable_gqa`), + and expand the KV heads to the query-head count so SDPA picks the flash kernel — passing + `enable_gqa=True` forces the math kernel, which materializes the full `[S, S]` scores and OOMs + on long sequences. +- Only Ulysses is supported (not ring attention). +- The CP/Ulysses degree must divide the query heads (32 for `Nano`, 64 for `Super`). For TP, + `tp` must divide the KV heads (8), and `tp * cp` must divide the query heads. +- TP all-reduces on every block, so it's bandwidth-heavy — use the smallest TP degree that makes + the weights fit and put the remaining GPUs into CP. +- Generation size is set with the usual CLI flags (`--num-frames` / `--height` / `--width`), and + multi-GPU runs require a seed for reproducibility across ranks (the CLI sets one if you omit `--seed`). +- On some multi-GPU topologies the first NCCL all-to-all can hang; if a run stalls at the first + denoising step, set `NCCL_P2P_DISABLE=1` before launching. + +See the [pipeline docs](../../docs/source/en/api/pipelines/cosmos3.md#context-parallelism) for how +to enable CP and TP from your own pipeline code. diff --git a/examples/cosmos3/cosmos_parallel.py b/examples/cosmos3/cosmos_parallel.py new file mode 100644 index 000000000000..beed164a716a --- /dev/null +++ b/examples/cosmos3/cosmos_parallel.py @@ -0,0 +1,426 @@ +# Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Multi-GPU helpers for Cosmos 3, implemented entirely outside the model. + +This module holds the implementation; the runnable entry point is +``inference_cosmos3.py`` (pass ``--tp-degree`` / ``--cp-degree`` and launch with +``torchrun`` to use any of these across all the pipeline's modalities). Two orthogonal +sharding axes are provided, composable on a 2-D ``(tp, cp)`` device mesh: + + * Context parallelism (CP / Ulysses) — ``enable_cosmos3_context_parallel``. Shards the + *sequence* across GPUs; attention runs with two all-to-all collectives per layer + (gather seq / scatter heads -> local attention -> gather heads / scatter seq). + Replicates the weights, so it cuts latency but not weight memory. + * Tensor parallelism (TP) — ``enable_cosmos3_tensor_parallel``. Shards the attention + and MLP *weight* matrices (Megatron-style), so a checkpoint that doesn't fit one GPU + (Cosmos3-Super, ~120 GB) loads across several. Attention stays dense; pair it with + ``enable_cosmos3_flash_attention`` (or with CP) so GQA uses the flash kernel. + +The model carries no parallelism logic — it exposes small no-op seams: +``transformer._cp_shard_fn`` / ``_cp_gather_fn`` (sequence shard/gather around the +decoder stack) and ``Cosmos3AttnProcessor._run_attention`` (an override seam for the +attention core). The helpers below wire those up. + +Why Cosmos 3 needs a custom CP path (not diffusers' declarative ``_cp_plan``): + 1. grouped-query attention — K/V heads must be repeated to match the query heads; + 2. separate understanding (causal) / generation (full) streams, with the generation + stream attending to ``cat(und, gen)``; + 3. ragged per-stream lengths — each pathway is padded independently and the padded + generation keys are masked. + +GQA + the flash kernel: SDPA's flash/cuDNN kernels reject ``enable_gqa`` and the native +kernel falls back to math (which materializes the full ``[S, S]`` scores and OOMs on long +sequences). Both attention paths here instead expand the KV heads up to the query-head +count and call SDPA with ``enable_gqa=False``, so it dispatches to flash (O(S) memory). +""" + +import torch + +from diffusers.models.attention_dispatch import AttentionBackendName, dispatch_attention_fn +from diffusers.models.transformers.transformer_cosmos3 import Cosmos3AttnProcessor + + +try: # torch >= 2.4 + from torch.distributed.tensor import DTensor, Replicate, Shard +except ImportError: # pragma: no cover - older torch + from torch.distributed._tensor import DTensor, Replicate, Shard + + +def _repeat_kv_heads(x, repeats): + """Repeat KV heads (for GQA): ``[seq, num_kv_heads, d] -> [seq, num_kv_heads * repeats, d]``. + + Each KV head is repeated ``repeats`` times contiguously, matching GQA grouping + (query head i pairs with KV group i // repeats). + """ + if repeats == 1: + return x + seq_len, num_kv_heads, head_dim = x.shape + x = x[:, :, None, :].expand(seq_len, num_kv_heads, repeats, head_dim) + return x.reshape(seq_len, num_kv_heads * repeats, head_dim) + + +# ============================================================================= +# Context parallelism (Ulysses) +# ============================================================================= +# --- Collective primitives (all-to-all / all-gather via DTensor redistribute) --- +def _cp_all_to_all(local_input, scatter_dim, gather_dim, cp_mesh): + """All-to-all via DTensor redistribute: gather ``gather_dim``, scatter ``scatter_dim``.""" + dt = DTensor.from_local(local_input, cp_mesh, [Shard(gather_dim)], run_check=False) + return dt.redistribute(cp_mesh, [Shard(scatter_dim)]).to_local() + + +def _cp_all_gather(local_input, gather_dim, cp_mesh): + """All-gather via DTensor redistribute: ``Shard(gather_dim) -> Replicate()``.""" + dt = DTensor.from_local(local_input, cp_mesh, [Shard(gather_dim)], run_check=False) + return dt.redistribute(cp_mesh, [Replicate()]).to_local() + + +def _cp_gather_seq_scatter_heads(x, cp_mesh): + """``[seq/cp, h, d] -> [seq, h/cp, d]``.""" + return _cp_all_to_all(x, scatter_dim=1, gather_dim=0, cp_mesh=cp_mesh) + + +def _cp_gather_heads_scatter_seq(x, cp_mesh): + """``[seq, h/cp, d] -> [seq/cp, h, d]``.""" + return _cp_all_to_all(x, scatter_dim=0, gather_dim=1, cp_mesh=cp_mesh) + + +def _cp_pad_dim0(x, target_len): + """Zero-pad ``x`` along dim 0 up to ``target_len`` (no-op if already there).""" + pad = target_len - x.shape[0] + if pad <= 0: + return x + return torch.cat([x, x.new_zeros((pad, *x.shape[1:]))], dim=0) + + +def _cp_shard_dim0(x, cp_mesh): + """Keep this rank's contiguous shard along dim 0 (dim 0 must be divisible).""" + world = cp_mesh.size() + if world == 1: + return x + rank = cp_mesh.get_local_rank() + shard = x.shape[0] // world + return x.narrow(0, rank * shard, shard).contiguous() + + +def _cp_gather_dim0(x, cp_mesh): + """Reassemble the full sequence (dim 0) from per-rank shards, in rank order.""" + if cp_mesh.size() == 1: + return x + return _cp_all_gather(x, gather_dim=0, cp_mesh=cp_mesh) + + +# --- Sharding / gathering of the dual-pathway packed sequence --- +def shard_cosmos3_sequence(und_seq, gen_seq, rotary_emb, cp_mesh): + """Pad each pathway to a multiple of the CP world size, then shard the und/gen + hidden states and their rotary embeddings along the sequence dim, independently + per pathway. ``rotary_emb`` is ``(cos_und, sin_und, cos_gen, sin_gen)``. + + Returns ``(und_seq, gen_seq, rotary_emb, meta)`` where ``meta`` records the real + and padded per-pathway lengths so attention can mask padded keys and the caller + can slice the padding off after gathering. + """ + world = cp_mesh.size() + cos_und, sin_und, cos_gen, sin_gen = rotary_emb + und_real, gen_real = und_seq.shape[0], gen_seq.shape[0] + und_padded = ((und_real + world - 1) // world) * world + gen_padded = ((gen_real + world - 1) // world) * world + meta = {"und_real": und_real, "gen_real": gen_real, "und_padded": und_padded, "gen_padded": gen_padded} + + und_seq = _cp_shard_dim0(_cp_pad_dim0(und_seq, und_padded), cp_mesh) + gen_seq = _cp_shard_dim0(_cp_pad_dim0(gen_seq, gen_padded), cp_mesh) + rotary_emb = ( + _cp_shard_dim0(_cp_pad_dim0(cos_und, und_padded), cp_mesh), + _cp_shard_dim0(_cp_pad_dim0(sin_und, und_padded), cp_mesh), + _cp_shard_dim0(_cp_pad_dim0(cos_gen, gen_padded), cp_mesh), + _cp_shard_dim0(_cp_pad_dim0(sin_gen, gen_padded), cp_mesh), + ) + return und_seq, gen_seq, rotary_emb, meta + + +def build_gen_key_mask(meta, dtype, device): + """Additive key mask for the generation pathway's full attention. Padded und/gen + key positions are set to ``-inf`` so real generation queries ignore them. Returns + ``None`` when no padding was added. Shape ``[1, 1, 1, S_k]``. + """ + if meta["und_real"] == meta["und_padded"] and meta["gen_real"] == meta["gen_padded"]: + return None + s_k = meta["und_padded"] + meta["gen_padded"] + mask = torch.zeros(s_k, dtype=dtype, device=device) + neg_inf = torch.finfo(dtype).min + mask[meta["und_real"] : meta["und_padded"]] = neg_inf + mask[meta["und_padded"] + meta["gen_real"] :] = neg_inf + return mask.view(1, 1, 1, s_k) + + +def gather_and_unpad(und_out, gen_out, meta, cp_mesh): + """Gather each pathway to its full padded length, then slice off the padding.""" + und_out = _cp_gather_dim0(und_out, cp_mesh)[: meta["und_real"]] + gen_out = _cp_gather_dim0(gen_out, cp_mesh)[: meta["gen_real"]] + return und_out, gen_out + + +def cosmos3_cp_attention(cp_mesh, q_und, k_und, v_und, q_gen, k_gen, v_gen, gen_key_mask=None): + """Ulysses context-parallel attention for the dual-pathway packed sequence. + + All inputs are *sequence-sharded*: q ``[S/cp, H, D]``, k/v ``[S/cp, Hkv, D]`` + (no batch dim). Returns ``(causal_out, full_out)`` flattened to ``[S/cp, H*D]`` + and ``[S_gen/cp, H*D]``. The understanding pathway self-attends causally; the + generation pathway attends to the concatenation of und+gen keys/values. + """ + world = cp_mesh.size() + q_heads = q_und.shape[1] + kv_heads = k_und.shape[1] + if q_heads % world != 0: + raise ValueError(f"Query heads ({q_heads}) must be divisible by CP world size ({world}).") + + # GQA, step 1: repeat KV heads up to a multiple of the world size so the head + # scatter in the all-to-all is valid (each rank must receive an equal share). + kv_head_repeats = max(world // kv_heads, 1) + repeated_kv_heads = kv_heads * kv_head_repeats + if repeated_kv_heads % world != 0: + raise ValueError(f"Repeated KV heads ({repeated_kv_heads}) must be divisible by CP world size ({world}).") + if kv_head_repeats > 1: + k_und = _repeat_kv_heads(k_und, kv_head_repeats) + v_und = _repeat_kv_heads(v_und, kv_head_repeats) + k_gen = _repeat_kv_heads(k_gen, kv_head_repeats) + v_gen = _repeat_kv_heads(v_gen, kv_head_repeats) + + # all-to-all #1: gather sequence, scatter heads -> [S, H/cp, D] + q_und = _cp_gather_seq_scatter_heads(q_und, cp_mesh) + k_und = _cp_gather_seq_scatter_heads(k_und, cp_mesh) + v_und = _cp_gather_seq_scatter_heads(v_und, cp_mesh) + q_gen = _cp_gather_seq_scatter_heads(q_gen, cp_mesh) + k_gen = _cp_gather_seq_scatter_heads(k_gen, cp_mesh) + v_gen = _cp_gather_seq_scatter_heads(v_gen, cp_mesh) + + # GQA, step 2: locally expand each rank's KV head shard up to its query-head + # count, so attention runs with equal Q/K/V heads (enable_gqa=False). This lets + # SDPA dispatch to the flash / memory-efficient kernel (O(S) memory); passing + # enable_gqa=True instead forces the math fallback, which materializes the full + # [S, S] score matrix and OOMs on long sequences (CP shards heads across ranks, + # but each rank still attends over the *full* sequence length). GQA grouping is + # preserved: local query head i pairs with local KV group i // head_repeats. + q_local, kv_local = q_und.shape[1], k_und.shape[1] + if q_local % kv_local != 0: + raise ValueError(f"Local query heads ({q_local}) must be a multiple of local KV heads ({kv_local}).") + head_repeats = q_local // kv_local + if head_repeats > 1: + k_und = _repeat_kv_heads(k_und, head_repeats) + v_und = _repeat_kv_heads(v_und, head_repeats) + k_gen = _repeat_kv_heads(k_gen, head_repeats) + v_gen = _repeat_kv_heads(v_gen, head_repeats) + + # Understanding pathway: causal self-attention over the full und sequence. + causal_out = dispatch_attention_fn( + q_und.unsqueeze(0), + k_und.unsqueeze(0), + v_und.unsqueeze(0), + is_causal=True, + enable_gqa=False, + backend=AttentionBackendName.NATIVE, + parallel_config=None, + ).squeeze(0) + + # Generation pathway: full attention over cat(und, gen) keys/values. + all_k = torch.cat([k_und, k_gen], dim=0) + all_v = torch.cat([v_und, v_gen], dim=0) + full_out = dispatch_attention_fn( + q_gen.unsqueeze(0), + all_k.unsqueeze(0), + all_v.unsqueeze(0), + attn_mask=gen_key_mask, + is_causal=False, + enable_gqa=False, + backend=AttentionBackendName.NATIVE, + parallel_config=None, + ).squeeze(0) + + # all-to-all #2: gather heads, scatter sequence -> [S/cp, H, D] + causal_out = _cp_gather_heads_scatter_seq(causal_out, cp_mesh) + full_out = _cp_gather_heads_scatter_seq(full_out, cp_mesh) + return causal_out.flatten(-2, -1), full_out.flatten(-2, -1) + + +class Cosmos3CPAttnProcessor(Cosmos3AttnProcessor): + """Cosmos 3 attention processor whose attention core runs Ulysses CP. + + It reuses the base processor's projection + rotary code (``__call__``) and overrides + only ``_run_attention`` to bracket the two pathways with all-to-all collectives. The + Ulysses mesh is read from ``self.cp_mesh``; the per-call generation key mask is read + from the attention module (stamped each forward by the shard seam). + """ + + def __init__(self, cp_mesh): + self.cp_mesh = cp_mesh + + def _run_attention(self, attn, q_und, k_und, v_und, q_gen, k_gen, v_gen): + return cosmos3_cp_attention( + self.cp_mesh, + q_und, + k_und, + v_und, + q_gen, + k_gen, + v_gen, + gen_key_mask=getattr(attn, "_cp_gen_key_mask", None), + ) + + +def enable_cosmos3_context_parallel(transformer, cp_mesh): + """Wire Ulysses context parallelism onto a ``Cosmos3OmniTransformer`` instance. + + Sets the CP attention processor on every decoder layer and installs the shard/gather + seams so the model shards each pathway across ``cp_mesh`` and gathers before decode. + All CP state lives in the closures below, so the model itself stays CP-free. + """ + processor = Cosmos3CPAttnProcessor(cp_mesh) + for layer in transformer.layers: + layer.self_attn.set_processor(processor) + + state = {"meta": None} + + def shard_fn(und_seq, gen_seq, rotary_emb): + und_seq, gen_seq, rotary_emb, meta = shard_cosmos3_sequence(und_seq, gen_seq, rotary_emb, cp_mesh) + state["meta"] = meta + gen_key_mask = build_gen_key_mask(meta, und_seq.dtype, und_seq.device) + # The processor reads the mask off the attention module each forward. + for layer in transformer.layers: + layer.self_attn._cp_gen_key_mask = gen_key_mask + return und_seq, gen_seq, rotary_emb + + def gather_fn(und_out, gen_out): + return gather_and_unpad(und_out, gen_out, state["meta"], cp_mesh) + + transformer._cp_shard_fn = shard_fn + transformer._cp_gather_fn = gather_fn + return transformer + + +# ============================================================================= +# Dense flash attention (GQA-safe; for TP without CP) +# ============================================================================= +class Cosmos3FlashAttnProcessor(Cosmos3AttnProcessor): + """Dense attention that expands the GQA KV heads to the query-head count so SDPA + uses the flash kernel (``enable_gqa=False``) instead of the math fallback, which + materializes the full ``[S, S]`` score matrix and OOMs on long sequences. + + No collectives: attention is computed locally over the full sequence (the head + counts on the attention module are the rank-local values set by + ``enable_cosmos3_tensor_parallel``). Use this with TP-only; CP installs its own + processor that already handles the flash dispatch. + """ + + def _run_attention(self, attn, q_und, k_und, v_und, q_gen, k_gen, v_gen): + repeats = attn.num_attention_heads // attn.num_key_value_heads + k_und, v_und = _repeat_kv_heads(k_und, repeats), _repeat_kv_heads(v_und, repeats) + k_gen, v_gen = _repeat_kv_heads(k_gen, repeats), _repeat_kv_heads(v_gen, repeats) + + # Understanding pathway: causal self-attention. + causal_out = dispatch_attention_fn( + q_und.unsqueeze(0), + k_und.unsqueeze(0), + v_und.unsqueeze(0), + is_causal=True, + enable_gqa=False, + backend=AttentionBackendName.NATIVE, + parallel_config=None, + ).squeeze(0) + + # Generation pathway: full attention over cat(und, gen) keys/values. + all_k = torch.cat([k_und, k_gen], dim=0) + all_v = torch.cat([v_und, v_gen], dim=0) + full_out = dispatch_attention_fn( + q_gen.unsqueeze(0), + all_k.unsqueeze(0), + all_v.unsqueeze(0), + is_causal=False, + enable_gqa=False, + backend=AttentionBackendName.NATIVE, + parallel_config=None, + ).squeeze(0) + return causal_out.flatten(-2, -1), full_out.flatten(-2, -1) + + +def enable_cosmos3_flash_attention(transformer): + """Install the dense flash-attention processor on every decoder layer.""" + processor = Cosmos3FlashAttnProcessor() + for layer in transformer.layers: + layer.self_attn.set_processor(processor) + return transformer + + +# ============================================================================= +# Tensor parallelism (shard the attention + MLP weights) +# ============================================================================= +def enable_cosmos3_tensor_parallel(transformer, tp_mesh): + """Shard every decoder layer's attention and MLP projection weights across + ``tp_mesh`` (Megatron-style tensor parallelism), so a model whose weights don't + fit on one GPU (e.g. Cosmos3-Super, ~120 GB) can be served across several. + + Layout per layer: + * column-parallel (output / head dim sharded): ``to_q/to_k/to_v`` and + ``add_q_proj/add_k_proj/add_v_proj``, plus both MLPs' ``gate_proj/up_proj``; + * row-parallel (input dim sharded, all-reduce on output): ``to_out`` and + ``to_add_out``, plus both MLPs' ``down_proj``. + + Each rank then owns ``num_attention_heads / tp`` query heads and + ``num_key_value_heads / tp`` KV heads, so the per-layer head counts on every + attention module are rewritten to their local values — the projection + reshape + code in ``Cosmos3AttnProcessor`` reads these. The embeddings, final norms, lm_head + and modality projections stay replicated (they're a small fraction of the weights). + + Memory note: weights are loaded to CPU first, then each layer is moved to its GPU + and sharded in place, so the full model is never materialized on one device. + + Composes with Ulysses CP: this shards weights only and never touches the attention + processor, so combine it with ``enable_cosmos3_context_parallel`` (CP installs its + own processor) on a 2-D ``(tp, cp)`` mesh, or with ``enable_cosmos3_flash_attention`` + for TP without CP. + """ + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module + + tp = tp_mesh.size() + dev = torch.device("cuda", torch.cuda.current_device()) + plan = { + "self_attn.to_q": ColwiseParallel(), + "self_attn.to_k": ColwiseParallel(), + "self_attn.to_v": ColwiseParallel(), + "self_attn.to_out": RowwiseParallel(), + "self_attn.add_q_proj": ColwiseParallel(), + "self_attn.add_k_proj": ColwiseParallel(), + "self_attn.add_v_proj": ColwiseParallel(), + "self_attn.to_add_out": RowwiseParallel(), + "mlp.gate_proj": ColwiseParallel(), + "mlp.up_proj": ColwiseParallel(), + "mlp.down_proj": RowwiseParallel(), + "mlp_moe_gen.gate_proj": ColwiseParallel(), + "mlp_moe_gen.up_proj": ColwiseParallel(), + "mlp_moe_gen.down_proj": RowwiseParallel(), + } + for layer in transformer.layers: + attn = layer.self_attn + if attn.num_attention_heads % tp != 0 or attn.num_key_value_heads % tp != 0: + raise ValueError( + f"TP degree {tp} must divide both the query heads ({attn.num_attention_heads}) " + f"and the KV heads ({attn.num_key_value_heads})." + ) + layer.to(dev) # full layer transiently on this rank's GPU, then sharded in place + parallelize_module(layer, tp_mesh, plan) + # Projections now emit only this rank's head shard; the processor reshapes + # with the local counts. + attn.num_attention_heads //= tp + attn.num_key_value_heads //= tp + attn.num_key_value_groups = attn.num_attention_heads // attn.num_key_value_heads + return transformer diff --git a/examples/cosmos3/inference_cosmos3.py b/examples/cosmos3/inference_cosmos3.py index e9a5f5f369bb..d26be77bcf73 100644 --- a/examples/cosmos3/inference_cosmos3.py +++ b/examples/cosmos3/inference_cosmos3.py @@ -20,11 +20,20 @@ Text-to-video-with-sound (requires a sound-capable checkpoint): python inference_cosmos3.py --prompt "..." --enable-sound + +Multi-GPU (any modality above): launch with torchrun and pass parallelism degrees. +``--tp-degree`` shards the weights (so large checkpoints fit), ``--cp-degree`` shards +the sequence (Ulysses, lower latency); ``--nproc_per_node`` must equal their product. +These reuse the helpers in the ``cosmos_{context,tensor}_parallel_inference.py`` examples. + # TP=2 x CP=2 across 4 GPUs (Super): + torchrun --nproc_per_node 4 inference_cosmos3.py --model super --tp-degree 2 --cp-degree 2 --prompt "..." """ import argparse import json +import os import pathlib +import sys import urllib.request import torch @@ -34,6 +43,15 @@ from diffusers.utils import encode_video, export_to_video, load_image, load_video +# Multi-GPU helpers (context + tensor parallelism) live in the sibling module. +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) +from cosmos_parallel import ( # noqa: E402 + enable_cosmos3_context_parallel, + enable_cosmos3_flash_attention, + enable_cosmos3_tensor_parallel, +) + + HF_REPOS = { "nano": "nvidia/Cosmos3-Nano", "super": "nvidia/Cosmos3-Super", @@ -99,6 +117,25 @@ def main(): help="Override the scheduler's flow-matching shift (UniPCMultistepScheduler.flow_shift).", ) parser.add_argument("--seed", type=int, default=None, help="Random seed for latent initialization.") + parser.add_argument( + "--tp-degree", + type=int, + default=1, + help=( + "Tensor-parallel degree: shard the model weights across this many GPUs (so large checkpoints " + "fit). Must divide the query heads and KV heads. >1 requires launching with torchrun." + ), + ) + parser.add_argument( + "--cp-degree", + type=int, + default=1, + help=( + "Context-parallel (Ulysses) degree: shard the sequence across this many GPUs (lower latency). " + "--tp-degree * --cp-degree must equal --nproc_per_node, and must divide the query heads. " + ">1 requires launching with torchrun." + ), + ) parser.add_argument( "--enable-sound", action="store_true", @@ -158,17 +195,68 @@ def main(): ) args = parser.parse_args() + tp, cp = args.tp_degree, args.cp_degree + distributed = tp * cp > 1 + + if distributed: + import torch.distributed as dist + from torch.distributed.device_mesh import init_device_mesh + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + world = dist.get_world_size() + torch.cuda.set_device(local_rank) + dev = torch.device("cuda", local_rank) + if world != tp * cp: + raise ValueError( + f"--nproc_per_node ({world}) must equal --tp-degree * --cp-degree ({tp} * {cp} = {tp * cp})." + ) + if args.seed is None: + args.seed = 42 # all ranks must start from identical latents + else: + rank, dev = 0, torch.device("cuda") + + def log(msg): + if rank == 0: + print(msg, flush=True) + hf_repo = HF_REPOS[args.model] - print(f"Downloading pipeline from {hf_repo}") + log(f"Downloading pipeline from {hf_repo}") pipeline_path = pathlib.Path(snapshot_download(repo_id=hf_repo)) - print(f"Loading pipeline from {pipeline_path} …") - pipeline = Cosmos3OmniPipeline.from_pretrained( - str(pipeline_path), - torch_dtype=torch.bfloat16, - device_map="cuda", - enable_safety_checker=not args.disable_safety_checker, - ) - print("Pipeline loaded successfully.") + log(f"Loading pipeline from {pipeline_path} …") + + if distributed: + # Load on CPU first (a TP-sharded model may not fit one GPU), then place / shard. + pipeline = Cosmos3OmniPipeline.from_pretrained( + str(pipeline_path), + torch_dtype=torch.bfloat16, + enable_safety_checker=not args.disable_safety_checker, + ) + qh = pipeline.transformer.config.num_attention_heads + kv = pipeline.transformer.config.num_key_value_heads + if kv % tp != 0: + raise ValueError(f"--tp-degree ({tp}) must divide the {kv} KV heads.") + if qh % world != 0: + raise ValueError(f"--tp-degree * --cp-degree ({world}) must divide the {qh} query heads.") + mesh = init_device_mesh("cuda", (tp, cp), mesh_dim_names=("tp", "cp")) + if tp > 1: + enable_cosmos3_tensor_parallel(pipeline.transformer, mesh["tp"]) # shard weights -> GPUs + pipeline.to(dev) # place the replicated remainder (embeddings, norms, VAE, ...) + pipeline.transformer.set_attention_backend("native") # GQA-capable backend + if cp > 1: + enable_cosmos3_context_parallel(pipeline.transformer, mesh["cp"]) # shard the sequence + elif tp > 1: + enable_cosmos3_flash_attention(pipeline.transformer) # GQA-safe dense attention + log(f"Parallelism: TP={tp} x CP={cp} over {world} GPUs.") + else: + pipeline = Cosmos3OmniPipeline.from_pretrained( + str(pipeline_path), + torch_dtype=torch.bfloat16, + device_map="cuda", + enable_safety_checker=not args.disable_safety_checker, + ) + log("Pipeline loaded successfully.") if args.flow_shift is not None: pipeline.scheduler = UniPCMultistepScheduler.from_config( @@ -176,7 +264,8 @@ def main(): ) output_dir = pathlib.Path(args.output) - output_dir.mkdir(parents=True, exist_ok=True) + if rank == 0: + output_dir.mkdir(parents=True, exist_ok=True) generator = torch.Generator().manual_seed(args.seed) if args.seed is not None else None if args.action_mode is not None: @@ -224,31 +313,36 @@ def main(): enable_safety_check=not args.no_safety_check, ) - if args.num_frames == 1: - save_path = output_dir / "sample.jpg" - result.video[0].save(save_path, format="JPEG", quality=85) - else: - save_path = output_dir / "sample.mp4" - if result.sound is not None: - assert pipeline.sound_tokenizer is not None - encode_video( - result.video, - fps=int(args.fps), - audio=result.sound, - audio_sample_rate=pipeline.sound_tokenizer.config.sampling_rate, - output_path=str(save_path), - ) + # Every rank produces the same output under parallelism; only rank 0 writes it. + if rank == 0: + if args.num_frames == 1: + save_path = output_dir / "sample.jpg" + result.video[0].save(save_path, format="JPEG", quality=85) else: - # macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). - export_to_video(result.video, str(save_path), fps=int(args.fps), quality=10, macro_block_size=1) - print(f"Saved: {save_path}") - - if result.action is not None: - for action in result.action: - action_path = output_dir / "sample_action.json" - with open(action_path, "w") as f: - json.dump(action.tolist(), f) - print(f"Saved: {action_path}") + save_path = output_dir / "sample.mp4" + if result.sound is not None: + assert pipeline.sound_tokenizer is not None + encode_video( + result.video, + fps=int(args.fps), + audio=result.sound, + audio_sample_rate=pipeline.sound_tokenizer.config.sampling_rate, + output_path=str(save_path), + ) + else: + # macro_block_size=1 allows arbitrary frame sizes (Cosmos3 outputs are not always divisible by 16). + export_to_video(result.video, str(save_path), fps=int(args.fps), quality=10, macro_block_size=1) + print(f"Saved: {save_path}") + + if result.action is not None: + for action in result.action: + action_path = output_dir / "sample_action.json" + with open(action_path, "w") as f: + json.dump(action.tolist(), f) + print(f"Saved: {action_path}") + + if distributed: + dist.destroy_process_group() if __name__ == "__main__": diff --git a/src/diffusers/models/transformers/transformer_cosmos3.py b/src/diffusers/models/transformers/transformer_cosmos3.py index d6c26f927cd1..fe47c190c2aa 100644 --- a/src/diffusers/models/transformers/transformer_cosmos3.py +++ b/src/diffusers/models/transformers/transformer_cosmos3.py @@ -67,6 +67,21 @@ def __call__( q_gen = q_gen * cos_gen + _rotate_half(q_gen) * sin_gen k_gen = k_gen * cos_gen + _rotate_half(k_gen) * sin_gen + causal_out, full_out = self._run_attention(attn, q_und, k_und, v_und, q_gen, k_gen, v_gen) + + # Per-pathway output projection + und_out = attn.to_out(causal_out) + gen_out = attn.to_add_out(full_out) + return und_out, gen_out + + def _run_attention(self, attn, q_und, k_und, v_und, q_gen, k_gen, v_gen): + """Run the two attention pathways and return ``(causal_out, full_out)``, each + flattened to ``[seq, num_attention_heads * head_dim]``. + + This is an override seam: subclasses can change how attention is computed while reusing the shared projection + and rotary code in ``__call__``. The context-parallel processor in ``examples/cosmos3`` overrides it to bracket + the two pathways with Ulysses all-to-all collectives. + """ # Causal pathway (understanding): und tokens self-attend with causal masking. causal_out = dispatch_attention_fn( q_und.unsqueeze(0), @@ -92,11 +107,7 @@ def __call__( parallel_config=self._parallel_config, ) full_out = full_out.squeeze(0).flatten(-2, -1) - - # Per-pathway output projection - und_out = attn.to_out(causal_out) - gen_out = attn.to_add_out(full_out) - return und_out, gen_out + return causal_out, full_out def _rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -295,6 +306,16 @@ class Cosmos3OmniTransformer(ModelMixin, ConfigMixin, PeftAdapterMixin, Attentio _repeated_blocks = ["Cosmos3VLTextMoTDecoderLayer"] _skip_layerwise_casting_patterns = ["embed_tokens", "time_embedder", "norm"] _keep_in_fp32_modules = ["time_embedder"] + # Optional context-parallelism seams. They default to ``None`` (no-op) so the + # model itself carries no CP logic. `forward` applies `_cp_shard_fn` to the + # per-pathway hidden states + rotary embeddings before the decoder layers, and + # `_cp_gather_fn` to the per-pathway outputs after the final norm. An external + # helper (see `examples/cosmos3/cosmos_parallel.py`) sets these to + # shard/gather across a device mesh and installs a context-parallel attention + # processor — the packed dual-pathway + GQA + ragged-length structure cannot be + # expressed as diffusers' declarative `_cp_plan`, so CP lives outside the model. + _cp_shard_fn = None + _cp_gather_fn = None # `dtype` is injected into init_dict by ModelMixin.from_pretrained (configuration_utils.py:289), # so __init__ must accept it. Excluding it here keeps save_pretrained from writing it into # config.json — the value is a load-time runtime hint, not part of the model architecture. @@ -678,6 +699,14 @@ def forward( und_seq = hidden_states[:und_len] gen_seq = hidden_states[und_len:] rotary_emb = (cos[:und_len], sin[:und_len], cos[und_len:], sin[und_len:]) + + # Optional context-parallelism shard seam (no-op unless set by an external + # helper, e.g. `examples/cosmos3/cosmos_parallel.py`). When set, it + # shards each pathway's sequence and rotary embeddings across a device mesh, so + # the decoder layers below run on local sequence shards. + if self._cp_shard_fn is not None: + und_seq, gen_seq, rotary_emb = self._cp_shard_fn(und_seq, gen_seq, rotary_emb) + for decoder_layer in self.layers: if torch.is_grad_enabled() and self.gradient_checkpointing: und_seq, gen_seq = self._gradient_checkpointing_func( @@ -687,6 +716,14 @@ def forward( und_seq, gen_seq = decoder_layer(und_seq, gen_seq, rotary_emb) und_out = self.norm(und_seq) gen_out = self.norm_moe_gen(gen_seq) + + # Optional context-parallelism gather seam: re-gather the full per-pathway + # sequence on every rank (and drop the padding) before the global-index decode + # below, since the downstream indexes address positions in the unpadded joint + # sequence. No-op unless `_cp_shard_fn`'s counterpart is set. + if self._cp_gather_fn is not None: + und_out, gen_out = self._cp_gather_fn(und_out, gen_out) + last_hidden_state = torch.cat([und_out, gen_out], dim=0) # Decode vision predictions from the joint hidden state.