Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions scripts/prepare_hidden_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

import argparse
import gc
import gzip
import hashlib
import os
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -58,6 +59,7 @@
)
from specforge.modeling.target import Eagle3TargetModel, get_eagle3_target_model
from specforge.utils import (
print_args_with_dots,
print_with_rank,
rank_0_priority,
safe_conversations_generator,
Expand Down Expand Up @@ -123,8 +125,8 @@ def parse_args():
others_group.add_argument(
"--num-io-threads",
type=int,
default=4,
help="Number of threads for async I/O operations",
default=None,
help="Number of threads for async I/O operations (default: all of CPU cores).",
)
others_group.add_argument(
"--num-workers", type=int, default=4, help="Number of workers for DataLoader"
Expand All @@ -141,6 +143,17 @@ def parse_args():
default=2000,
help="Number of files per subdirectory.",
)
others_group.add_argument(
"--compress",
action="store_true",
help="Compress hidden state files on disk (gzip).",
)
others_group.add_argument(
"--compression-level",
type=int,
default=6,
help="Gzip compression level (1-9).",
)

sglang_group = parser.add_argument_group("sglang")
SGLangBackendArgs.add_args(sglang_group)
Expand Down Expand Up @@ -215,6 +228,8 @@ def __init__(
num_io_threads: int = 4,
io_queue_size: int = 50,
file_group_size: int = 2000,
compress: bool = False,
compression_level: int = 6,
):
"""
Args:
Expand All @@ -231,6 +246,9 @@ def __init__(
self.num_io_threads = num_io_threads
self.io_queue_size = io_queue_size
self.file_group_size = file_group_size
self.compress = compress
self.compression_level = compression_level
self.file_extension = ".ckpt.gz" if self.compress else ".ckpt"

# progress bar should only shown on TP rank = 0
self.show_progress = dist.get_rank(get_tp_group()) == 0
Expand Down Expand Up @@ -282,7 +300,13 @@ def _save_tensor_sync(self, data_point: DataPoint, output_file: str) -> None:
)
return

torch.save(asdict(data_point), output_file)
if self.compress:
with gzip.open(
output_file, "wb", compresslevel=self.compression_level
) as f:
torch.save(asdict(data_point), f)
else:
torch.save(asdict(data_point), output_file)

def _save_tensor_async(self, data_point: DataPoint, output_file: str) -> None:
"""
Expand Down Expand Up @@ -365,14 +389,22 @@ def _check_existing_files_batch(
return [False] * len(global_indices)

def check_single_file(idx):
return os.path.exists(self._get_file_path(output_path, idx))
if os.path.exists(self._get_file_path(output_path, idx)):
return True
legacy_ckpt = self._get_file_path(output_path, idx, extension=".ckpt")
compressed_ckpt = self._get_file_path(
output_path, idx, extension=".ckpt.gz"
)
return os.path.exists(legacy_ckpt) or os.path.exists(compressed_ckpt)

# Parallel file existence check
with ThreadPoolExecutor(max_workers=self.num_io_threads) as executor:
exists = list(executor.map(check_single_file, global_indices))
return exists

def _get_file_path(self, output_path: str, idx: int) -> str:
def _get_file_path(
self, output_path: str, idx: int, extension: Optional[str] = None
) -> str:
"""
A helper function to get the standard file path for the data point with the given index.

Expand All @@ -383,9 +415,10 @@ def _get_file_path(self, output_path: str, idx: int) -> str:
Returns:
str: The file path for the data point.
"""
ext = self.file_extension if extension is None else extension
group_idx = (idx // self.file_group_size) * self.file_group_size
grouped_subdir = f"rows_{group_idx}-{group_idx + self.file_group_size}"
return os.path.join(output_path, grouped_subdir, f"data_{idx}.ckpt")
return os.path.join(output_path, grouped_subdir, f"data_{idx}{ext}")

@torch.no_grad()
def generate(
Expand Down Expand Up @@ -553,9 +586,12 @@ def main():
args.aux_hidden_states_layers = [
int(x) for x in args.aux_hidden_states_layers.split(",")
]

if args.num_io_threads is None:
cpu_cores = os.cpu_count() or 1
args.num_io_threads = max(1, cpu_cores)
# Initialize distributed environment (TP + DP)
init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
print_args_with_dots(args)

# Build target model (with TP)
target_model_config = AutoConfig.from_pretrained(
Expand Down Expand Up @@ -657,6 +693,8 @@ def main():
num_io_threads=args.num_io_threads,
io_queue_size=args.io_queue_size,
file_group_size=args.file_group_size,
compress=args.compress,
compression_level=args.compression_level,
# Other params like io_queue_size can also be added to argparse
) as hidden_states_generator:

Expand Down
82 changes: 32 additions & 50 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
destroy_distributed,
get_dp_group,
get_draft_dp_group,
get_draft_sp_group,
get_tp_group,
init_distributed,
)
Expand Down Expand Up @@ -340,9 +339,24 @@ def sanity_check(args: Namespace) -> None:
"""
args.dp_size = dist.get_world_size() // args.tp_size
args.target_batch_size = args.tp_size * args.batch_size
if args.attention_backend == "usp":
sp_sanity_check(args)


def sp_sanity_check(args: Namespace) -> None:
args.draft_accumulation_steps = (
args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size
)
assert (
args.batch_size == 1
), f"USP only supports batch_size=1, got batch_size={args.batch_size}"

assert args.sp_ring_size * args.sp_ulysses_size > 1, (
f"USP requires sp_ring_size * sp_ulysses_size > 1. "
f"Got sp_ring_size={args.sp_ring_size}, sp_ulysses_size={args.sp_ulysses_size}."
)

assert args.train_hidden_states_path is not None, f"USP only support offline mode"

if args.eval_data_path is not None and args.eval_hidden_states_path is not None:
raise ValueError(
Expand Down Expand Up @@ -453,6 +467,8 @@ def build_dataloaders(
train_eagle3_dataset = build_offline_eagle3_dataset(
args.train_hidden_states_path,
args.max_length,
ttt_length=args.ttt_length,
use_usp_preprocess=(args.attention_backend == "usp"),
)

train_dataloader = prepare_dp_dataloaders(
Expand Down Expand Up @@ -488,6 +504,8 @@ def build_dataloaders(
eval_eagle3_dataset = build_offline_eagle3_dataset(
args.eval_hidden_states_path,
args.max_length,
ttt_length=args.ttt_length,
use_usp_preprocess=(args.attention_backend == "usp"),
)
eval_dataloader = prepare_dp_dataloaders(
eval_eagle3_dataset,
Expand Down Expand Up @@ -619,6 +637,9 @@ def run_forward(
loss_mask=loss_mask,
target=target,
hidden_states=hidden_states,
position_ids=(
data["position_ids"].cuda() if "position_ids" in data else None
),
image_grid_thw=image_grid_thw,
is_vlm=args.is_vlm,
)
Expand Down Expand Up @@ -675,56 +696,13 @@ def record_metrcs(
tracker.log(logdict, step=global_step)


def get_dp_data_shard_from_tp(tensor: torch.Tensor, sp_dim: int = 1) -> torch.Tensor:
def get_dp_data_shard_from_tp(tensor: torch.Tensor) -> torch.Tensor:
"""
Process: TP split -> Pad to Max Len -> SP gather.
Get the data shard from the tensor.
"""
# 1. TP: Slice the tensor along the batch dimension
tp_group = get_tp_group()
tp_size = dist.get_world_size(tp_group)
tp_rank = dist.get_rank(tp_group)

local_tp_shard = tensor.chunk(tp_size, dim=0)[tp_rank]

# 2. SP: Handle dynamic sequence lengths and Gather
sp_group = get_draft_sp_group()

if sp_group is not None and dist.get_world_size(sp_group) > 1:
sp_world_size = dist.get_world_size(sp_group)
local_seq_len = local_tp_shard.size(sp_dim)

# Find global max sequence length in SP group
len_tensor = torch.tensor(
[local_seq_len], device=local_tp_shard.device, dtype=torch.long
)
dist.all_reduce(len_tensor, op=dist.ReduceOp.MAX, group=sp_group)
max_seq_len = len_tensor.item()

# Pad local tensor if necessary
# Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1
if local_seq_len < max_seq_len:
pad_size = max_seq_len - local_seq_len

pad_config = [0] * (local_tp_shard.ndim * 2)

pad_idx = (local_tp_shard.ndim - 1 - sp_dim) * 2 + 1
pad_config[pad_idx] = pad_size

# Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
local_tp_shard_padded = nn.F.pad(local_tp_shard, pad_config, value=0)
else:
local_tp_shard_padded = local_tp_shard

gathered_shards = [
torch.empty_like(local_tp_shard_padded) for _ in range(sp_world_size)
]
dist.all_gather(
gathered_shards, local_tp_shard_padded.contiguous(), group=sp_group
)

return torch.cat(gathered_shards, dim=sp_dim)

return local_tp_shard
tp_size = dist.get_world_size(get_tp_group())
tp_rank = dist.get_rank(get_tp_group())
return tensor.chunk(tp_size, dim=0)[tp_rank]


def main():
Expand Down Expand Up @@ -890,7 +868,11 @@ def main():
# 7.1 Training Step
# ================================================
plosses, acces = run_forward(
args, eagle3_model, data, target_model, is_online
args,
eagle3_model,
data,
target_model,
is_online,
)
run_backward_and_update(args, plosses, optimizer, global_step)

Expand Down
Loading