diff --git a/scripts/prepare_hidden_states.py b/scripts/prepare_hidden_states.py index d201ca479..304677d26 100644 --- a/scripts/prepare_hidden_states.py +++ b/scripts/prepare_hidden_states.py @@ -34,6 +34,7 @@ import argparse import gc +import gzip import hashlib import os from concurrent.futures import ThreadPoolExecutor @@ -58,6 +59,7 @@ ) from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model from specforge.utils import ( + print_args_with_dots, print_with_rank, rank_0_priority, safe_conversations_generator, @@ -123,8 +125,8 @@ def parse_args(): others_group.add_argument( "--num-io-threads", type=int, - default=4, - help="Number of threads for async I/O operations", + default=None, + help="Number of threads for async I/O operations (default: all of CPU cores).", ) others_group.add_argument( "--num-workers", type=int, default=4, help="Number of workers for DataLoader" @@ -141,6 +143,17 @@ def parse_args(): default=2000, help="Number of files per subdirectory.", ) + others_group.add_argument( + "--compress", + action="store_true", + help="Compress hidden state files on disk (gzip).", + ) + others_group.add_argument( + "--compression-level", + type=int, + default=6, + help="Gzip compression level (1-9).", + ) sglang_group = parser.add_argument_group("sglang") SGLangBackendArgs.add_args(sglang_group) @@ -215,6 +228,8 @@ def __init__( num_io_threads: int = 4, io_queue_size: int = 50, file_group_size: int = 2000, + compress: bool = False, + compression_level: int = 6, ): """ Args: @@ -231,6 +246,9 @@ def __init__( self.num_io_threads = num_io_threads self.io_queue_size = io_queue_size self.file_group_size = file_group_size + self.compress = compress + self.compression_level = compression_level + self.file_extension = ".ckpt.gz" if self.compress else ".ckpt" # progress bar should only shown on TP rank = 0 self.show_progress = dist.get_rank(get_tp_group()) == 0 @@ -282,7 +300,13 @@ def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None: ) return - torch.save(asdict(data_point), output_file) + if self.compress: + with gzip.open( + output_file, "wb", compresslevel=self.compression_level + ) as f: + torch.save(asdict(data_point), f) + else: + torch.save(asdict(data_point), output_file) def _save_tensor_async(self, data_point: DataPoint, output_file: str) -> None: """ @@ -365,14 +389,22 @@ def _check_existing_files_batch( return [False] * len(global_indices) def check_single_file(idx): - return os.path.exists(self._get_file_path(output_path, idx)) + if os.path.exists(self._get_file_path(output_path, idx)): + return True + legacy_ckpt = self._get_file_path(output_path, idx, extension=".ckpt") + compressed_ckpt = self._get_file_path( + output_path, idx, extension=".ckpt.gz" + ) + return os.path.exists(legacy_ckpt) or os.path.exists(compressed_ckpt) # Parallel file existence check with ThreadPoolExecutor(max_workers=self.num_io_threads) as executor: exists = list(executor.map(check_single_file, global_indices)) return exists - def _get_file_path(self, output_path: str, idx: int) -> str: + def _get_file_path( + self, output_path: str, idx: int, extension: Optional[str] = None + ) -> str: """ A helper function to get the standard file path for the data point with the given index. @@ -383,9 +415,10 @@ def _get_file_path(self, output_path: str, idx: int) -> str: Returns: str: The file path for the data point. """ + ext = self.file_extension if extension is None else extension group_idx = (idx // self.file_group_size) * self.file_group_size grouped_subdir = f"rows_{group_idx}-{group_idx + self.file_group_size}" - return os.path.join(output_path, grouped_subdir, f"data_{idx}.ckpt") + return os.path.join(output_path, grouped_subdir, f"data_{idx}{ext}") @torch.no_grad() def generate( @@ -553,9 +586,12 @@ def main(): args.aux_hidden_states_layers = [ int(x) for x in args.aux_hidden_states_layers.split(",") ] - + if args.num_io_threads is None: + cpu_cores = os.cpu_count() or 1 + args.num_io_threads = max(1, cpu_cores) # Initialize distributed environment (TP + DP) init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) + print_args_with_dots(args) # Build target model (with TP) target_model_config = AutoConfig.from_pretrained( @@ -657,6 +693,8 @@ def main(): num_io_threads=args.num_io_threads, io_queue_size=args.io_queue_size, file_group_size=args.file_group_size, + compress=args.compress, + compression_level=args.compression_level, # Other params like io_queue_size can also be added to argparse ) as hidden_states_generator: diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 4237dc94a..0d87c33ad 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -35,7 +35,6 @@ destroy_distributed, get_dp_group, get_draft_dp_group, - get_draft_sp_group, get_tp_group, init_distributed, ) @@ -340,9 +339,24 @@ def sanity_check(args: Namespace) -> None: """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size + if args.attention_backend == "usp": + sp_sanity_check(args) + + +def sp_sanity_check(args: Namespace) -> None: args.draft_accumulation_steps = ( args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size ) + assert ( + args.batch_size == 1 + ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" + + assert args.sp_ring_size * args.sp_ulysses_size > 1, ( + f"USP requires sp_ring_size * sp_ulysses_size > 1. " + f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}." + ) + + assert args.train_hidden_states_path is not None, f"USP only support offline mode" if args.eval_data_path is not None and args.eval_hidden_states_path is not None: raise ValueError( @@ -453,6 +467,8 @@ def build_dataloaders( train_eagle3_dataset = build_offline_eagle3_dataset( args.train_hidden_states_path, args.max_length, + ttt_length=args.ttt_length, + use_usp_preprocess=(args.attention_backend == "usp"), ) train_dataloader = prepare_dp_dataloaders( @@ -488,6 +504,8 @@ def build_dataloaders( eval_eagle3_dataset = build_offline_eagle3_dataset( args.eval_hidden_states_path, args.max_length, + ttt_length=args.ttt_length, + use_usp_preprocess=(args.attention_backend == "usp"), ) eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, @@ -619,6 +637,9 @@ def run_forward( loss_mask=loss_mask, target=target, hidden_states=hidden_states, + position_ids=( + data["position_ids"].cuda() if "position_ids" in data else None + ), image_grid_thw=image_grid_thw, is_vlm=args.is_vlm, ) @@ -675,56 +696,13 @@ def record_metrcs( tracker.log(logdict, step=global_step) -def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor: +def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor: """ - Process: TP split -> Pad to Max Len -> SP gather. + Get the data shard from the tensor. """ - # 1. TP: Slice the tensor along the batch dimension - tp_group = get_tp_group() - tp_size = dist.get_world_size(tp_group) - tp_rank = dist.get_rank(tp_group) - - local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank] - - # 2. SP: Handle dynamic sequence lengths and Gather - sp_group = get_draft_sp_group() - - if sp_group is not None and dist.get_world_size(sp_group) > 1: - sp_world_size = dist.get_world_size(sp_group) - local_seq_len = local_tp_shard.size(sp_dim) - - # Find global max sequence length in SP group - len_tensor = torch.tensor( - [local_seq_len], device=local_tp_shard.device, dtype=torch.long - ) - dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group) - max_seq_len = len_tensor.item() - - # Pad local tensor if necessary - # Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1 - if local_seq_len < max_seq_len: - pad_size = max_seq_len - local_seq_len - - pad_config = [0] * (local_tp_shard.ndim * 2) - - pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1 - pad_config[pad_idx] = pad_size - - # Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed - local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0) - else: - local_tp_shard_padded = local_tp_shard - - gathered_shards = [ - torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size) - ] - dist.all_gather( - gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group - ) - - return torch.cat(gathered_shards, dim=sp_dim) - - return local_tp_shard + tp_size = dist.get_world_size(get_tp_group()) + tp_rank = dist.get_rank(get_tp_group()) + return tensor.chunk(tp_size, dim=0)[tp_rank] def main(): @@ -890,7 +868,11 @@ def main(): # 7.1 Training Step # ================================================ plosses, acces = run_forward( - args, eagle3_model, data, target_model, is_online + args, + eagle3_model, + data, + target_model, + is_online, ) run_backward_and_update(args, plosses, optimizer, global_step) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index abf43527f..1e2f04e7e 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -25,16 +25,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.checkpoint import checkpoint from transformers.cache_utils import DynamicCache -from yunchang import EXTRACT_FUNC_DICT +from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter from specforge.core.loss import LogSoftmaxLoss -from specforge.distributed import ( - gather_outputs_and_unpad, - get_sp_ring_group, - get_sp_ulysses_group, -) from specforge.modeling.draft import Eagle3DraftModel from specforge.utils import padding @@ -74,23 +68,66 @@ def __init__( self.attention_backend = attention_backend self.target_model = target_model + def _make_adapter(self) -> BackendAdapter: if self.attention_backend == "usp": - self.extract_func = EXTRACT_FUNC_DICT["basic"] - self.sp_ring_degree = torch.distributed.get_world_size(get_sp_ring_group()) - self.sp_ulysses_degree = torch.distributed.get_world_size( - get_sp_ulysses_group() + return UspAdapter(self) + return SdpaLikeAdapter(self) + + def _acc_and_loss( + self, + *, + logits: torch.Tensor, + target_p: torch.Tensor, + position_mask: torch.Tensor, + loss_mask: torch.Tensor, + adapter: BackendAdapter, + ) -> Tuple[torch.Tensor, torch.Tensor]: + with torch.no_grad(): + local_correct = ( + (logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1) + ).sum() + local_denom = loss_mask.sum().clamp_min(1e-6) + local_correct, local_denom = adapter.reduce_metrics( + local_correct=local_correct, local_denom=local_denom + ) + acc = local_correct / local_denom + + loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) + loss = adapter.reduce_loss(loss) + return acc, loss + + def _prepare_position_ids( + self, + position_ids: Optional[torch.Tensor], + *, + seq_length: int, + past_key_values_length: int, + device: torch.device, + is_vlm: bool, + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + ) -> torch.Tensor: + if self.attention_backend == "usp": + return position_ids + if position_ids is None: + if is_vlm: + mrope_positions_ids, _ = self.target_model.get_rope_index( + input_ids=input_ids, image_grid_thw=image_grid_thw + ) + return mrope_positions_ids + return ( + torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + .unsqueeze(0) + .view(-1, seq_length) ) - self.sp_world_size = self.sp_ring_degree * self.sp_ulysses_degree - self.sp_rank = torch.distributed.get_rank() % self.sp_world_size - - @torch.compile() - def prepare_usp_input(self, full_input): - shared_input = self.extract_func( - full_input, - rank=self.sp_rank, - world_size=self.sp_world_size, - ).clone() - return shared_input + + position_ids = position_ids.long() + return position_ids.view(-1, seq_length) def forward( self, @@ -131,35 +168,21 @@ def forward( past_key_values_length = 0 # Step 2: project the concatenated hidden states to the target hidden size - if self.attention_backend == "usp": - # NOTE: Split first for USP to parallelize computation and ensure - # gradient consistency without redundant full-sequence projection. - hidden_states = self.prepare_usp_input(hidden_states) hidden_states = self.draft_model.project_hidden_states(hidden_states) # Step 3: process kv cache, position ids and position ids if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - if is_vlm: - mrope_positions_ids, mrope_position_delta = ( - self.target_model.get_rope_index( - input_ids=input_ids, image_grid_thw=image_grid_thw - ) - ) - position_ids = mrope_positions_ids - else: - device = hidden_states.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + position_ids = self._prepare_position_ids( + position_ids=position_ids, + seq_length=seq_length, + past_key_values_length=past_key_values_length, + device=hidden_states.device, + is_vlm=is_vlm, + input_ids=input_ids, + image_grid_thw=image_grid_thw, + ) # Step 4: handle attention mask if attention_mask is None: @@ -177,28 +200,11 @@ def forward( past_key_values_length=past_key_values_length, ) - def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): - # 1. Compute Logits(The part that consumes the most VRAM.) - logits_ = self.draft_model.compute_logits(hs) - logits = gather_outputs_and_unpad(logits_, gather_dim=1) - - # 2. Compute Loss - loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) - - # 3. Compute Accuracy - with torch.no_grad(): - acc_val = _compute_metric_acc( - logits=logits, - target_p=tgt_p, - position_mask=pos_mask, - loss_mask=l_mask, - ) - return loss_val, acc_val - # Step 5: run TTT plosses = [] vlosses = [] acces = [] + adapter = self._make_adapter() # for sequence paralle, position mask and input ids will split by sequence dim, need to keep origin for ttt shift global_input_ids = input_ids if self.attention_backend in ["sdpa", "fa", "usp"]: @@ -211,25 +217,31 @@ def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): raise ValueError(f"Unknown attention backend: {self.attention_backend}") for idx in range(self.length): - target_p = target_p_padded[:, idx : idx + seq_length, :] - if self.attention_backend == "usp": - input_ids = self.prepare_usp_input(global_input_ids) - else: - input_ids = global_input_ids - + state = adapter.step_view( + idx=idx, + ttt_length=self.length, + global_input_ids=global_input_ids, + attention_mask=attention_mask, + loss_mask=loss_mask, + position_ids=position_ids, + hidden_states=hidden_states, + target_p_padded=target_p_padded, + position_mask=position_mask, + seq_length=seq_length, + ) is_last = idx == self.length - 1 # Step 5.1: embed the input ids - inputs_embeds = self.draft_model.embed_input_ids(input_ids) + inputs_embeds = self.draft_model.embed_input_ids(state.input_ids) inputs_embeds = inputs_embeds.to(hidden_states.dtype) # Step 5.2: run the draft model backbone hidden_states_out = self.draft_model.backbone( input_embeds=inputs_embeds, - hidden_states=hidden_states, + hidden_states=state.hidden_states, cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, + attention_mask=state.attention_mask, + position_ids=state.position_ids, past_key_values=past_key_values, use_cache=True, ) @@ -237,22 +249,20 @@ def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): # update hidden states for next step hidden_states = hidden_states_out - if hidden_states.requires_grad: - loss, acc = checkpoint( - compute_loss_and_acc_checkpointed, - hidden_states, - target_p, - position_mask, - loss_mask, - use_reentrant=False, - ) - else: - loss, acc = compute_loss_and_acc_checkpointed( - hidden_states, target_p, position_mask, loss_mask - ) + # Step 5.4: get logits + logits = self.draft_model.compute_logits(hidden_states) - plosses.append(loss) + # Step 5.5 + 5.6: metric and loss + acc, loss = self._acc_and_loss( + logits=logits, + target_p=state.target_p, + position_mask=state.position_mask, + loss_mask=state.loss_mask, + adapter=adapter, + ) acces.append(acc) + plosses.append(loss) + if not is_last: # Step 5.7: we need to update the loss mask global_input_ids = padding(global_input_ids, left=False) diff --git a/specforge/core/eagle3_adapters.py b/specforge/core/eagle3_adapters.py new file mode 100644 index 000000000..555c16efc --- /dev/null +++ b/specforge/core/eagle3_adapters.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch +import torch.distributed as dist +import torch.distributed.nn.functional as dist_nn + +from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group + + +@dataclass +class StepState: + input_ids: torch.Tensor + hidden_states: torch.Tensor + position_ids: torch.Tensor + attention_mask: torch.Tensor + target_p: torch.Tensor + position_mask: torch.Tensor + loss_mask: torch.Tensor + + +class BackendAdapter: + def __init__(self, model: "OnlineEagle3Model"): + self.m = model + + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + raise NotImplementedError + + def reduce_metrics( + self, *, local_correct: torch.Tensor, local_denom: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return local_correct, local_denom + + def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: + return loss + + +class SdpaLikeAdapter(BackendAdapter): + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + target_p = target_p_padded[:, idx : idx + seq_length, :].contiguous() + return StepState( + input_ids=global_input_ids, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + target_p=target_p, + position_mask=position_mask, + loss_mask=loss_mask, + ) + + +class UspAdapter(BackendAdapter): + def __init__(self, model: "OnlineEagle3Model"): + super().__init__(model) + self.sp_group = get_draft_sp_group() + self.sp_world_size = dist.get_world_size(self.sp_group) + self.ulysses_pg = get_sp_ulysses_group() + self.sp_ulysses_degree = dist.get_world_size(self.ulysses_pg) + + def step_view( + self, + *, + idx: int, + ttt_length: int, + global_input_ids: torch.Tensor, + attention_mask: torch.Tensor, + loss_mask: torch.Tensor, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + target_p_padded: torch.Tensor, + position_mask: torch.Tensor, + seq_length: int, + ) -> StepState: + usp_chunk_size = seq_length - ttt_length + if usp_chunk_size <= 0: + raise ValueError( + f"USP local seq_length ({seq_length}) must be larger than " + f"ttt_length ({ttt_length})" + ) + target_p = target_p_padded[:, idx : idx + usp_chunk_size, :] + return StepState( + input_ids=global_input_ids[:, :usp_chunk_size], + hidden_states=hidden_states[:, :usp_chunk_size, :], + position_ids=position_ids[:, : usp_chunk_size * self.sp_ulysses_degree], + attention_mask=attention_mask[:, :usp_chunk_size], + target_p=target_p, + position_mask=position_mask[:, :usp_chunk_size, :], + loss_mask=loss_mask[:, :usp_chunk_size, :], + ) + + def reduce_metrics( + self, *, local_correct: torch.Tensor, local_denom: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + local_correct = dist_nn.all_reduce( + local_correct, op=dist.ReduceOp.SUM, group=self.sp_group + ) + local_denom = dist_nn.all_reduce( + local_denom, op=dist.ReduceOp.SUM, group=self.sp_group + ) + return local_correct, local_denom + + def reduce_loss(self, loss: torch.Tensor) -> torch.Tensor: + loss = dist_nn.all_reduce(loss, op=dist.ReduceOp.SUM, group=self.sp_group) + loss = loss / self.sp_world_size + return loss diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 648c0d1a3..46f3d8647 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -20,6 +20,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gzip +import io import os import re import warnings @@ -27,11 +29,14 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.nn.functional as F from tqdm import tqdm from transformers import ImageProcessingMixin, PreTrainedTokenizer from datasets import Dataset as HFDataset +from ..distributed import get_draft_sp_group, get_sp_ring_group + try: from qwen_vl_utils import process_vision_info @@ -432,23 +437,53 @@ def preprocess_function(examples): # Offline Eagle3 Dataset # ============================== # modified from https://github.com/NickL77/BaldEagle/blob/master/train/modules/data/data.py -def list_local_files(path, suffixes=[".ckpt"]): +def list_local_files(path, suffixes=None): + if suffixes is None: + suffixes = [".ckpt", ".ckpt.gz"] datapaths = [] for root, directories, files in os.walk(path): for file in files: file_path = os.path.join(root, file) datapaths.append(file_path) - for suffix in suffixes: - datapaths = [f_name for f_name in datapaths if f_name.endswith(suffix)] + if suffixes: + datapaths = [ + f_name + for f_name in datapaths + if any(f_name.endswith(suffix) for suffix in suffixes) + ] return datapaths class OfflineEagle3Dataset(torch.utils.data.Dataset): - def __init__(self, datapath, transform=None, max_len=2048): + def __init__( + self, + datapath, + transform=None, + max_len=2048, + ttt_length=1, + use_usp_preprocess=False, + ): + """ + Args: + datapath: List of file paths. + transform: Optional transform to apply. + max_len: Maximum sequence length to load. + ttt_length: TTT overlap length used in USP preprocessing. + use_usp_preprocess: Whether to shard all sequences with USP overlap in preprocessing. + """ self.datapaths = datapath self.transform = transform self._epoch = 0 self.max_len = max_len + self.ttt_length = ttt_length + self.use_usp_preprocess = use_usp_preprocess + if use_usp_preprocess: + sp_group = get_draft_sp_group() + self.sp_rank = torch.distributed.get_rank(sp_group) + self.sp_size = torch.distributed.get_world_size(sp_group) + ring_group = get_sp_ring_group() + self.ring_rank = torch.distributed.get_rank(ring_group) + self.sp_ring_size = torch.distributed.get_world_size(ring_group) @staticmethod def process_data(data, max_len, transform=None): @@ -470,11 +505,98 @@ def process_data(data, max_len, transform=None): new_data = transform(new_data) return new_data + @staticmethod + def process_data_usp( + data, + max_len, + ttt_length=1, + transform=None, + sp_rank=0, + sp_size=1, + ring_rank=0, + sp_ring_size=1, + ): + """ + USP preprocess: shard all sequences by sp_rank and add TTT overlap. + Each local sequence length = ceil(max_len / sp_size) + ttt_length. + """ + new_data = {} + + input_ids = data["input_ids"] + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + global_len = min(max_len, input_ids.shape[1]) + chunk_size = (global_len + sp_size - 1) // sp_size + start = sp_rank * chunk_size + local_len = chunk_size + ttt_length + + end = min(start + local_len, global_len) + + def _slice_and_pad(tensor): + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + tensor = tensor[:, :global_len] + sliced = tensor[:, start : min(end, tensor.shape[1])] + valid_len = sliced.shape[1] + if valid_len < local_len: + pad_len = local_len - valid_len + if tensor.ndim == 2: + sliced = F.pad(sliced, (0, pad_len)) + else: + sliced = F.pad(sliced, (0, 0, 0, pad_len)) + return sliced.contiguous(), valid_len + + if "aux_hidden_state" not in data or data["aux_hidden_state"] is None: + raise KeyError("aux_hidden_state is required for OfflineEagle3Dataset") + new_data["hidden_state"], _ = _slice_and_pad(data["aux_hidden_state"]) + new_data["target"], _ = _slice_and_pad(data["hidden_state"]) + + new_data["input_ids"], valid_len = _slice_and_pad(input_ids) + + full_loss_mask = data["loss_mask"] + if full_loss_mask.ndim == 1: + full_loss_mask = full_loss_mask.unsqueeze(0) + + full_loss_mask = full_loss_mask[:, :global_len].clone() + if full_loss_mask.numel() > 0: + full_loss_mask[0, -1] = 0 + new_data["loss_mask"], _ = _slice_and_pad(full_loss_mask) + + local_len = new_data["input_ids"].shape[1] + attention_mask = torch.zeros((1, local_len), dtype=torch.long) + attention_mask[:, :valid_len] = 1 + new_data["attention_mask"] = attention_mask + + # Position ids should align with Ulysses all2all-expanded sequence length. + # Local seq_len (per sp_rank) = local_len; attention uses (local_len - ttt_length). + sp_ulysses_size = max(1, sp_size // sp_ring_size) + usp_chunk_size = max(local_len - ttt_length, 0) + ring_chunk = usp_chunk_size * sp_ulysses_size + ring_start = ring_rank * ring_chunk + new_data["position_ids"] = torch.arange( + ring_start, ring_start + ring_chunk, dtype=torch.long + ).unsqueeze(0) + + if transform: + new_data = transform(new_data) + + return new_data + def __len__(self): return len(self.datapaths) def _open_file(self, index): - return torch.load(self.datapaths[index], weights_only=False) + """ + Opens the file with memory mapping. + This operation is virtually instant and consumes negligible RAM + because no data is actually read from disk yet. + """ + data_path = self.datapaths[index] + if data_path.endswith(".gz"): + with gzip.open(data_path, "rb") as f: + return torch.load(io.BytesIO(f.read()), weights_only=False) + return torch.load(data_path, weights_only=False, mmap=True) def __getitem__(self, index): try: @@ -482,7 +604,24 @@ def __getitem__(self, index): except Exception as e: print(f"ERROR Failed to load {self.datapaths[index]} with error {e}") data = self._open_file(0) - return self.process_data(data, self.max_len, self.transform) + + # 2. Read only specific bytes from disk + if self.use_usp_preprocess: + return self.process_data_usp( + data, + self.max_len, + ttt_length=self.ttt_length, + transform=self.transform, + sp_rank=self.sp_rank, + sp_size=self.sp_size, + ring_rank=self.ring_rank, + sp_ring_size=self.sp_ring_size, + ) + return self.process_data( + data, + self.max_len, + self.transform, + ) def set_epoch(self, epoch): self._epoch = epoch @@ -491,10 +630,15 @@ def set_epoch(self, epoch): def build_offline_eagle3_dataset( hidden_states_path: str, max_len: int = 2048, + ttt_length: int = 1, + use_usp_preprocess: bool = False, ) -> torch.utils.data.Dataset: + return OfflineEagle3Dataset( list_local_files(hidden_states_path), max_len=max_len, + ttt_length=ttt_length, + use_usp_preprocess=use_usp_preprocess, ) @@ -521,7 +665,7 @@ def generate_vocab_mapping_file( Returns: The path to the vocab mapping file. """ - # prepare cache direcotory + # prepare cache directory os.makedirs(cache_dir, exist_ok=True) vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt") @@ -529,7 +673,7 @@ def generate_vocab_mapping_file( print(f"Loading vocab mapping from the cached file at: {vocab_mapping_path}") return vocab_mapping_path - # we first count the frequency of effectiev tokens in the dataset + # we first count the frequency of effective tokens in the dataset token_dict = Counter() for input_ids, loss_mask in tqdm( zip(dataset["input_ids"], dataset["loss_mask"]), diff --git a/specforge/data/utils.py b/specforge/data/utils.py index 9668680dc..93fd6f58a 100644 --- a/specforge/data/utils.py +++ b/specforge/data/utils.py @@ -26,7 +26,7 @@ from torch.utils.data import DataLoader, DistributedSampler from datasets import Dataset -from specforge.distributed import get_draft_sp_group +from specforge.distributed import get_draft_sp_group, get_sp_ulysses_group class DataCollatorWithPadding: @@ -36,6 +36,7 @@ class DataCollatorWithPadding: def __init__(self): self.sp_degree = torch.distributed.get_world_size(get_draft_sp_group()) + self.ulysses_degree = torch.distributed.get_world_size(get_sp_ulysses_group()) def paddingtensor(self, intensors: torch.Tensor, N: int) -> torch.Tensor: """ @@ -90,10 +91,14 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: - loss_mask: torch.Tensor of shape (B, N) """ max_length = max(item["input_ids"].shape[1] for item in features) + # pad for sequence parrel max_length = ( (max_length + self.sp_degree - 1) // self.sp_degree ) * self.sp_degree + # position max len, ulysses do not need chuck position ids + position_max_len = max_length * self.ulysses_degree + batch_input_ids = torch.cat( [self.paddingtensor2D(item["input_ids"], max_length) for item in features] ) @@ -106,6 +111,15 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: batch_loss_mask = torch.cat( [self.paddingtensor2D(item["loss_mask"], max_length) for item in features] ) + if "position_ids" in features[0]: + batch_position_ids = torch.cat( + [ + self.paddingtensor2D(item["position_ids"], position_max_len) + for item in features + ] + ) + else: + batch_position_ids = None batch = { "input_ids": batch_input_ids, "attention_mask": batch_attention_mask, @@ -113,16 +127,23 @@ def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: "hidden_state": None, "target": None, } + if batch_position_ids is not None: + batch["position_ids"] = batch_position_ids if all("hidden_state" in item for item in features): assert all( "target" in item for item in features ), "target is required when hidden_state is provided" - batch["hidden_state"] = torch.cat( - [ - self.paddingtensor(item["hidden_state"], max_length) - for item in features - ] - ) + if self.sp_degree > 1: # USP mode + batch["hidden_state"] = torch.cat( + [item["hidden_state"] for item in features] + ) + else: + batch["hidden_state"] = torch.cat( + [ + self.paddingtensor(item["hidden_state"], max_length) + for item in features + ] + ) batch["target"] = torch.cat( [self.paddingtensor(item["target"], max_length) for item in features] ) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 552a3cf86..4a1833072 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -977,6 +977,10 @@ def __init__(self, config): assert ( dist.is_initialized() ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): + raise NotImplementedError( + f"LlamaMutiRotaryEmbedding is currently not supported for LlamaUSPFlashAttention." + ) self.ring_pg = get_sp_ring_group() self.ulysses_pg = get_sp_ulysses_group() self.sp_ring_degree = torch.distributed.get_world_size(self.ring_pg) @@ -1043,39 +1047,16 @@ def forward( # Global length calculation (for RoPE) global_q_len = q_len * self.sp_ring_degree * self.sp_ulysses_degree - # ============================================================= # 2. RoPE & Cache Management # ============================================================= - if self.sp_ring_degree > 1: - if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): - position_ids = position_ids.chunk(self.sp_ring_degree, dim=2)[ - self.ring_rank - ].clone() - else: - position_ids = position_ids.chunk(self.sp_ring_degree, dim=1)[ - self.ring_rank - ].clone() - lck = 0 if cache_hidden is None else len(cache_hidden[0]) - if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): - cos, sin = self.rotary_emb(query_states, position_ids + lck) - cos, sin = cos.to(query_states.device), sin.to(query_states.device) - query_states, key_states = apply_multimodal_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - self.config.rope_scaling["mrope_section"], - unsqueeze_dim=2, - ) - else: - cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) - cos, sin = cos.to(query_states.device), sin.to(query_states.device) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 - ) + cos, sin = self.rotary_emb(query_states, seq_len=global_q_len + lck) + cos, sin = cos.to(query_states.device), sin.to(query_states.device) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + lck, unsqueeze_dim=2 + ) # Update Cache (Eagle3 Logic: Cache is a list of tensors for tree branches) if cache_hidden is not None: diff --git a/tests/test_layers/test_decoder.py b/tests/test_layers/test_decoder.py index aca4f3c24..db6d785db 100644 --- a/tests/test_layers/test_decoder.py +++ b/tests/test_layers/test_decoder.py @@ -1,4 +1,5 @@ import os +import time import unittest import torch @@ -8,6 +9,9 @@ from transformers import PretrainedConfig from yunchang import EXTRACT_FUNC_DICT +from specforge.core.eagle3_adapters import SdpaLikeAdapter, UspAdapter +from specforge.data.preprocessing import build_offline_eagle3_dataset + # Project-specific imports from specforge.distributed import destroy_distributed, init_distributed from specforge.modeling.draft.llama3_eagle import LlamaDecoderLayer @@ -57,6 +61,19 @@ def setup_env(rank, world_size, port): torch.cuda.set_device(rank) +def dbg(rank, msg): + print(f"[rank{rank}] {msg}", flush=True) + + +def wait_for_file(path, timeout_s=60, poll_s=0.1): + start = time.time() + while time.time() - start < timeout_s: + if os.path.exists(path): + return True + time.sleep(poll_s) + return False + + def run_iterative_pass( decoder_layer, embed_tokens, @@ -113,6 +130,7 @@ def run_test_case(rank, world_size, port): setup_env(rank, world_size, port) device = torch.device(f"cuda:{rank}") set_seed(42) + dbg(rank, "env setup complete") # --- Data & Config Preparation --- config = get_model_config() @@ -135,13 +153,34 @@ def run_test_case(rank, world_size, port): config.vocab_size, config.hidden_size, config.pad_token_id ).to(device) - # --- Phase 1: Golden Run (SDPA) --- + # --- Phase 1: Golden Run (FA) --- # Init dist briefly for internal checks, even if running single-device logic init_distributed(tp_size=1, sp_ulysses_size=1, sp_ring_size=1) + dbg(rank, "init_distributed (FA) done") sdpa_decoder = ( LlamaDecoderLayer(config, attention_backend="fa").to(device).to(torch.bfloat16) ) + dbg(rank, "FA decoder created") + # Adapter smoke test for FA/SDPA-style path + dummy_model = type("Dummy", (), {})() + sdpa_adapter = SdpaLikeAdapter(dummy_model) + sdpa_target_p = torch.zeros((1, seq_len, 8), device=device, dtype=torch.float32) + sdpa_position_mask = torch.ones((1, seq_len, 1), device=device, dtype=torch.float32) + sdpa_state = sdpa_adapter.step_view( + idx=0, + ttt_length=ttt_length, + global_input_ids=data_input_ids, + attention_mask=attention_mask, + loss_mask=torch.ones((1, seq_len, 1), device=device, dtype=torch.float32), + position_ids=position_ids, + hidden_states=data_hidden_states, + target_p_padded=sdpa_target_p, + position_mask=sdpa_position_mask, + seq_length=seq_len, + ) + assert sdpa_state.input_ids.shape[1] == seq_len + assert sdpa_state.hidden_states.shape[1] == seq_len with torch.no_grad(): sdpa_output = run_iterative_pass( @@ -153,11 +192,13 @@ def run_test_case(rank, world_size, port): position_ids=position_ids, ttt_length=ttt_length, ) + dbg(rank, "FA forward done") # Save weights for alignment and cleanup SDPA model state_dict = sdpa_decoder.state_dict() del sdpa_decoder destroy_distributed() + dbg(rank, "destroy_distributed (FA) done") # --- Phase 2: Distributed Run (USP) --- def subtest_usp(sp_ulysses_degree, sp_ring_degree): @@ -168,6 +209,89 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): sp_ulysses_size=sp_ulysses_degree, sp_ring_size=sp_ring_degree, ) + dbg( + rank, + f"init_distributed (USP U{sp_ulysses_degree} R{sp_ring_degree}) done", + ) + # Dataset + adapter smoke test (USP path) + tmp_dir = "./tmp/usp_dataset_shared" + try: + if rank == 0: + os.makedirs(tmp_dir, exist_ok=True) + sample = { + "input_ids": data_input_ids[0].cpu(), + "loss_mask": torch.ones_like(data_input_ids[0].cpu()), + "hidden_state": data_hidden_states[0].cpu().unsqueeze(0), + "aux_hidden_state": data_hidden_states[0].cpu().unsqueeze(0), + } + torch.save(sample, os.path.join(tmp_dir, "data_0.ckpt")) + dbg(rank, "wrote sample ckpt") + ready_flag = os.path.join(tmp_dir, "ready.flag") + with open(ready_flag, "w", encoding="utf-8") as f: + f.write("ready\n") + if rank != 0: + ready_flag = os.path.join(tmp_dir, "ready.flag") + assert wait_for_file( + ready_flag, timeout_s=60 + ), "timeout waiting for ready flag" + dbg(rank, "dataset sync done") + assert os.path.exists( + os.path.join(tmp_dir, "data_0.ckpt") + ), f"Expected sample not found at {tmp_dir}" + dbg(rank, "sample exists") + + ds = build_offline_eagle3_dataset( + tmp_dir, + max_len=seq_len, + ttt_length=ttt_length, + use_usp_preprocess=True, + ) + dbg(rank, "dataset built") + item = ds[0] + dbg(rank, "dataset item loaded") + assert "position_ids" in item + + dummy_model = type("Dummy", (), {})() + adapter = UspAdapter(dummy_model) + local_seq_len = item["input_ids"].shape[1] + target_p_padded = torch.zeros( + (1, local_seq_len, 8), device=device, dtype=torch.float32 + ) + position_mask = torch.ones( + (1, local_seq_len, 1), device=device, dtype=torch.float32 + ) + state = adapter.step_view( + idx=0, + ttt_length=ttt_length, + global_input_ids=item["input_ids"].to(device), + attention_mask=item["attention_mask"].to(device), + loss_mask=item["loss_mask"].to(device).unsqueeze(-1), + position_ids=item["position_ids"].to(device), + hidden_states=item["hidden_state"].to(device), + target_p_padded=target_p_padded, + position_mask=position_mask, + seq_length=local_seq_len, + ) + assert state.input_ids.shape[1] == local_seq_len - ttt_length + assert state.hidden_states.shape[1] == local_seq_len - ttt_length + dbg(rank, "adapter step_view ok") + finally: + if rank == 0: + done_flag = os.path.join(tmp_dir, "done.flag") + assert wait_for_file( + done_flag, timeout_s=60 + ), "timeout waiting for done flag" + try: + for root, _, files in os.walk(tmp_dir): + for name in files: + os.remove(os.path.join(root, name)) + os.rmdir(tmp_dir) + except OSError: + pass + else: + done_flag = os.path.join(tmp_dir, "done.flag") + with open(done_flag, "w", encoding="utf-8") as f: + f.write("done\n") # Init USP model and load golden weights usp_decoder = ( @@ -176,6 +300,7 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): .to(torch.bfloat16) ) usp_decoder.load_state_dict(state_dict) + dbg(rank, "USP decoder loaded") # Shard data (Split Input) extract_func = EXTRACT_FUNC_DICT["basic"] @@ -203,24 +328,41 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): .detach() .clone() ) + dbg(rank, "USP local inputs prepared") + total_degree = sp_ring_degree * sp_ulysses_degree + chunk_size = sdpa_output.shape[1] // total_degree + start_idx = (rank % total_degree) * chunk_size + local_len = local_input_ids.shape[1] + local_position_ids = ( + torch.arange(start_idx, start_idx + local_len, device=device) + .unsqueeze(0) + .long() + ) + local_attention_mask = torch.tril( + torch.ones(local_len, local_len, device=device) + ).view(1, 1, local_len, local_len) # Run USP forward + if sp_ring_degree > 1: + usp_attention_mask = local_attention_mask + usp_position_ids = local_position_ids + else: + usp_attention_mask = attention_mask + usp_position_ids = position_ids with torch.no_grad(): usp_output = run_iterative_pass( decoder_layer=usp_decoder, embed_tokens=embed_tokens, input_ids=local_input_ids, hidden_states=local_hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, + attention_mask=usp_attention_mask, + position_ids=usp_position_ids, ttt_length=ttt_length, ) + dbg(rank, "USP forward done") # Verify results # Slice the golden output to match the current rank's chunk - total_degree = sp_ring_degree * sp_ulysses_degree - chunk_size = sdpa_output.shape[1] // total_degree - start_idx = (rank % total_degree) * chunk_size end_idx = start_idx + chunk_size golden_chunk = sdpa_output[:, start_idx:end_idx, :] @@ -229,9 +371,11 @@ def subtest_usp(sp_ulysses_degree, sp_ring_degree): f"[Rank {rank}] USP (U{sp_ulysses_degree}R{sp_ring_degree}) mismatch!\n" f"Max Diff: {(usp_output - golden_chunk).abs().max().item()}" ) + dbg(rank, "USP output verified") finally: destroy_distributed() + dbg(rank, "destroy_distributed (USP) done") # Case 1: Hybrid (Ulysses=2, Ring=1) subtest_usp(sp_ulysses_degree=2, sp_ring_degree=1)