Skip to content

USP Offline Long‑Sequence Refactor: SP Sharded Inputs, Adapter Unification, and Compressed Hidden States#454

Merged
jiapingW merged 1 commit intosgl-project:mainfrom
uygnef:opt/sp
Feb 5, 2026
Merged

USP Offline Long‑Sequence Refactor: SP Sharded Inputs, Adapter Unification, and Compressed Hidden States#454
jiapingW merged 1 commit intosgl-project:mainfrom
uygnef:opt/sp

Conversation

@uygnef
Copy link
Copy Markdown
Collaborator

@uygnef uygnef commented Jan 27, 2026

Motivation

  • For long sequences, storing full hidden states and targets on GPU is too memory‑heavy. We need to shard along sequence‑parallel (SP) at input time so memory scales with SP degree.
  • USP/sequence‑parallel paths diverged between offline and online training, which made loss/metric computation and position_ids handling fragile; therefore we introduced adapters to unify backend-specific step views and distributed
    loss/metric reductions, keeping position_ids handling consistent.

With these changes, 8 GPUs can now support training on sequences up to 128K length.

Modifications

  1. Add specforge/core/eagle3_adapters.py to abstract SDPA/USP step views and distributed loss/metric reductions.
  2. Offline dataset reads hidden states via mmap and only loads the local shard needed by the current SP rank.
  3. Remove long‑sequence support in online mode (target models are too large for VRAM). Long‑sequence online needs a redesign. In the near term, we’ll focus on offline mode and add compression to mitigate disk usage.
  4. prepare_hidden_states.py adds --compress to gzip hidden states, reducing storage by ~20%.

Related Issues

Accuracy Test

split input

image image

Benchmark & Profiling

Compression is accuracy-neutral:

  • accept length stays the same for compressed vs. uncompressed runs (1.93 vs. 1.93).
  • training throughput is slightly lower due to decompression overhead: time/step increases from 0.45s to 0.47s (~4% slower).
setting accept length time/step (s) disk usage
gzip (compressed) 1.93 0.47 659G
uncompressed 1.93 0.45 822G

SP-sharded inputs:
accept length unchanged

Checklist

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @uygnef, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization to the offline dataloader, particularly for training with sequence parallelism. By integrating memory-mapped loading and moving the hidden state sharding logic into the dataset itself, the system can now handle larger sequence lengths more efficiently, drastically reducing memory footprint and improving data loading performance during training. These changes ensure that each processing unit only handles the necessary data segments, streamlining the overall training pipeline.

Highlights

  • Memory-Mapped Data Loading: The OfflineEagle3Dataset now utilizes memory-mapped loading (mmap=True) for hidden states, significantly reducing initial RAM consumption and I/O overhead by only loading data into memory when accessed.
  • Integrated Sequence Parallelism (SP) Sharding: Hidden state sharding for Sequence Parallelism (SP) is now performed directly within the OfflineEagle3Dataset during data loading. This ensures each GPU only loads its specific chunk of hidden states, optimizing VRAM and RAM usage.
  • Dynamic Position ID Generation for USP: The eagle3.py core logic has been updated to dynamically generate position IDs for the USP (Unified Sequence Parallelism) attention backend, ensuring correct positional embeddings for sharded sequences.
  • Data Collator Adaptation: The DataCollatorForEagle3 has been modified to handle pre-sharded hidden states when SP is active, concatenating them directly without additional padding, as sharding is now managed by the dataset.
  • USP Configuration Validation: New validation checks have been added to train_eagle3.py to enforce batch_size=1 and sp_ring_size * sp_ulysses_size > 1 when using the USP attention backend, preventing misconfigurations.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimizations to the offline data loader for sequence parallelism (SP) by moving hidden state sharding into the dataset with memory-mapped file loading. This aims to reduce memory usage and I/O overhead. The changes include argument validation, modifications to the Eagle3 model's forward pass, and significant updates to the data preprocessing pipeline to support sharded hidden states. The code has been reviewed, and suggestions have been provided to address potential issues with batch size constraints and to improve code clarity.

