Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions pithtrain/models/deepseek_v2_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def forward(
query_states,
key_states,
value_states.contiguous(),
softmax_scale=self.softmax_scale,
sm_scale=self.softmax_scale,
cp_group=self.cp_group,
)
else:
Expand Down Expand Up @@ -646,6 +646,7 @@ def __init__(
self.stage_id = stage_id
self.cp_group = cp_group
self.cp_rank = cp_group.rank() if cp_group is not None else 0
self.cp_size = cp_group.size() if cp_group is not None else 1
self.embed_tokens = (
nn.Embedding(config.vocab_size, config.hidden_size) if stage_id == 0 else None
)
Expand Down Expand Up @@ -714,11 +715,23 @@ def forward(
hidden_states = self.embed_tokens(hidden_states)

seq_len = hidden_states.shape[1]
offset = self.cp_rank * seq_len
cos, sin = self.rotary_emb(hidden_states, seq_len=offset + seq_len)
# Zigzag CP layout: the local seq_len tokens come from two non-contiguous
# global chunks. Build the global position IDs by concatenating the
# front block and the mirror back block, then gather cos/sin by position.
block = seq_len // 2
global_seq_len = seq_len * self.cp_size
front_start = self.cp_rank * block
back_start = (2 * self.cp_size - self.cp_rank - 1) * block
position_ids = torch.cat(
[
torch.arange(front_start, front_start + block, device=hidden_states.device),
torch.arange(back_start, back_start + block, device=hidden_states.device),
]
)
Comment thread
haok1402 marked this conversation as resolved.
cos, sin = self.rotary_emb(hidden_states, seq_len=global_seq_len)
position_embeddings = (
cos[offset : offset + seq_len].unsqueeze(0).to(dtype=hidden_states.dtype),
sin[offset : offset + seq_len].unsqueeze(0).to(dtype=hidden_states.dtype),
cos[position_ids].unsqueeze(0).to(dtype=hidden_states.dtype),
sin[position_ids].unsqueeze(0).to(dtype=hidden_states.dtype),
)
for _, layer in self.layers.items():
layer._position_embeddings = position_embeddings
Expand Down
22 changes: 16 additions & 6 deletions pithtrain/models/qwen3_30b_a3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def forward(
query_states,
key_states,
value_states,
softmax_scale=self.scaling,
sm_scale=self.scaling,
cp_group=self.cp_group,
)

Expand Down Expand Up @@ -673,6 +673,7 @@ def __init__(
self.num_stages = num_stages
self.cp_group = cp_group
self.cp_rank = cp_group.rank() if cp_group is not None else 0
self.cp_size = cp_group.size() if cp_group is not None else 1

hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
Expand Down Expand Up @@ -765,12 +766,21 @@ def forward(

bsz, seq_len, _ = hidden_states.shape

offset = self.cp_rank * seq_len
cos, sin = self.rotary_emb(hidden_states, seq_len=offset + seq_len)
position_embeddings = (
cos[offset : offset + seq_len].unsqueeze(0),
sin[offset : offset + seq_len].unsqueeze(0),
# Zigzag CP layout: the local seq_len tokens come from two non-contiguous
# global chunks. Build the global position IDs by concatenating the
# front block and the mirror back block, then gather cos/sin by position.
block = seq_len // 2
global_seq_len = seq_len * self.cp_size
front_start = self.cp_rank * block
back_start = (2 * self.cp_size - self.cp_rank - 1) * block
position_ids = torch.cat(
[
torch.arange(front_start, front_start + block, device=hidden_states.device),
torch.arange(back_start, back_start + block, device=hidden_states.device),
]
)
Comment thread
haok1402 marked this conversation as resolved.
cos, sin = self.rotary_emb(hidden_states, seq_len=global_seq_len)
position_embeddings = (cos[position_ids].unsqueeze(0), sin[position_ids].unsqueeze(0))

for layer_idx_str, layer in self.layers.items():
layer._position_embeddings = position_embeddings
Expand Down
9 changes: 6 additions & 3 deletions pithtrain/modules/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,12 @@ def setup_model(cfg: TrainingCfg, ctx: TrainingCtx, distributed: DistributedCtx)
module_config.ep_size = ep_size
assert hasattr(module_config, "hidden_size")
assert isinstance(module_config.hidden_size, int)
assert cfg.sequence_length % cp_size == 0, (
f"sequence_length ({cfg.sequence_length}) must be divisible by context_parallel_size ({cp_size})"
)
if cfg.sequence_length % (2 * cp_size) != 0:
raise ValueError(
f"sequence_length ({cfg.sequence_length}) must be divisible by "
f"2 * context_parallel_size ({2 * cp_size}); zigzag ring attention "
f"splits the sequence into 2*cp_size equal chunks"
)

hidden_size = module_config.hidden_size

Expand Down
Loading
Loading