Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
41d96a3
make mamba
lchu6 Aug 5, 2024
81ca3c2
add quick debug
lchu6 Aug 10, 2024
0817ddc
add quick debug
lchu6 Aug 10, 2024
5d7e936
revert debug verbosity
lchu6 Aug 10, 2024
bcad3ad
Learning rate scheduler changed (Constant)
divya-kumari32 Nov 11, 2024
f8c1651
Mamba config restore
divya-kumari32 Nov 12, 2024
e974212
Add AutoHandler
daviswer Nov 12, 2024
cd872c2
Add Auto cfg option for AutoHAndler
daviswer Nov 12, 2024
e7a7179
Len gets called before open
daviswer Nov 12, 2024
3ab5174
path/filepath typo fix
daviswer Nov 12, 2024
06cda8a
Partitioning fix from mup-search
daviswer Nov 12, 2024
5ab4af6
Cosine 0.01 decay
divya-kumari32 Nov 17, 2024
fd0cc02
Merge branch 'mamba-new' of https://github.com/foundation-model-stack…
divya-kumari32 Nov 17, 2024
0210fd3
Warmup interval change
divya-kumari32 Nov 17, 2024
98ec15e
Schedule change
divya-kumari32 Nov 17, 2024
ab7fb58
Constant schedule
divya-kumari32 Nov 27, 2024
2b23cc6
LR schedule change (cool down and constant lr)
divya-kumari32 Dec 10, 2024
f87ca63
Update dataset_utils.py
divya-kumari32 Dec 14, 2024
e39f561
LR schedule change (Warmup + constant)
divya-kumari32 Dec 16, 2024
2402c88
Update dataset_utils.py
divya-kumari32 Dec 18, 2024
e435102
Cosine schedule
divya-kumari32 Dec 19, 2024
499a830
For constant lr 1.5e5
divya-kumari32 Jan 13, 2025
99a6543
Schedule change
divya-kumari32 Jan 13, 2025
5ff2e86
Schedule change
divya-kumari32 Jan 13, 2025
0fdb43d
Final singlefile checkpoint saves one folder up (#127)
daviswer Jan 14, 2025
1b85011
Merge from main
divya-kumari32 Jan 24, 2025
06ab5e6
Merge from main
divya-kumari32 Jan 24, 2025
ee2eb62
Added cool down
divya-kumari32 Jan 24, 2025
f425e50
length of doc check
divya-kumari32 Jan 31, 2025
8b8f688
splitstrip cols and pass to fhandler
daviswer Feb 3, 2025
c328c56
fhandler col_names support
daviswer Feb 3, 2025
9764afc
Warmup for annealing
divya-kumari32 Feb 4, 2025
43902da
Debugging
divya-kumari32 Feb 4, 2025
c1320c2
Debugging II
divya-kumari32 Feb 4, 2025
701d0eb
Empty shard check
daviswer Feb 19, 2025
c90acb9
Added constant lr schedule with warmup
divya-kumari32 Feb 23, 2025
45c636f
Merge branch 'mamba-new' of https://github.com/foundation-model-stack…
divya-kumari32 Feb 23, 2025
55fd5ac
added print for lenght of doc
divya-kumari32 Feb 26, 2025
97eade5
added print for lenght of doc II
divya-kumari32 Feb 26, 2025
1b50708
Update dataset_utils.py
divya-kumari32 Mar 6, 2025
ac5194b
Pulled from data-fixes branch
divya-kumari32 May 24, 2025
c2acc28
Update dataset_utils.py
divya-kumari32 May 24, 2025
87bdab0
Update dataset_utils.py
divya-kumari32 May 24, 2025
34d5974
Update dataset_utils.py
divya-kumari32 May 24, 2025
2f302d4
Adding print for debug
divya-kumari32 May 26, 2025
9558e2e
Revert "Pulled from data-fixes branch"
divya-kumari32 May 26, 2025
db7d4ad
Revert all changes made after March 6 (before merge)
divya-kumari32 May 26, 2025
149b6fa
Revert all changes made after March 6 (before merge)
divya-kumari32 May 26, 2025
385c98d
removed print
divya-kumari32 May 26, 2025
df36194
Merge branch 'data-fixes-temp' into mamba-new-data-fixes2
daviswer Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions fms_fsdp/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,3 @@ class train_config:
stage2_prompt_length: int = 64
stage2_batch_size: int = 96
stage2_seq_length: int = 256

# FIM training
psm_rate: float = 0.0
spm_rate: float = 0.0
fim_pre: int = 1
fim_mid: int = 2
fim_suf: int = 3
4 changes: 3 additions & 1 deletion fms_fsdp/utils/checkpointing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def save_single_file(
):
# Note: metadata kwargs cannot contain any of:
# (step, model)
save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth")
pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step))
os.makedirs(pth_path, exist_ok=True)
save_name = os.path.join(pth_path, "consolidated.00.pth")
save_time = time.time()
with FSDP.state_dict_type(
model,
Expand Down
38 changes: 11 additions & 27 deletions fms_fsdp/utils/dataloader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
AutoHandler,
BufferDataset,
CheckpointDataset,
FIMDataset,
ParquetHandler,
PreloadBufferDataset,
PreprocessDataset,
Expand Down Expand Up @@ -59,9 +58,9 @@ def __iter__(self):
return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size)


def get_data_loader(cfg, rank, world_size):
def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]):
"""
Pytorch dataloader for stateful, distributed, and rescalable language model training.
Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training.
Assumes underlying data is sequences of integer values.
...
Args
Expand All @@ -72,12 +71,11 @@ def get_data_loader(cfg, rank, world_size):
Rank of current distributed worker. Used for handling dataset sharding logic.
world_size : int
Number of distributed workers. Used for handling dataset sharding logic.
postprocess : List[Callable]
Any task-specific postprocessing to apply before handing over data. Steps will apply in
the order provided by the user. For CLM training, use postprocess=[causal_lm].
"""

fim_training = cfg.psm_rate + cfg.spm_rate > 0
if fim_training:
assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?"

datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name)

# Base streaming dataset. Returns doc chunks in sequence.
Expand All @@ -94,7 +92,7 @@ def get_data_loader(cfg, rank, world_size):
cfg.tokenizer_path, cols, cfg.doc_cutoff
)
else:
filehandler = _handler_map[cfg.file_type](cols)
filehandler = _handler_map[cfg.file_type, cols]
# Base reader layer
data = StreamingDocDataset(
cfg.data_path,
Expand Down Expand Up @@ -124,34 +122,20 @@ def get_data_loader(cfg, rank, world_size):
verbose=(rank == 0),
)
# Wrap above dataset in packing logic to form constant-length lines.
# Increment seq len to counteract CLM's one token removal.
data = BufferDataset(
data,
cfg.seq_length + 1,
cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1,
bos_token=cfg.bol_token,
eos_token=cfg.eol_token,
pack_hard=True,
)
# Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average.
data = PreloadBufferDataset(data, 10000)

# Apply FIM transformation if needed
if fim_training:
data = FIMDataset(
data,
cfg.eos_token,
cfg.psm_rate,
cfg.spm_rate,
pre_token=cfg.fim_pre,
mid_token=cfg.fim_mid,
suf_token=cfg.fim_suf,
)

# Transform to tensors
# Apply desired postprocessing steps in sequence
data = PreprocessDataset(data, torch.IntTensor)

# Apply CLM transformation
data = PreprocessDataset(data, causal_lm)
for p in postprocess:
data = PreprocessDataset(data, p)

# Enable auto-saving
data = CheckpointDataset(
Expand Down Expand Up @@ -181,4 +165,4 @@ def splitstrip(x):
datas = splitstrip(datas)
weights = [float(x) for x in splitstrip(weights)]
cols = splitstrip(cols)
return datas, weights, cols
return datas, weights, cols
Loading
Loading