Comment thread specforge/data/preprocessing.py
Comment thread scripts/train_eagle3.py Outdated
Comment thread specforge/core/eagle3.py Outdated
Comment thread specforge/data/utils.py Outdated
Comment thread specforge/modeling/draft/llama3_eagle.py Outdated
@uygnef uygnef force-pushed the opt/sp branch 3 times, most recently from 37f8da5 to 22c8a14 Compare February 3, 2026 11:47
@uygnef uygnef changed the title optimize offline dataloader for sp USP Offline Long‑Sequence Refactor: SP Sharded Inputs, Adapter Unification, and Compressed Hidden States Feb 3, 2026
@uygnef uygnef marked this pull request as ready for review February 4, 2026 06:45
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@uygnef uygnef force-pushed the opt/sp branch 2 times, most recently from 74e565a to 27f626b Compare February 4, 2026 06:50
@FrankLeeeee
Copy link
Copy Markdown
Collaborator

Seems the test fails, do you get it correct in your local env?

2. usp does not support online mode
3. loss shared for usp, split target
4. Tighten USP batching and padding behavior
5. optimize offline dataloader for sp
@uygnef
Copy link
Copy Markdown
Collaborator Author

uygnef commented Feb 4, 2026

Seems the test fails, do you get it correct in your local env?

@FrankLeeeee Fixed. The tests were already passing locally – I just missed the import for the new debug function when committing. All good now, tests are passing.

