From 7c7ef547206b49da40f2ebe638ba6373d12242ff Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Mon, 26 Jan 2026 17:12:28 +0800 Subject: [PATCH] 1. add zip hidden state 2. usp does not support online mode 3. loss shared for usp, split target 4. Tighten USP batching and padding behavior 5. optimize offline dataloader for sp --- scripts/prepare_hidden_states.py | 52 ++++++- scripts/train_eagle3.py | 82 ++++------ specforge/core/eagle3.py | 184 ++++++++++++----------- specforge/core/eagle3_adapters.py | 133 ++++++++++++++++ specforge/data/preprocessing.py | 160 +++++++++++++++++++- specforge/data/utils.py | 35 ++++- specforge/modeling/draft/llama3_eagle.py | 37 ++--- tests/test_layers/test_decoder.py | 156 ++++++++++++++++++- 8 files changed, 646 insertions(+), 193 deletions(-) create mode 100644 specforge/core/eagle3_adapters.py diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index d201ca479..304677d26 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -34,6 +34,7 @@ import argparse import gc +import gzip import hashlib import os from concurrent.futures import ThreadPoolExecutor @@ -58,6 +59,7 @@ ) from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model from specforge.utils import ( + print_args_with_dots, print_with_rank, rank_0_priority, safe_conversations_generator, @@ -123,8 +125,8 @@ def parse_args(): others_group.add_argument( "--num-io-threads", type=int, - default=4, - help="Number of threads for async I/O operations", + default=None, + help="Number of threads for async I/O operations (default: all of CPU cores).", ) others_group.add_argument( "--num-workers", type=int, default=4, help="Number of workers for DataLoader" @@ -141,6 +143,17 @@ def parse_args(): default=2000, help="Number of files per subdirectory.", ) + others_group.add_argument( + "--compress", + action="store_true", + help="Compress hidden state files on disk (gzip).", + ) + others_group.add_argument( + "--compression-level", + type=int, + default=6, + help="Gzip compression level (1-9).", + ) sglang_group = parser.add_argument_group("sglang") SGLangBackendArgs.add_args(sglang_group) @@ -215,6 +228,8 @@ def __init__( num_io_threads: int = 4, io_queue_size: int = 50, file_group_size: int = 2000, + compress: bool = False, + compression_level: int = 6, ): """ Args: @@ -231,6 +246,9 @@ def __init__( self.num_io_threads = num_io_threads self.io_queue_size = io_queue_size self.file_group_size = file_group_size + self.compress = compress + self.compression_level = compression_level + self.file_extension = ".ckpt.gz" if self.compress else ".ckpt" # progress bar should only shown on TP rank = 0 self.show_progress = dist.get_rank(get_tp_group()) == 0 @@ -282,7 +300,13 @@ def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None: ) return - torch.save(asdict(data_point), output_file) + if self.compress: + with gzip.open( + output_file, "wb", compresslevel=self.compression_level + ) as f: + torch.save(asdict(data_point), f) + else: + torch.save(asdict(data_point), output_file) def _save_tensor_async(self, data_point: DataPoint, output_file: str) -> None: """ @@ -365,14 +389,22 @@ def _check_existing_files_batch( return [False] * len(global_indices) def check_single_file(idx): - return os.path.exists(self._get_file_path(output_path, idx)) + if os.path.exists(self._get_file_path(output_path, idx)): + return True + legacy_ckpt = self._get_file_path(output_path, idx, extension=".ckpt") + compressed_ckpt = self._get_file_path( + output_path, idx, extension=".ckpt.gz" + ) + return os.path.exists(legacy_ckpt) or os.path.exists(compressed_ckpt) # Parallel file existence check with ThreadPoolExecutor(max_workers=self.num_io_threads) as executor: exists = list(executor.map(check_single_file, global_indices)) return exists - def _get_file_path(self, output_path: str, idx: int) -> str: + def _get_file_path( + self, output_path: str, idx: int, extension: Optional[str] = None + ) -> str: """ A helper function to get the standard file path for the data point with the given index. @@ -383,9 +415,10 @@ def _get_file_path(self, output_path: str, idx: int) -> str: Returns: str: The file path for the data point. """ + ext = self.file_extension if extension is None else extension group_idx = (idx // self.file_group_size) * self.file_group_size grouped_subdir = f"rows_{group_idx}-{group_idx + self.file_group_size}" - return os.path.join(output_path, grouped_subdir, f"data_{idx}.ckpt") + return os.path.join(output_path, grouped_subdir, f"data_{idx}{ext}") @torch.no_grad() def generate( @@ -553,9 +586,12 @@ def main(): args.aux_hidden_states_layers = [ int(x) for x in args.aux_hidden_states_layers.split(",") ] - + if args.num_io_threads is None: + cpu_cores = os.cpu_count() or 1 + args.num_io_threads = max(1, cpu_cores) # Initialize distributed environment (TP + DP) init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_args_with_dots(args) # Build target model (with TP) target_model_config = AutoConfig.from_pretrained( @@ -657,6 +693,8 @@ def main(): num_io_threads=args.num_io_threads, io_queue_size=args.io_queue_size, file_group_size=args.file_group_size, + compress=args.compress, + compression_level=args.compression_level, # Other params like io_queue_size can also be added to argparse ) as hidden_states_generator: diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 4237dc94a..0d87c33ad 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -35,7 +35,6 @@ destroy_distributed, get_dp_group, get_draft_dp_group, - get_draft_sp_group, get_tp_group, init_distributed, ) @@ -340,9 +339,24 @@ def sanity_check(args: Namespace) -> None: """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size + if args.attention_backend == "usp": + sp_sanity_check(args) + + +def sp_sanity_check(args: Namespace) -> None: args.draft_accumulation_steps = ( args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size ) + assert ( + args.batch_size == 1 + ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" + + assert args.sp_ring_size * args.sp_ulysses_size > 1, ( + f"USP requires sp_ring_size * sp_ulysses_size > 1. " + f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}." + ) + + assert args.train_hidden_states_path is not None, f"USP only support offline mode" if args.eval_data_path is not None and args.eval_hidden_states_path is not None: raise ValueError( @@ -453,6 +467,8 @@ def build_dataloaders( train_eagle3_dataset = build_offline_eagle3_dataset( args.train_hidden_states_path, args.max_length, + ttt_length=args.ttt_length, + use_usp_preprocess=(args.attention_backend == "usp"), ) train_dataloader = prepare_dp_dataloaders( @@ -488,6 +504,8 @@ def build_dataloaders( eval_eagle3_dataset = build_offline_eagle3_dataset( args.eval_hidden_states_path, args.max_length, + ttt_length=args.ttt_length, + use_usp_preprocess=(args.attention_backend == "usp"), ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, @@ -619,6 +637,9 @@ def run_forward( loss_mask=loss_mask, target=target, hidden_states=hidden_states, + position_ids=( + data["position_ids"].cuda() if "position_ids" in data else None + ), image_grid_thw=image_grid_thw, is_vlm=args.is_vlm, ) @@ -675,56 +696,13 @@ def record_metrcs( tracker.log(logdict, step=global_step) -def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor: +def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor: """ - Process: TP split -> Pad to Max Len -> SP gather. + Get the data shard from the tensor. """ - # 1. TP: Slice the tensor along the batch dimension - tp_group = get_tp_group() - tp_size = dist.get_world_size(tp_group) - tp_rank = dist.get_rank(tp_group) - - local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank] - - # 2. SP: Handle dynamic sequence lengths and Gather - sp_group = get_draft_sp_group() - - if sp_group is not None and dist.get_world_size(sp_group) > 1: - sp_world_size = dist.get_world_size(sp_group) - local_seq_len = local_tp_shard.size(sp_dim) - - # Find global max sequence length in SP group - len_tensor = torch.tensor( - [local_seq_len], device=local_tp_shard.device, dtype=torch.long - ) - dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group) - max_seq_len = len_tensor.item() - - # Pad local tensor if necessary - # Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1 - if local_seq_len < max_seq_len: - pad_size = max_seq_len - local_seq_len - - pad_config = [0] * (local_tp_shard.ndim * 2) - - pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1 - pad_config[pad_idx] = pad_size - - # Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed - local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0) - else: - local_tp_shard_padded = local_tp_shard - - gathered_shards = [ - torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size) - ] - dist.all_gather( - gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group - ) - - return torch.cat(gathered_shards, dim=sp_dim) - - return local_tp_shard + tp_size = dist.get_world_size(get_tp_group()) + tp_rank = dist.get_rank(get_tp_group()) + return tensor.chunk(tp_size, dim=0)[tp_rank] def main(): @@ -890,7 +868,11 @@ def main(): # 7.1 Training Step # ================================================ plosses, acces = run_forward( - args, eagle3_model, data, target_model, is_online + args, + eagle3_model, + data, + target_model, + is_online, ) run_backward_and_update(args, plosses, optimizer, global_step) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index abf43527f..1e2f04e7e 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -25,16 +25,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from transformers.cache_utils import DynamicCache -from yunchang import EXTRACT_FUNC_DICT +from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter from specforge.core.loss import LogSoftmaxLoss -from specforge.distributed import ( - gather_outputs_and_unpad, - get_sp_ring_group, - get_sp_ulysses_group, -) from specforge.modeling.draft import Eagle3DraftModel from specforge.utils import padding @@ -74,23 +68,66 @@ def __init__( self.attention_backend = attention_backend self.target_model = target_model + def _make_adapter(self) -> BackendAdapter: if self.attention_backend == "usp": - self.extract_func = EXTRACT_FUNC_DICT["basic"] - self.sp_ring_degree = torch.distributed.get_world_size(get_sp_ring_group()) - self.sp_ulysses_degree = torch.distributed.get_world_size( - get_sp_ulysses_group() + return UspAdapter(self) + return SdpaLikeAdapter(self) + + def _acc_and_loss( + self, + *, + logits: torch.Tensor, + target_p: torch.Tensor, + position_mask: torch.Tensor, + loss_mask: torch.Tensor, + adapter: BackendAdapter, + ) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + local_correct = ( + (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) + ).sum() + local_denom = loss_mask.sum().clamp_min(1e-6) + local_correct, local_denom = adapter.reduce_metrics( + local_correct=local_correct, local_denom=local_denom + ) + acc = local_correct / local_denom + + loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) + loss = adapter.reduce_loss(loss) + return acc, loss + + def _prepare_position_ids( + self, + position_ids: Optional[torch.Tensor], + *, + seq_length: int, + past_key_values_length: int, + device: torch.device, + is_vlm: bool, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.attention_backend == "usp": + return position_ids + if position_ids is None: + if is_vlm: + mrope_positions_ids, _ = self.target_model.get_rope_index( + input_ids=input_ids, image_grid_thw=image_grid_thw + ) + return mrope_positions_ids + return ( + torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + .unsqueeze(0) + .view(-1, seq_length) ) - self.sp_world_size = self.sp_ring_degree * self.sp_ulysses_degree - self.sp_rank = torch.distributed.get_rank() % self.sp_world_size - - @torch.compile() - def prepare_usp_input(self, full_input): - shared_input = self.extract_func( - full_input, - rank=self.sp_rank, - world_size=self.sp_world_size, - ).clone() - return shared_input + + position_ids = position_ids.long() + return position_ids.view(-1, seq_length) def forward( self, @@ -131,35 +168,21 @@ def forward( past_key_values_length = 0 # Step 2: project the concatenated hidden states to the target hidden size - if self.attention_backend == "usp": - # NOTE: Split first for USP to parallelize computation and ensure - # gradient consistency without redundant full-sequence projection. - hidden_states = self.prepare_usp_input(hidden_states) hidden_states = self.draft_model.project_hidden_states(hidden_states) # Step 3: process kv cache, position ids and position ids if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - if is_vlm: - mrope_positions_ids, mrope_position_delta = ( - self.target_model.get_rope_index( - input_ids=input_ids, image_grid_thw=image_grid_thw - ) - ) - position_ids = mrope_positions_ids - else: - device = hidden_states.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = self._prepare_position_ids( + position_ids=position_ids, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + device=hidden_states.device, + is_vlm=is_vlm, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + ) # Step 4: handle attention mask if attention_mask is None: @@ -177,28 +200,11 @@ def forward( past_key_values_length=past_key_values_length, ) - def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): - # 1. Compute Logits(The part that consumes the most VRAM.) - logits_ = self.draft_model.compute_logits(hs) - logits = gather_outputs_and_unpad(logits_, gather_dim=1) - - # 2. Compute Loss - loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) - - # 3. Compute Accuracy - with torch.no_grad(): - acc_val = _compute_metric_acc( - logits=logits, - target_p=tgt_p, - position_mask=pos_mask, - loss_mask=l_mask, - ) - return loss_val, acc_val - # Step 5: run TTT plosses = [] vlosses = [] acces = [] + adapter = self._make_adapter() # for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift global_input_ids = input_ids if self.attention_backend in ["sdpa", "fa", "usp"]: @@ -211,25 +217,31 @@ def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): raise ValueError(f"Unknown attention backend: {self.attention_backend}") for idx in range(self.length): - target_p = target_p_padded[:, idx : idx + seq_length, :] - if self.attention_backend == "usp": - input_ids = self.prepare_usp_input(global_input_ids) - else: - input_ids = global_input_ids - + state = adapter.step_view( + idx=idx, + ttt_length=self.length, + global_input_ids=global_input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + position_ids=position_ids, + hidden_states=hidden_states, + target_p_padded=target_p_padded, + position_mask=position_mask, + seq_length=seq_length, + ) is_last = idx == self.length - 1 # Step 5.1: embed the input ids - inputs_embeds = self.draft_model.embed_input_ids(input_ids) + inputs_embeds = self.draft_model.embed_input_ids(state.input_ids) inputs_embeds = inputs_embeds.to(hidden_states.dtype) # Step 5.2: run the draft model backbone hidden_states_out = self.draft_model.backbone( input_embeds=inputs_embeds, - hidden_states=hidden_states, + hidden_states=state.hidden_states, cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, + attention_mask=state.attention_mask, + position_ids=state.position_ids, past_key_values=past_key_values, use_cache=True, ) @@ -237,22 +249,20 @@ def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): # update hidden states for next step hidden_states = hidden_states_out - if hidden_states.requires_grad: - loss, acc = checkpoint( - compute_loss_and_acc_checkpointed, - hidden_states, - target_p, - position_mask, - loss_mask, - use_reentrant=False, - ) - else: - loss, acc = compute_loss_and_acc_checkpointed( - hidden_states, target_p, position_mask, loss_mask - ) + # Step 5.4: get logits + logits = self.draft_model.compute_logits(hidden_states) - plosses.append(loss) + # Step 5.5 + 5.6: metric and loss + acc, loss = self._acc_and_loss( + logits=logits, + target_p=state.target_p, + position_mask=state.position_mask, + loss_mask=state.loss_mask, + adapter=adapter, + ) acces.append(acc) + plosses.append(loss) + if not is_last: # Step 5.7: we need to update the loss mask global_input_ids = padding(global_input_ids, left=False) diff --git a/specforge/core/eagle3_adapters.py b/specforge/core/eagle3_adapters.py new file mode 100644 index 000000000..555c16efc --- /dev/null +++ b/specforge/core/eagle3_adapters.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn + +from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group + + +@dataclass +class StepState: + input_ids: torch.Tensor + hidden_states: torch.Tensor + position_ids: torch.Tensor + attention_mask: torch.Tensor + target_p: torch.Tensor + position_mask: torch.Tensor + loss_mask: torch.Tensor + + +class BackendAdapter: + def __init__(self, model: "OnlineEagle3Model"): + self.m = model + + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + raise NotImplementedError + + def reduce_metrics( + self, *, local_correct: torch.Tensor, local_denom: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return local_correct, local_denom + + def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: + return loss + + +class SdpaLikeAdapter(BackendAdapter): + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() + return StepState( + input_ids=global_input_ids, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + target_p=target_p, + position_mask=position_mask, + loss_mask=loss_mask, + ) + + +class UspAdapter(BackendAdapter): + def __init__(self, model: "OnlineEagle3Model"): + super().__init__(model) + self.sp_group = get_draft_sp_group() + self.sp_world_size = dist.get_world_size(self.sp_group) + self.ulysses_pg = get_sp_ulysses_group() + self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg) + + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + usp_chunk_size = seq_length - ttt_length + if usp_chunk_size <= 0: + raise ValueError( + f"USP local seq_length ({seq_length}) must be larger than " + f"ttt_length ({ttt_length})" + ) + target_p = target_p_padded[:, idx : idx + usp_chunk_size, :] + return StepState( + input_ids=global_input_ids[:, :usp_chunk_size], + hidden_states=hidden_states[:, :usp_chunk_size, :], + position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree], + attention_mask=attention_mask[:, :usp_chunk_size], + target_p=target_p, + position_mask=position_mask[:, :usp_chunk_size, :], + loss_mask=loss_mask[:, :usp_chunk_size, :], + ) + + def reduce_metrics( + self, *, local_correct: torch.Tensor, local_denom: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + local_correct = dist_nn.all_reduce( + local_correct, op=dist.ReduceOp.SUM, group=self.sp_group + ) + local_denom = dist_nn.all_reduce( + local_denom, op=dist.ReduceOp.SUM, group=self.sp_group + ) + return local_correct, local_denom + + def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: + loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group) + loss = loss / self.sp_world_size + return loss diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 648c0d1a3..46f3d8647 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -20,6 +20,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gzip +import io import os import re import warnings @@ -27,11 +29,14 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F from tqdm import tqdm from transformers import ImageProcessingMixin, PreTrainedTokenizer from datasets import Dataset as HFDataset +from ..distributed import get_draft_sp_group, get_sp_ring_group + try: from qwen_vl_utils import process_vision_info @@ -432,23 +437,53 @@ def preprocess_function(examples): # Offline Eagle3 Dataset # ============================== # modified from https://github.com/NickL77/BaldEagle/blob/master/train/modules/data/data.py -def list_local_files(path, suffixes=[".ckpt"]): +def list_local_files(path, suffixes=None): + if suffixes is None: + suffixes = [".ckpt", ".ckpt.gz"] datapaths = [] for root, directories, files in os.walk(path): for file in files: file_path = os.path.join(root, file) datapaths.append(file_path) - for suffix in suffixes: - datapaths = [f_name for f_name in datapaths if f_name.endswith(suffix)] + if suffixes: + datapaths = [ + f_name + for f_name in datapaths + if any(f_name.endswith(suffix) for suffix in suffixes) + ] return datapaths class OfflineEagle3Dataset(torch.utils.data.Dataset): - def __init__(self, datapath, transform=None, max_len=2048): + def __init__( + self, + datapath, + transform=None, + max_len=2048, + ttt_length=1, + use_usp_preprocess=False, + ): + """ + Args: + datapath: List of file paths. + transform: Optional transform to apply. + max_len: Maximum sequence length to load. + ttt_length: TTT overlap length used in USP preprocessing. + use_usp_preprocess: Whether to shard all sequences with USP overlap in preprocessing. + """ self.datapaths = datapath self.transform = transform self._epoch = 0 self.max_len = max_len + self.ttt_length = ttt_length + self.use_usp_preprocess = use_usp_preprocess + if use_usp_preprocess: + sp_group = get_draft_sp_group() + self.sp_rank = torch.distributed.get_rank(sp_group) + self.sp_size = torch.distributed.get_world_size(sp_group) + ring_group = get_sp_ring_group() + self.ring_rank = torch.distributed.get_rank(ring_group) + self.sp_ring_size = torch.distributed.get_world_size(ring_group) @staticmethod def process_data(data, max_len, transform=None): @@ -470,11 +505,98 @@ def process_data(data, max_len, transform=None): new_data = transform(new_data) return new_data + @staticmethod + def process_data_usp( + data, + max_len, + ttt_length=1, + transform=None, + sp_rank=0, + sp_size=1, + ring_rank=0, + sp_ring_size=1, + ): + """ + USP preprocess: shard all sequences by sp_rank and add TTT overlap. + Each local sequence length = ceil(max_len / sp_size) + ttt_length. + """ + new_data = {} + + input_ids = data["input_ids"] + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + global_len = min(max_len, input_ids.shape[1]) + chunk_size = (global_len + sp_size - 1) // sp_size + start = sp_rank * chunk_size + local_len = chunk_size + ttt_length + + end = min(start + local_len, global_len) + + def _slice_and_pad(tensor): + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + tensor = tensor[:, :global_len] + sliced = tensor[:, start : min(end, tensor.shape[1])] + valid_len = sliced.shape[1] + if valid_len < local_len: + pad_len = local_len - valid_len + if tensor.ndim == 2: + sliced = F.pad(sliced, (0, pad_len)) + else: + sliced = F.pad(sliced, (0, 0, 0, pad_len)) + return sliced.contiguous(), valid_len + + if "aux_hidden_state" not in data or data["aux_hidden_state"] is None: + raise KeyError("aux_hidden_state is required for OfflineEagle3Dataset") + new_data["hidden_state"], _ = _slice_and_pad(data["aux_hidden_state"]) + new_data["target"], _ = _slice_and_pad(data["hidden_state"]) + + new_data["input_ids"], valid_len = _slice_and_pad(input_ids) + + full_loss_mask = data["loss_mask"] + if full_loss_mask.ndim == 1: + full_loss_mask = full_loss_mask.unsqueeze(0) + + full_loss_mask = full_loss_mask[:, :global_len].clone() + if full_loss_mask.numel() > 0: + full_loss_mask[0, -1] = 0 + new_data["loss_mask"], _ = _slice_and_pad(full_loss_mask) + + local_len = new_data["input_ids"].shape[1] + attention_mask = torch.zeros((1, local_len), dtype=torch.long) + attention_mask[:, :valid_len] = 1 + new_data["attention_mask"] = attention_mask + + # Position ids should align with Ulysses all2all-expanded sequence length. + # Local seq_len (per sp_rank) = local_len; attention uses (local_len - ttt_length). + sp_ulysses_size = max(1, sp_size // sp_ring_size) + usp_chunk_size = max(local_len - ttt_length, 0) + ring_chunk = usp_chunk_size * sp_ulysses_size + ring_start = ring_rank * ring_chunk + new_data["position_ids"] = torch.arange( + ring_start, ring_start + ring_chunk, dtype=torch.long + ).unsqueeze(0) + + if transform: + new_data = transform(new_data) + + return new_data + def __len__(self): return len(self.datapaths) def _open_file(self, index): - return torch.load(self.datapaths[index], weights_only=False) + """ + Opens the file with memory mapping. + This operation is virtually instant and consumes negligible RAM + because no data is actually read from disk yet. + """ + data_path = self.datapaths[index] + if data_path.endswith(".gz"): + with gzip.open(data_path, "rb") as f: + return torch.load(io.BytesIO(f.read()), weights_only=False) + return torch.load(data_path, weights_only=False, mmap=True) def __getitem__(self, index): try: @@ -482,7 +604,24 @@ def __getitem__(self, index): except Exception as e: print(f"ERROR Failed to load {self.datapaths[index]} with error {e}") data = self._open_file(0) - return self.process_data(data, self.max_len, self.transform) + + # 2. Read only specific bytes from disk + if self.use_usp_preprocess: + return self.process_data_usp( + data, + self.max_len, + ttt_length=self.ttt_length, + transform=self.transform, + sp_rank=self.sp_rank, + sp_size=self.sp_size, + ring_rank=self.ring_rank, + sp_ring_size=self.sp_ring_size, + ) + return self.process_data( + data, + self.max_len, + self.transform, + ) def set_epoch(self, epoch): self._epoch = epoch @@ -491,10 +630,15 @@ def set_epoch(self, epoch): def build_offline_eagle3_dataset( hidden_states_path: str, max_len: int = 2048, + ttt_length: int = 1, + use_usp_preprocess: bool = False, ) -> torch.utils.data.Dataset: + return OfflineEagle3Dataset( list_local_files(hidden_states_path), max_len=max_len, + ttt_length=ttt_length, + use_usp_preprocess=use_usp_preprocess, ) @@ -521,7 +665,7 @@ def generate_vocab_mapping_file( Returns: The path to the vocab mapping file. """ - # prepare cache direcotory + # prepare cache directory os.makedirs(cache_dir, exist_ok=True) vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt") @@ -529,7 +673,7 @@ def generate_vocab_mapping_file( print(f"Loading vocab mapping from the cached file at: {vocab_mapping_path}") return vocab_mapping_path - # we first count the frequency of effectiev tokens in the dataset + # we first count the frequency of effective tokens in the dataset token_dict = Counter() for input_ids, loss_mask in tqdm( zip(dataset["input_ids"], dataset["loss_mask"]), diff --git a/specforge/data/utils.py b/specforge/data/utils.py index 9668680dc..93fd6f58a 100644 --- a/specforge/data/utils.py +++ b/specforge/data/utils.py @@ -26,7 +26,7 @@ from torch.utils.data import DataLoader, DistributedSampler from datasets import Dataset -from specforge.distributed import get_draft_sp_group +from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group class DataCollatorWithPadding: @@ -36,6 +36,7 @@ class DataCollatorWithPadding: def __init__(self): self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group()) + self.ulysses_degree = torch.distributed.get_world_size(get_sp_ulysses_group()) def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: """ @@ -90,10 +91,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - loss_mask: torch.Tensor of shape (B, N) """ max_length = max(item["input_ids"].shape[1] for item in features) + # pad for sequence parrel max_length = ( (max_length + self.sp_degree - 1) // self.sp_degree ) * self.sp_degree + # position max len, ulysses do not need chuck position ids + position_max_len = max_length * self.ulysses_degree + batch_input_ids = torch.cat( [self.paddingtensor2D(item["input_ids"], max_length) for item in features] ) @@ -106,6 +111,15 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: batch_loss_mask = torch.cat( [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] ) + if "position_ids" in features[0]: + batch_position_ids = torch.cat( + [ + self.paddingtensor2D(item["position_ids"], position_max_len) + for item in features + ] + ) + else: + batch_position_ids = None batch = { "input_ids": batch_input_ids, "attention_mask": batch_attention_mask, @@ -113,16 +127,23 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: "hidden_state": None, "target": None, } + if batch_position_ids is not None: + batch["position_ids"] = batch_position_ids if all("hidden_state" in item for item in features): assert all( "target" in item for item in features ), "target is required when hidden_state is provided" - batch["hidden_state"] = torch.cat( - [ - self.paddingtensor(item["hidden_state"], max_length) - for item in features - ] - ) + if self.sp_degree > 1: # USP mode + batch["hidden_state"] = torch.cat( + [item["hidden_state"] for item in features] + ) + else: + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) batch["target"] = torch.cat( [self.paddingtensor(item["target"], max_length) for item in features] ) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 552a3cf86..4a1833072 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -977,6 +977,10 @@ def __init__(self, config): assert ( dist.is_initialized() ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + raise NotImplementedError( + f"LlamaMutiRotaryEmbedding is currently not supported for LlamaUSPFlashAttention." + ) self.ring_pg = get_sp_ring_group() self.ulysses_pg = get_sp_ulysses_group() self.sp_ring_degree = torch.distributed.get_world_size(self.ring_pg) @@ -1043,39 +1047,16 @@ def forward( # Global length calculation (for RoPE) global_q_len = q_len * self.sp_ring_degree * self.sp_ulysses_degree - # ============================================================= # 2. RoPE & Cache Management # ============================================================= - if self.sp_ring_degree > 1: - if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): - position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[ - self.ring_rank - ].clone() - else: - position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[ - self.ring_rank - ].clone() - lck = 0 if cache_hidden is None else len(cache_hidden[0]) - if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): - cos, sin = self.rotary_emb(query_states, position_ids + lck) - cos, sin = cos.to(query_states.device), sin.to(query_states.device) - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - self.config.rope_scaling["mrope_section"], - unsqueeze_dim=2, - ) - else: - cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) - cos, sin = cos.to(query_states.device), sin.to(query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 - ) + cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) # Update Cache (Eagle3 Logic: Cache is a list of tensors for tree branches) if cache_hidden is not None: diff --git a/tests/test_layers/test_decoder.py b/tests/test_layers/test_decoder.py index aca4f3c24..db6d785db 100644 --- a/tests/test_layers/test_decoder.py +++ b/tests/test_layers/test_decoder.py @@ -1,4 +1,5 @@ import os +import time import unittest import torch @@ -8,6 +9,9 @@ from transformers import PretrainedConfig from yunchang import EXTRACT_FUNC_DICT +from specforge.core.eagle3_adapters import SdpaLikeAdapter, UspAdapter +from specforge.data.preprocessing import build_offline_eagle3_dataset + # Project-specific imports from specforge.distributed import destroy_distributed, init_distributed from specforge.modeling.draft.llama3_eagle import LlamaDecoderLayer @@ -57,6 +61,19 @@ def setup_env(rank, world_size, port): torch.cuda.set_device(rank) +def dbg(rank, msg): + print(f"[rank{rank}] {msg}", flush=True) + + +def wait_for_file(path, timeout_s=60, poll_s=0.1): + start = time.time() + while time.time() - start < timeout_s: + if os.path.exists(path): + return True + time.sleep(poll_s) + return False + + def run_iterative_pass( decoder_layer, embed_tokens, @@ -113,6 +130,7 @@ def run_test_case(rank, world_size, port): setup_env(rank, world_size, port) device = torch.device(f"cuda:{rank}") set_seed(42) + dbg(rank, "env setup complete") # --- Data & Config Preparation --- config = get_model_config() @@ -135,13 +153,34 @@ def run_test_case(rank, world_size, port): config.vocab_size, config.hidden_size, config.pad_token_id ).to(device) - # --- Phase 1: Golden Run (SDPA) --- + # --- Phase 1: Golden Run (FA) --- # Init dist briefly for internal checks, even if running single-device logic init_distributed(tp_size=1, sp_ulysses_size=1, sp_ring_size=1) + dbg(rank, "init_distributed (FA) done") sdpa_decoder = ( LlamaDecoderLayer(config, attention_backend="fa").to(device).to(torch.bfloat16) ) + dbg(rank, "FA decoder created") + # Adapter smoke test for FA/SDPA-style path + dummy_model = type("Dummy", (), {})() + sdpa_adapter = SdpaLikeAdapter(dummy_model) + sdpa_target_p = torch.zeros((1, seq_len, 8), device=device, dtype=torch.float32) + sdpa_position_mask = torch.ones((1, seq_len, 1), device=device, dtype=torch.float32) + sdpa_state = sdpa_adapter.step_view( + idx=0, + ttt_length=ttt_length, + global_input_ids=data_input_ids, + attention_mask=attention_mask, + loss_mask=torch.ones((1, seq_len, 1), device=device, dtype=torch.float32), + position_ids=position_ids, + hidden_states=data_hidden_states, + target_p_padded=sdpa_target_p, + position_mask=sdpa_position_mask, + seq_length=seq_len, + ) + assert sdpa_state.input_ids.shape[1] == seq_len + assert sdpa_state.hidden_states.shape[1] == seq_len with torch.no_grad(): sdpa_output = run_iterative_pass( @@ -153,11 +192,13 @@ def run_test_case(rank, world_size, port): position_ids=position_ids, ttt_length=ttt_length, ) + dbg(rank, "FA forward done") # Save weights for alignment and cleanup SDPA model state_dict = sdpa_decoder.state_dict() del sdpa_decoder destroy_distributed() + dbg(rank, "destroy_distributed (FA) done") # --- Phase 2: Distributed Run (USP) --- def subtest_usp(sp_ulysses_degree, sp_ring_degree): @@ -168,6 +209,89 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): sp_ulysses_size=sp_ulysses_degree, sp_ring_size=sp_ring_degree, ) + dbg( + rank, + f"init_distributed (USP U{sp_ulysses_degree} R{sp_ring_degree}) done", + ) + # Dataset + adapter smoke test (USP path) + tmp_dir = "./tmp/usp_dataset_shared" + try: + if rank == 0: + os.makedirs(tmp_dir, exist_ok=True) + sample = { + "input_ids": data_input_ids[0].cpu(), + "loss_mask": torch.ones_like(data_input_ids[0].cpu()), + "hidden_state": data_hidden_states[0].cpu().unsqueeze(0), + "aux_hidden_state": data_hidden_states[0].cpu().unsqueeze(0), + } + torch.save(sample, os.path.join(tmp_dir, "data_0.ckpt")) + dbg(rank, "wrote sample ckpt") + ready_flag = os.path.join(tmp_dir, "ready.flag") + with open(ready_flag, "w", encoding="utf-8") as f: + f.write("ready\n") + if rank != 0: + ready_flag = os.path.join(tmp_dir, "ready.flag") + assert wait_for_file( + ready_flag, timeout_s=60 + ), "timeout waiting for ready flag" + dbg(rank, "dataset sync done") + assert os.path.exists( + os.path.join(tmp_dir, "data_0.ckpt") + ), f"Expected sample not found at {tmp_dir}" + dbg(rank, "sample exists") + + ds = build_offline_eagle3_dataset( + tmp_dir, + max_len=seq_len, + ttt_length=ttt_length, + use_usp_preprocess=True, + ) + dbg(rank, "dataset built") + item = ds[0] + dbg(rank, "dataset item loaded") + assert "position_ids" in item + + dummy_model = type("Dummy", (), {})() + adapter = UspAdapter(dummy_model) + local_seq_len = item["input_ids"].shape[1] + target_p_padded = torch.zeros( + (1, local_seq_len, 8), device=device, dtype=torch.float32 + ) + position_mask = torch.ones( + (1, local_seq_len, 1), device=device, dtype=torch.float32 + ) + state = adapter.step_view( + idx=0, + ttt_length=ttt_length, + global_input_ids=item["input_ids"].to(device), + attention_mask=item["attention_mask"].to(device), + loss_mask=item["loss_mask"].to(device).unsqueeze(-1), + position_ids=item["position_ids"].to(device), + hidden_states=item["hidden_state"].to(device), + target_p_padded=target_p_padded, + position_mask=position_mask, + seq_length=local_seq_len, + ) + assert state.input_ids.shape[1] == local_seq_len - ttt_length + assert state.hidden_states.shape[1] == local_seq_len - ttt_length + dbg(rank, "adapter step_view ok") + finally: + if rank == 0: + done_flag = os.path.join(tmp_dir, "done.flag") + assert wait_for_file( + done_flag, timeout_s=60 + ), "timeout waiting for done flag" + try: + for root, _, files in os.walk(tmp_dir): + for name in files: + os.remove(os.path.join(root, name)) + os.rmdir(tmp_dir) + except OSError: + pass + else: + done_flag = os.path.join(tmp_dir, "done.flag") + with open(done_flag, "w", encoding="utf-8") as f: + f.write("done\n") # Init USP model and load golden weights usp_decoder = ( @@ -176,6 +300,7 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): .to(torch.bfloat16) ) usp_decoder.load_state_dict(state_dict) + dbg(rank, "USP decoder loaded") # Shard data (Split Input) extract_func = EXTRACT_FUNC_DICT["basic"] @@ -203,24 +328,41 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): .detach() .clone() ) + dbg(rank, "USP local inputs prepared") + total_degree = sp_ring_degree * sp_ulysses_degree + chunk_size = sdpa_output.shape[1] // total_degree + start_idx = (rank % total_degree) * chunk_size + local_len = local_input_ids.shape[1] + local_position_ids = ( + torch.arange(start_idx, start_idx + local_len, device=device) + .unsqueeze(0) + .long() + ) + local_attention_mask = torch.tril( + torch.ones(local_len, local_len, device=device) + ).view(1, 1, local_len, local_len) # Run USP forward + if sp_ring_degree > 1: + usp_attention_mask = local_attention_mask + usp_position_ids = local_position_ids + else: + usp_attention_mask = attention_mask + usp_position_ids = position_ids with torch.no_grad(): usp_output = run_iterative_pass( decoder_layer=usp_decoder, embed_tokens=embed_tokens, input_ids=local_input_ids, hidden_states=local_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + attention_mask=usp_attention_mask, + position_ids=usp_position_ids, ttt_length=ttt_length, ) + dbg(rank, "USP forward done") # Verify results # Slice the golden output to match the current rank's chunk - total_degree = sp_ring_degree * sp_ulysses_degree - chunk_size = sdpa_output.shape[1] // total_degree - start_idx = (rank % total_degree) * chunk_size end_idx = start_idx + chunk_size golden_chunk = sdpa_output[:, start_idx:end_idx, :] @@ -229,9 +371,11 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): f"[Rank {rank}] USP (U{sp_ulysses_degree}R{sp_ring_degree}) mismatch!\n" f"Max Diff: {(usp_output - golden_chunk).abs().max().item()}" ) + dbg(rank, "USP output verified") finally: destroy_distributed() + dbg(rank, "destroy_distributed (USP) done") # Case 1: Hybrid (Ulysses=2, Ring=1) subtest_usp(sp_ulysses_degree=2, sp_ring_degree=1)