Mini Trainer exposes a memory-efficient loading strategy for large models by splitting initialization into three phases. This document focuses on the contributor-facing details behind the lazy-init + FSDP2 flow used by both SFT and OSFT.
graph LR
A[Phase 1 prepare_model_for_fsdp2] --> B[Phase 2 wrap_fsdp2]
B --> C[Phase 3 finalize_model_initialization]
A -->|ModelInitializationContext| C
subgraph Phase 1: Entry Points
A1[setup_sft_model_distributed]:::sft
A2[setup_osft_model_distributed]:::osft
end
classDef sft fill:#e0f7ff,stroke:#0aa3d8,color:#0a4b6f;
classDef osft fill:#f9e4ff,stroke:#a949c5,color:#4a155e;
- Phase 1 –
prepare_model_for_fsdp2- Detects whether the model was created via
setup_sft_model_distributed(SFT) orsetup_osft_model_distributed(OSFT). - Rank 0 loads the full checkpoint on CPU, records the state dict + buffers, and stores that payload inside
ModelInitializationContext. - All other ranks instantiate the model on the
metadevice and rely on the context for later population.
- Detects whether the model was created via
- Phase 2 –
wrap_fsdp2- Applies activation checkpointing to each transformer block, builds the device mesh, and calls
fully_shardon both individual blocks and the top-level module. - No weight loading happens here; it only prepares the structure for sharded weights.
- Applies activation checkpointing to each transformer block, builds the device mesh, and calls
- Phase 3 –
finalize_model_initialization- Consumes the context from Phase 1 to populate the sharded parameters:
- SFT: broadcasts the entire state dict after performing dtype conversions on rank 0.
- OSFT: handles non-OSFT tensors first, then launches distributed SVD to populate OSFT factors.
- Consumes the context from Phase 1 to populate the sharded parameters:
- Entry point:
setup_sft_model_distributed - Rank 0 loads the pretrained model, aligns tokenizer tokens, saves
state_dict+ buffers, then deletes the large model to free memory. _synchronize_state_dict_fsdp2is used in Phase 3 to broadcast the state dict through the FSDP2 shards. Once finished, all tensor references are released to keep the memory footprint flat.
- Entry point:
setup_osft_model_distributed - The OSFT wrapper (
create_osft_model_class) creates meta tensors, registers logical keys, and retains rank 0’s dense parameters insidemodel._lazy_init_og_state_dict. - During finalization:
post_fsdp2_wrap_synchronize_state_dict_across_procsbroadcasts every non-OSFT tensor (embedding tables, layer norms, etc.).compute_distributed_svdcarves out each OSFT target from the rank‑0 payload, assigns shards of work to every rank, and returns the SVD factor tensors to rank 0.- The caller feeds those factors into
set_model_state_dict, ensuring each shard receives the appropriate high-/low-rank parameters. mark_fsdp2_initializedclears the lazy flags so optimizers and checkpoint saves treat the model as fully materialized.
sequenceDiagram
participant R0 as Rank 0
participant R* as Other Ranks
participant Model as FSDP2 Model
R0->>Model: post_fsdp2_wrap_synchronize_state_dict_across_procs(non-OSFT tensors)
R0->>R*: Broadcast non-OSFT shards
note over R0,R*: Non-OSFT weights ready
R0->>R*: Send OSFT target assignments
R*-->>R0: Return SVD factors
R0->>Model: set_model_state_dict(SVD factors)
Model-->>R0: mark_fsdp2_initialized()
Most training scripts follow the template below:
context = prepare_model_for_fsdp2(model, tokenizer, base_model_args, ...)
model = wrap_fsdp2(model)
model = finalize_model_initialization(model, context)Switch between SFT and OSFT by toggling the CLI flags (--osft, --osft-unfreeze-rank-ratio, etc.). Both paths reuse the same Phase 2/3 implementation, so the rest of the training loop stays identical.