Comment on lines +64 to +65
def dbg(rank, msg):
print(f"[rank{rank}] {msg}", flush=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace with a logger instead?

@FrankLeeeee
Copy link
Copy Markdown
Collaborator

Have you run some experiments to show that the accuracy does not get affected?

@uygnef
Copy link
Copy Markdown
Collaborator Author

uygnef commented Feb 4, 2026

Have you run some experiments to show that the accuracy does not get affected?

Based on experiments, both training loss and accuracy remain unchanged.
These results have been added to the Accuracy Test section above for reference.

The accept length metrics are also nearly identical(note: dataset different from zip test):

Split: 1.8971143677292643
Baseline: 1.8911144578313253

@jiapingW jiapingW merged commit 6c27152 into sgl-project:main Feb 5, 2026
2 checks passed
@uygnef uygnef deleted the opt/sp branch February 6, 2026 08:06
@jiahang01
Copy link
Copy Markdown
Contributor

jiahang01 commented Feb 11, 2026

@uygnef
hello, could you provide a training script example or some sp params to be set ? I'v tried offline sp training ,and it kept going error.

[rank0]:   File "/ossfs/workspace/SpecForge/specforge/modeling/draft/llama3_eagle.py", line 1283, in forward
[rank0]:     hidden_states = torch.cat((input_emb, hidden_states), dim=-1)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 8072 but got size 8067 for tensor number 1 in the list.

and my config is:

--batch-size 1 
--tp-size 1 
--learning-rate 5e-5 
--max-length 8192 
--chat-template deepseek-v3 
--cache-dir ./cache 
--sp-ulysses-size 2
--attention-backend usp

@uygnef
Copy link
Copy Markdown
Collaborator Author

uygnef commented Feb 11, 2026

torchrun \
    --standalone \
    --nproc_per_node ${NPROC_PER_NODE:-8} \
    scripts/train_eagle3.py \
    --target-model-path models/Meta-Llama-3.1-8B-Instruct \
    --draft-model-config $ROOT_DIR/configs/llama3-8B-eagle3.json \
    --train-data-path datasets/sharegpt_train_1w.jsonl \
    --train-hidden-states-path hidden_states/sharegpt_train_Llama-3.1-8B-Instruct-1w \
    --build-dataset-num-proc 128 \
    --output-dir $ROOT_DIR/outputs/baseline \
    --num-epochs 10 \
    --batch-size 1 \
    --tp-size 1 \
    --target-model-backend sglang \
    --learning-rate 1e-4 \
    --max-length 4096 \
    --chat-template llama3 \
    --report-to tensorboard \
    --sp-ulysses-size 2 \
    --attention-backend usp \
    --cache-dir $ROOT_DIR/cache 

There doesn’t seem to be an issue with your parameters, but the error indicates that the sequence hasn’t been split. Can you paste all the parameters printed in the log?

@jiahang01
Copy link
Copy Markdown
Contributor

@uygnef
Here is my parameters

target_model_path                       /dsv3_SFT_nothink_0819_multi_round_aug_backup-iter_0000078_fp8
trust_remote_code                      ······· False
draft_model_config                     ········ None
embedding_key                           model.embed_tokens.weight
lm_head_key                             lm_head.weight
is_vlm                                 ······· False
target_model_backend                   ······ sglang
train_data_path                         /eagle3_train_dsv3_chat_data_202509_and_202510_8k.jsonl
train_hidden_states_path                /dsv3_eagle3_hideen_states
eval_hidden_states_path                ········ None
eval_data_path                         ········ None
chat_template                          · deepseek-v3
is_preformatted                        ······· False
train_only_last_turn                   ······· False
build_dataset_num_proc                 ··········· 8
dataloader_num_workers                 ··········· 4
num_epochs                             ··········· 2
max_num_steps                          ········ None
batch_size                             ··········· 1
learning_rate                          ······· 5e-05
max_length                             ········ 8192
warmup_ratio                           ······· 0.015
total_steps                            ········ None
max_grad_norm                          ········· 0.5
ttt_length                             ··········· 7
resume                                 ······· False
ckpt_dir                                /DeepSeek-V3-EAGLE3
eval_interval                          ········ 5000
save_interval                          ········· 500
log_interval                           ·········· 50
seed                                   ··········· 0
draft_accumulation_steps               ··········· 4
tp_size                                ··········· 1
sp_ulysses_size                        ··········· 2
sp_ring_size                           ··········· 1
attention_backend                      ········· usp
cache_key                              ········ None
cache_dir                              ····· ./cache
output_dir                              /DeepSeek-V3-EAGLE3-ContinueSFT/20260209
verbose                                ······· False
dist_timeout                           ·········· 20
model_download_dir                     ········ None
min_pixels                             ······· 50176
max_pixels                             ······ 802816
profile                                ······· False
profile_start_step                     ·········· 30
profile_num_steps                      ··········· 4
profile_record_shapes                  ······· False
sglang_attention_backend               ·· flashinfer
sglang_mem_fraction_static             ········· 0.4
sglang_context_length                  ········ None
sglang_enable_nccl_nvls                ······· False
sglang_enable_symm_mem                 ······· False
sglang_enable_torch_compile            ······· False
sglang_enable_dp_attention             ······· False
sglang_enable_dp_lm_head               ······· False
sglang_enable_piecewise_cuda_graph     ······· False
sglang_piecewise_cuda_graph_max_tokens ········ 4096
sglang_piecewise_cuda_graph_tokens     ········ None
sglang_ep_size                         ··········· 1
report_to                              ········ none
wandb_project                          ········ None
wandb_name                             ········ None
wandb_key                              ········ None
swanlab_project                        ········ None
swanlab_name                           ········ None
swanlab_key                            ········ None
mlflow_tracking_uri                    ········ None
mlflow_experiment_name                 ········ None
mlflow_run_name                        ········ None
dp_size                                ··········· 8
target_batch_size                      ··········· 1

The reason seems lie in the hidden states reading code, which may get different datapath rank in different ranks
, the Gemini change the function as below:

def list_local_files(path, suffixes=None):
    if suffixes is None:
        suffixes = [".ckpt", ".ckpt.gz"]
    datapaths = []
    for root, directories, files in os.walk(path):
        for file in files:
            file_path = os.path.join(root, file)
            datapaths.append(file_path)
    if suffixes:
        datapaths = [
            f_name
            for f_name in datapaths
            if any(f_name.endswith(suffix) for suffix in suffixes)
        ]
    datapaths.sort()  # Sort to ensure deterministic order across ranks
    return datapaths

And then everything goes well. Maybe my env fault ? I'm not sure.

@uygnef
Copy link
Copy Markdown
Collaborator Author

uygnef commented Feb 12, 2026

@jiahang01 Thank you for pointing this out. Would you be willing to open a pull request with this fix?

@jiahang01
Copy link
Copy Markdown
Contributor

@jiahang01 Thank you for pointing this out. Would you be willing to open a pull request with this fix?

Sure, I'll work on it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants