USP Offline Long‑Sequence Refactor: SP Sharded Inputs, Adapter Unification, and Compressed Hidden States#454
Conversation
Summary of ChangesHello @uygnef, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant optimization to the offline dataloader, particularly for training with sequence parallelism. By integrating memory-mapped loading and moving the hidden state sharding logic into the dataset itself, the system can now handle larger sequence lengths more efficiently, drastically reducing memory footprint and improving data loading performance during training. These changes ensure that each processing unit only handles the necessary data segments, streamlining the overall training pipeline. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces optimizations to the offline data loader for sequence parallelism (SP) by moving hidden state sharding into the dataset with memory-mapped file loading. This aims to reduce memory usage and I/O overhead. The changes include argument validation, modifications to the Eagle3 model's forward pass, and significant updates to the data preprocessing pipeline to support sharded hidden states. The code has been reviewed, and suggestions have been provided to address potential issues with batch size constraints and to improve code clarity.
37f8da5 to
22c8a14
Compare
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
74e565a to
27f626b
Compare
|
Seems the test fails, do you get it correct in your local env? |
2. usp does not support online mode 3. loss shared for usp, split target 4. Tighten USP batching and padding behavior 5. optimize offline dataloader for sp
@FrankLeeeee Fixed. The tests were already passing locally – I just missed the import for the new debug function when committing. All good now, tests are passing. |
| def dbg(rank, msg): | ||
| print(f"[rank{rank}] {msg}", flush=True) |
There was a problem hiding this comment.
replace with a logger instead?
|
Have you run some experiments to show that the accuracy does not get affected? |
Based on experiments, both training loss and accuracy remain unchanged. The accept length metrics are also nearly identical(note: dataset different from zip test): Split: 1.8971143677292643 |
|
@uygnef [rank0]: File "/ossfs/workspace/SpecForge/specforge/modeling/draft/llama3_eagle.py", line 1283, in forward
[rank0]: hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 8072 but got size 8067 for tensor number 1 in the list.and my config is: --batch-size 1
--tp-size 1
--learning-rate 5e-5
--max-length 8192
--chat-template deepseek-v3
--cache-dir ./cache
--sp-ulysses-size 2
--attention-backend usp |
There doesn’t seem to be an issue with your parameters, but the error indicates that the sequence hasn’t been split. Can you paste all the parameters printed in the log? |
|
@uygnef target_model_path /dsv3_SFT_nothink_0819_multi_round_aug_backup-iter_0000078_fp8
trust_remote_code ······· False
draft_model_config ········ None
embedding_key model.embed_tokens.weight
lm_head_key lm_head.weight
is_vlm ······· False
target_model_backend ······ sglang
train_data_path /eagle3_train_dsv3_chat_data_202509_and_202510_8k.jsonl
train_hidden_states_path /dsv3_eagle3_hideen_states
eval_hidden_states_path ········ None
eval_data_path ········ None
chat_template · deepseek-v3
is_preformatted ······· False
train_only_last_turn ······· False
build_dataset_num_proc ··········· 8
dataloader_num_workers ··········· 4
num_epochs ··········· 2
max_num_steps ········ None
batch_size ··········· 1
learning_rate ······· 5e-05
max_length ········ 8192
warmup_ratio ······· 0.015
total_steps ········ None
max_grad_norm ········· 0.5
ttt_length ··········· 7
resume ······· False
ckpt_dir /DeepSeek-V3-EAGLE3
eval_interval ········ 5000
save_interval ········· 500
log_interval ·········· 50
seed ··········· 0
draft_accumulation_steps ··········· 4
tp_size ··········· 1
sp_ulysses_size ··········· 2
sp_ring_size ··········· 1
attention_backend ········· usp
cache_key ········ None
cache_dir ····· ./cache
output_dir /DeepSeek-V3-EAGLE3-ContinueSFT/20260209
verbose ······· False
dist_timeout ·········· 20
model_download_dir ········ None
min_pixels ······· 50176
max_pixels ······ 802816
profile ······· False
profile_start_step ·········· 30
profile_num_steps ··········· 4
profile_record_shapes ······· False
sglang_attention_backend ·· flashinfer
sglang_mem_fraction_static ········· 0.4
sglang_context_length ········ None
sglang_enable_nccl_nvls ······· False
sglang_enable_symm_mem ······· False
sglang_enable_torch_compile ······· False
sglang_enable_dp_attention ······· False
sglang_enable_dp_lm_head ······· False
sglang_enable_piecewise_cuda_graph ······· False
sglang_piecewise_cuda_graph_max_tokens ········ 4096
sglang_piecewise_cuda_graph_tokens ········ None
sglang_ep_size ··········· 1
report_to ········ none
wandb_project ········ None
wandb_name ········ None
wandb_key ········ None
swanlab_project ········ None
swanlab_name ········ None
swanlab_key ········ None
mlflow_tracking_uri ········ None
mlflow_experiment_name ········ None
mlflow_run_name ········ None
dp_size ··········· 8
target_batch_size ··········· 1The reason seems lie in the hidden states reading code, which may get different datapath rank in different ranks def list_local_files(path, suffixes=None):
if suffixes is None:
suffixes = [".ckpt", ".ckpt.gz"]
datapaths = []
for root, directories, files in os.walk(path):
for file in files:
file_path = os.path.join(root, file)
datapaths.append(file_path)
if suffixes:
datapaths = [
f_name
for f_name in datapaths
if any(f_name.endswith(suffix) for suffix in suffixes)
]
datapaths.sort() # Sort to ensure deterministic order across ranks
return datapathsAnd then everything goes well. Maybe my env fault ? I'm not sure. |
|
@jiahang01 Thank you for pointing this out. Would you be willing to open a pull request with this fix? |
Sure, I'll work on it |
Motivation
loss/metric reductions, keeping position_ids handling consistent.
With these changes, 8 GPUs can now support training on sequences up to 128K length.
Modifications
Related Issues
Accuracy Test
split input
Benchmark & Profiling
Compression is accuracy-neutral:
SP-sharded inputs:
accept length unchanged
Checklist