From 41d96a342071894c547e0023ecf4afe44c6af647 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Mon, 5 Aug 2024 16:05:33 -0400 Subject: [PATCH 01/43] make mamba --- fms_fsdp/utils/config_utils.py | 221 ++++++++++++++------------------- fms_to_hf.py | 156 ++--------------------- main_training.py | 27 ++-- 3 files changed, 118 insertions(+), 286 deletions(-) diff --git a/fms_fsdp/utils/config_utils.py b/fms_fsdp/utils/config_utils.py index c0389b12..5642da8b 100644 --- a/fms_fsdp/utils/config_utils.py +++ b/fms_fsdp/utils/config_utils.py @@ -23,132 +23,101 @@ def update_config(config, **kwargs): def get_model_config(model_variant): - if model_variant == "llama2_70b": - llama_config = LLaMAConfig( - emb_dim=8192, - multiple_of=4096, - nheads=64, - kvheads=8, - nlayers=80, - hidden_grow_factor=28672 / 8192, - ) - elif model_variant == "llama2_34b": - llama_config = LLaMAConfig( - emb_dim=8192, - nheads=64, - kvheads=8, - nlayers=48, - hidden_grow_factor=22016 / 8192, - ) - elif model_variant == "llama2_13b": - llama_config = LLaMAConfig( - emb_dim=5120, - nheads=40, - nlayers=40, - hidden_grow_factor=13824 / 5120, - ) - elif model_variant == "llama2_7b": - llama_config = LLaMAConfig( - hidden_grow_factor=3, - kvheads=8, - ) - elif model_variant == "llama2_1.4b": - llama_config = LLaMAConfig( - emb_dim=2048, - nheads=16, - nlayers=24, - hidden_grow_factor=3, - kvheads=4, - ) - elif model_variant == "llama3_8b": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=4096, - nheads=32, - kvheads=8, - nlayers=32, - hidden_grow_factor=3.5, - max_expected_seq_len=8192, - ) - elif model_variant == "llama3_8b_4k": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=4096, - nheads=32, - kvheads=8, - nlayers=32, - hidden_grow_factor=3.5, - max_expected_seq_len=4096, - ) - elif model_variant == "llama3_1.8b": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=2048, - nheads=16, - kvheads=8, - nlayers=24, - hidden_grow_factor=3.5, - max_expected_seq_len=8192, - ) - elif model_variant == "llama3_1.8b_4k": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=2048, - nheads=16, - kvheads=8, - nlayers=24, - hidden_grow_factor=3.5, - max_expected_seq_len=4096, - ) - elif model_variant == "llama3_3.2b": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=3072, - nheads=24, - kvheads=8, - nlayers=24, - hidden_grow_factor=8 / 3, - max_expected_seq_len=8192, - ) - elif model_variant == "llama3_3.2b_4k": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=3072, - nheads=24, - kvheads=8, - nlayers=24, - hidden_grow_factor=8 / 3, - max_expected_seq_len=4096, - ) - elif model_variant == "llama3_70b": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=8192, - nheads=64, - kvheads=8, - nlayers=80, - hidden_grow_factor=3.5, - max_expected_seq_len=8192, - ) - elif model_variant == "llama3_70b_4k": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=8192, - nheads=64, - kvheads=8, - nlayers=80, - hidden_grow_factor=3.5, - max_expected_seq_len=4096, - ) - elif model_variant == "llama3_194m_4k": - llama_config = LLaMAConfig( - src_vocab_size=128256, - emb_dim=1024, - nheads=8, - nlayers=10, - max_expected_seq_len=4096, - ) + if model_variant == "mamba_1.5b": + config_data = { + "d_model": 2048, + "d_intermediate": 0, + "n_layer": 48, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [8, 16, 24, 32, 40], + "attn_cfg": { + "causal": True, + "d_conv": 4, + "head_dim": 128, + "num_heads": 24, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": True, + } + elif model_variant == "mamba_2.9b": + config_data = { + "d_model": 2560, + "d_intermediate": 0, + "n_layer": 64, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27, 36, 45, 56], + "attn_cfg": { + "causal": True, + "d_conv": 4, + "head_dim": 128, + "num_heads": 30, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": True, + } + elif model_variant == "mamba_9.8b": + config_data = { + "d_model": 4096, + "d_intermediate": 14336, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 0, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } + elif model_variant == "mamba_debug": + config_data = { + "d_model": 4096, + "d_intermediate": 0, + "n_layer": 32, + "vocab_size": 128256, + "ssm_cfg": {"layer": "Mamba2"}, + "attn_layer_idx": [9, 18, 27], + "attn_cfg": { + "causal": True, + "d_conv": 4, + "head_dim": 128, + "num_heads": 32, + "num_heads_kv": 8, + "out_proj_bias": False, + "qkv_proj_bias": False, + "rotary_emb_dim": 64, + }, + "rms_norm": True, + "residual_in_fp32": True, + "fused_add_norm": True, + "pad_vocab_size_multiple": 16, + "tie_embeddings": False, + } else: raise ValueError(f"model variant {model_variant} not supported.") - return llama_config + return config_data diff --git a/fms_to_hf.py b/fms_to_hf.py index d03582a3..a3fdfc87 100644 --- a/fms_to_hf.py +++ b/fms_to_hf.py @@ -1,162 +1,28 @@ import fire -import torch -from fms.models.hf import to_hf_api -from fms.models.llama import LLaMA +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict -from transformers import LlamaConfig, LlamaForCausalLM from fms_fsdp.utils.config_utils import get_model_config -def convert_to_hf(model: LLaMA, model_variant, is_old_fms) -> LlamaForCausalLM: - fms_hf_model = to_hf_api(model) - hf_config = fms_hf_model.config - if "llama3" in model_variant: - hf_config.bos_token_id = 128000 - hf_config.eos_token_id = 128001 - oss_hf_model = LlamaForCausalLM( - LlamaConfig( - vocab_size=hf_config.vocab_size, - hidden_size=hf_config.hidden_size, - rms_norm_eps=hf_config.norm_eps, - num_attention_heads=hf_config.nheads, - num_key_value_heads=None if hf_config.kvheads == 0 else hf_config.kvheads, - num_hidden_layers=hf_config.nlayers, - intermediate_size=hf_config.multiple_of - * ( - ( - int(hf_config.hidden_grow_factor * hf_config.hidden_size) - + hf_config.multiple_of - - 1 - ) - // hf_config.multiple_of - ), - pad_token_id=( - None if hf_config.pad_token_id == -1 else hf_config.pad_token_id - ), - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - max_position_embeddings=hf_config.max_expected_seq_len, - ) - ) - - # compute the freq from rot_emb since it is gathered lazily - rot_emb = fms_hf_model.decoder.model.rot_emb - max_seq_len = rot_emb.max_seq_len - alpha = rot_emb._alpha(max_seq_len) - ratio = rot_emb.ratio - dim = rot_emb.dim - if rot_emb.ntk_scaling: - ratio = ratio * alpha ** (dim / (dim - 2)) - freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - - with torch.no_grad(): - oss_hf_model.model.embed_tokens.weight.copy_(fms_hf_model.embedding.weight) - i = 0 - for oss_hf_layer in oss_hf_model.model.layers: - fms_hf_layer = fms_hf_model.decoder.model.layers[i] - - # self attn - if is_old_fms: - oss_hf_layer.self_attn.q_proj.weight.copy_( - fms_hf_layer.attn.query.weight - ) - oss_hf_layer.self_attn.k_proj.weight.copy_(fms_hf_layer.attn.key.weight) - oss_hf_layer.self_attn.v_proj.weight.copy_( - fms_hf_layer.attn.value.weight - ) - else: - q, k, v = torch.split( - fms_hf_layer.attn.in_proj.qkv_fused.weight, - fms_hf_layer.attn.in_proj.splits, - dim=0, - ) - oss_hf_layer.self_attn.q_proj.weight.copy_(q) - oss_hf_layer.self_attn.k_proj.weight.copy_(k) - oss_hf_layer.self_attn.v_proj.weight.copy_(v) - oss_hf_layer.self_attn.o_proj.weight.copy_(fms_hf_layer.attn.dense.weight) - oss_hf_layer.self_attn.rotary_emb.inv_freqs = freqs - - # mlp - if is_old_fms: - oss_hf_layer.mlp.gate_proj.weight.copy_( - fms_hf_layer.ff_sub_layer.wg.weight - ) - oss_hf_layer.mlp.up_proj.weight.copy_( - fms_hf_layer.ff_sub_layer.w1.weight - ) - else: - wg1_fused = fms_hf_layer.ff_sub_layer.wg1_fused.weight - wg_splits = [wg1_fused.size(0) // 2, wg1_fused.size(0) // 2] - wg, w1 = torch.split( - fms_hf_layer.ff_sub_layer.wg1_fused.weight, wg_splits, dim=0 - ) - oss_hf_layer.mlp.gate_proj.weight.copy_(wg) - oss_hf_layer.mlp.up_proj.weight.copy_(w1) - oss_hf_layer.mlp.down_proj.weight.copy_(fms_hf_layer.ff_sub_layer.w2.weight) - - # layer norm - oss_hf_layer.input_layernorm.weight.copy_(fms_hf_layer.ln.weight) - oss_hf_layer.post_attention_layernorm.weight.copy_( - fms_hf_layer.ff_ln.weight - ) - - # adjust q, k - q = oss_hf_layer.self_attn.q_proj.weight.data - q = ( - q.view(hf_config.nheads, -1, 2, q.size(1)) - .transpose(1, 2) - .reshape(*q.size()) - ) - oss_hf_layer.self_attn.q_proj.weight.copy_(q) - - k = oss_hf_layer.self_attn.k_proj.weight.data - k = ( - k.view( - hf_config.nheads if hf_config.kvheads == 0 else hf_config.kvheads, - -1, - 2, - k.size(1), - ) - .transpose(1, 2) - .reshape(*k.size()) - ) - oss_hf_layer.self_attn.k_proj.weight.copy_(k) - - i = i + 1 - oss_hf_model.model.norm.weight = fms_hf_model.decoder.model.dec_norm.weight - oss_hf_model.lm_head.weight = fms_hf_model.lm_head.weight - - return oss_hf_model - - -def main( - model_variant, compiled, is_old_fms, load_path, save_path, tokenizer_name_or_path -): +def main(model_variant, load_path, save_path, tokenizer_name_or_path): print("Initializing model...") - llama_config = get_model_config(model_variant) - with torch.device("meta"): - model = LLaMA(llama_config) - model.to_empty(device="cpu") + config_data = get_model_config(model_variant) + mamba_config = MambaConfig(**config_data) + model = MambaLMHeadModel(mamba_config) print(f"Reading state dict from {load_path}") - if not compiled: - state_dict = {"model_state": model.state_dict()} - else: - state_dict = {"model_state": {"_orig_mod": model.state_dict()}} + state_dict = {"model_state": model.state_dict()} load_state_dict( state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True ) print("Loading state dict into the model...") - if not compiled: - model.load_state_dict(state_dict["model_state"]) - else: - model.load_state_dict(state_dict["model_state"]["_orig_mod"]) + model.load_state_dict(state_dict["model_state"]) - print("Converting to HF model..") - hf_model = convert_to_hf(model, model_variant, is_old_fms) - hf_model.save_pretrained(save_path) + print("Saving model to HF-compatible format...") + model.save_pretrained(save_path) print("Copying tokenizer...") from transformers import AutoTokenizer @@ -164,7 +30,7 @@ def main( tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) tokenizer.save_pretrained(save_path) - print(f"Model converted to HF model, saving at {save_path}") + print(f"Model saving at {save_path}") if __name__ == "__main__": diff --git a/main_training.py b/main_training.py index 67cccee2..200bc6ea 100644 --- a/main_training.py +++ b/main_training.py @@ -1,10 +1,13 @@ import math import os +from pathlib import Path import fire import torch import torch.optim as optim -from fms.models.llama import LLaMA, LLaMABlock +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from mamba_ssm.modules.block import Block from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim.lr_scheduler import LambdaLR @@ -44,9 +47,12 @@ def main(**kwargs): torch.cuda.set_device(local_rank) torch.cuda.empty_cache() setup_environ_flags() + os.environ["TRITON_CACHE_DIR"] = os.path.join( + Path.home(), ".triton", "cache", str(local_rank) + ) # get policy - block = LLaMABlock + block = Block ( mixed_precision_policy, wrapping_policy, @@ -55,14 +61,10 @@ def main(**kwargs): param_init_fn, ) = get_policies(cfg, rank, block) - # get fms model - llama_config = get_model_config(cfg.model_variant) - if cfg.low_cpu_fsdp: - with torch.device("meta"): - model = LLaMA(llama_config) - else: - model = LLaMA(llama_config) - model.reset_parameters() + # get model + config_data = get_model_config(cfg.model_variant) + mamba_config = MambaConfig(**config_data) + model = MambaLMHeadModel(mamba_config) if rank == 0: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -89,11 +91,6 @@ def main(**kwargs): limit_all_gathers=True, param_init_fn=param_init_fn, ) - # we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution. - model.rot_emb.compute_freqs_cis( - torch.device("cuda", torch.cuda.current_device()), - model.config.max_expected_seq_len, - ) # fsdp activation checkpointing if cfg.fsdp_activation_checkpointing: From 81ca3c2ff2cc8ff1e87a0d65bd2ae3a0831b5490 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Sat, 10 Aug 2024 09:43:10 -0400 Subject: [PATCH 02/43] add quick debug --- fms_fsdp/utils/checkpointing_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 5381dc9d..90400ff9 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -179,9 +179,12 @@ def load( Returns model, optimizer, dataloader, current step, and current tokens seen. """ is_resuming = False + print(self.ckp_path) if self._validate_ckp_path(self.ckp_path) is not None: + print("yyyyyyyy") path = self.ckp_path is_resuming = True + print(path) load_path = self._validate_ckp_path(path) if load_path is None: self.report( From 0817ddcace669bd20771c01664b09b243ede4189 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Sat, 10 Aug 2024 09:50:42 -0400 Subject: [PATCH 03/43] add quick debug --- fms_fsdp/utils/checkpointing_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 90400ff9..e05f7421 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -24,6 +24,7 @@ def get_latest(targdir, qualifier=lambda x: True): """Fetch the latest file or folder written to target directory, subject to name passing the qualifier fn. If directory is empty or nonexistent or no items qualify, return None.""" if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: + print(targdir) latest = max( [ os.path.join(targdir, x) From 5d7e936ccf9f9f31b1b6c56b0502798f4c571963 Mon Sep 17 00:00:00 2001 From: Linsong Chu Date: Sat, 10 Aug 2024 10:01:50 -0400 Subject: [PATCH 04/43] revert debug verbosity --- fms_fsdp/utils/checkpointing_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index e05f7421..5381dc9d 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -24,7 +24,6 @@ def get_latest(targdir, qualifier=lambda x: True): """Fetch the latest file or folder written to target directory, subject to name passing the qualifier fn. If directory is empty or nonexistent or no items qualify, return None.""" if os.path.exists(targdir) and len(os.listdir(targdir)) > 0: - print(targdir) latest = max( [ os.path.join(targdir, x) @@ -180,12 +179,9 @@ def load( Returns model, optimizer, dataloader, current step, and current tokens seen. """ is_resuming = False - print(self.ckp_path) if self._validate_ckp_path(self.ckp_path) is not None: - print("yyyyyyyy") path = self.ckp_path is_resuming = True - print(path) load_path = self._validate_ckp_path(path) if load_path is None: self.report( From bcad3adbc449d9664385632c5149ca2ef1e69338 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Mon, 11 Nov 2024 15:47:09 -0500 Subject: [PATCH 05/43] Learning rate scheduler changed (Constant) --- main_training.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/main_training.py b/main_training.py index 200bc6ea..179a4687 100644 --- a/main_training.py +++ b/main_training.py @@ -134,14 +134,21 @@ def main(**kwargs): if cfg.training_stage == "annealing": schedule = lambda x: 1 - x / cfg.num_steps else: - warmup_interval = min(2000, cfg.num_steps // 20) - schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, - 0.1 - + 0.5 - * (1 - 0.1) - * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + + warmup_interval = 1000 + schedule = lambda x: ( + min(x, warmup_interval) / warmup_interval ) + + # warmup_interval = min(2000, cfg.num_steps // 20) + # schedule = lambda x: min( + # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + # 0.1 + # + 0.5 + # * (1 - 0.1) + # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + # ) + scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) # profiler From e974212422c0d0925f0a267614321aefa1457e06 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 12 Nov 2024 11:06:40 -0500 Subject: [PATCH 06/43] Add AutoHandler --- fms_fsdp/utils/dataset_utils.py | 47 +++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index d1d442d7..095cc847 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -402,6 +402,53 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: return doc[index : index + n_pull] +class AutoHandler(_ShardFileHandler): + def __init__(self, tokenizer_path: str, col_name: str = "text"): + self.PHandler = ParquetHandler(tokenizer_path, col_name) + self.AHandler = ArrowHandler() + self.current = None + + def is_legal(self, filepath: str): + return "parquet" in os.path.splitext(filepath)[1] or "arrow" in os.path.splitext(filepath)[1] + + def open(self, path: str): + """ + Open the file, to be indexed via self.get() method. + Avoid reading entire multi-Gb files when possible! + """ + if "arrow" in os.path.splitext(filepath)[1]: + self.current = self.AHandler + else: + self.current = self.PHandler + return self.current.open(path) + + def length(self, path: str): + """ + Calculate the number of documents in the given file. + Avoid reading entire multi-Gb files when possible! + """ + return self.current.length(path) + + def get(self, reader, index: int, drop_tokens: Set): + """ + Given the output of self.open() and an index, return the document at that index. + Then, remove the first and/or last items if they appear in drop_tokens. + Try to avoid reading entire documents at a time in case of long documents, + but this is less important than avoiding reading entire files as above. + Output must support len(). + """ + return self.current.get(reader, index, drop_tokens) + + def slice(self, doc, index: int, n_pull: int) -> List: + """ + Given a long document, retrieve n_pull consecutive items starting from index. + Again, try to be memory-efficient when doing so, but efficiency in self.get() + and self.open() is far more important. + Must return a python list. + """ + return self.current.slice(doc, index, n_pull) + + #### ------------------------- PIPELINE LAYERS ------------------------- #### From cd872c23ae70c9c2e47a45e6b28363daf3c62f28 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 12 Nov 2024 11:08:28 -0500 Subject: [PATCH 07/43] Add Auto cfg option for AutoHAndler --- fms_fsdp/utils/dataloader_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 2faeffb7..4b811d6d 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -2,6 +2,7 @@ from fms_fsdp.utils.dataset_utils import ( ArrowHandler, + AutoHandler, BufferDataset, CheckpointDataset, ParquetHandler, @@ -16,6 +17,7 @@ _handler_map = { "arrow": ArrowHandler, "hf_parquet": ParquetHandler, + "auto": AutoHandler, } @@ -84,10 +86,10 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): assert ( cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" - if cfg.file_type == "hf_parquet": - filehandler = ParquetHandler(cfg.tokenizer_path, cfg.col_name) + if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": + filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) else: - filehandler = _handler_map[cfg.file_type](cfg.col_name) + filehandler = _handler_map[cfg.file_type] # Base reader layer data = StreamingDocDataset( cfg.data_path, From e7a7179eb6648e9fe45e19934379d38f4709dbd9 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 12 Nov 2024 13:37:22 -0500 Subject: [PATCH 08/43] Len gets called before open --- fms_fsdp/utils/dataset_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 095cc847..0a9d8c31 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -427,7 +427,10 @@ def length(self, path: str): Calculate the number of documents in the given file. Avoid reading entire multi-Gb files when possible! """ - return self.current.length(path) + if "arrow" in os.path.splitext(filepath)[1]: + return self.AHandler.length(path) + else: + return self.PHandler.length(path) def get(self, reader, index: int, drop_tokens: Set): """ From 3ab51740e835c751c203dc25acaf0527e5d947f5 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 12 Nov 2024 13:51:40 -0500 Subject: [PATCH 09/43] path/filepath typo fix --- fms_fsdp/utils/dataset_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 0a9d8c31..3b4c8122 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -416,7 +416,7 @@ def open(self, path: str): Open the file, to be indexed via self.get() method. Avoid reading entire multi-Gb files when possible! """ - if "arrow" in os.path.splitext(filepath)[1]: + if "arrow" in os.path.splitext(path)[1]: self.current = self.AHandler else: self.current = self.PHandler @@ -427,7 +427,7 @@ def length(self, path: str): Calculate the number of documents in the given file. Avoid reading entire multi-Gb files when possible! """ - if "arrow" in os.path.splitext(filepath)[1]: + if "arrow" in os.path.splitext(path)[1]: return self.AHandler.length(path) else: return self.PHandler.length(path) From 06cda8a6514d9c71afdfe7c8fef1b02bd7cc0ee7 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 12 Nov 2024 14:30:02 -0500 Subject: [PATCH 10/43] Partitioning fix from mup-search --- fms_fsdp/utils/dataset_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 3b4c8122..1a17dd75 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -384,10 +384,10 @@ def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] def open(self, path: str): - return pq.read_pandas(path, columns=[self.col_name])[self.col_name] + return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[self.col_name] def length(self, path: str): - return pq.read_pandas(path, columns=[]).num_rows + return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): doc = self.tokenizer(str(reader[index]))["input_ids"] From 5ab4af600d2eb9ac2d0a28f3f49c06bcaa03430d Mon Sep 17 00:00:00 2001 From: divykum2 Date: Sat, 16 Nov 2024 22:43:35 -0500 Subject: [PATCH 11/43] Cosine 0.01 decay --- main_training.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/main_training.py b/main_training.py index 179a4687..6076a68f 100644 --- a/main_training.py +++ b/main_training.py @@ -135,11 +135,23 @@ def main(**kwargs): schedule = lambda x: 1 - x / cfg.num_steps else: - warmup_interval = 1000 - schedule = lambda x: ( - min(x, warmup_interval) / warmup_interval + # (cosine 0.01 decay) + warmup_interval = min(2000, cfg.num_steps // 20) + schedule = lambda x: min( + 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + 0.01 + + 0.5 + * (1 - 0.1) + * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), ) + + # (constant schedule) + # warmup_interval = 1000 + # schedule = lambda x: ( + # min(x, warmup_interval) / warmup_interval + # ) + # (cosine 0.1 decay) # warmup_interval = min(2000, cfg.num_steps // 20) # schedule = lambda x: min( # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, From 0210fd330a51ed769334a8f804e7668198a8ab66 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Sun, 17 Nov 2024 10:06:15 -0500 Subject: [PATCH 12/43] Warmup interval change --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index 6076a68f..a4d4a8b1 100644 --- a/main_training.py +++ b/main_training.py @@ -136,7 +136,7 @@ def main(**kwargs): else: # (cosine 0.01 decay) - warmup_interval = min(2000, cfg.num_steps // 20) + warmup_interval = min(1000, cfg.num_steps // 10) schedule = lambda x: min( 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, 0.01 From 98ec15e874f3a388328c1535413eb3bc188ea279 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Sun, 17 Nov 2024 12:58:56 -0500 Subject: [PATCH 13/43] Schedule change --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index a4d4a8b1..54708fc6 100644 --- a/main_training.py +++ b/main_training.py @@ -141,7 +141,7 @@ def main(**kwargs): 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, 0.01 + 0.5 - * (1 - 0.1) + * (1 - 0.01) * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), ) From ab7fb58467ae8f33f77fc745e0306ca03f66edf4 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 27 Nov 2024 17:15:15 -0500 Subject: [PATCH 14/43] Constant schedule --- main_training.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/main_training.py b/main_training.py index 54708fc6..0e8b5c2a 100644 --- a/main_training.py +++ b/main_training.py @@ -136,20 +136,20 @@ def main(**kwargs): else: # (cosine 0.01 decay) - warmup_interval = min(1000, cfg.num_steps // 10) - schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, - 0.01 - + 0.5 - * (1 - 0.01) - * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), - ) + # warmup_interval = min(1000, cfg.num_steps // 10) + # schedule = lambda x: min( + # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + # 0.01 + # + 0.5 + # * (1 - 0.01) + # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + # ) # (constant schedule) - # warmup_interval = 1000 - # schedule = lambda x: ( - # min(x, warmup_interval) / warmup_interval - # ) + warmup_interval = 1000 + schedule = lambda x: ( + min(x, warmup_interval) / warmup_interval + ) # (cosine 0.1 decay) # warmup_interval = min(2000, cfg.num_steps // 20) From 2b23cc63767c7b52862ad339762cdee184456475 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Tue, 10 Dec 2024 15:57:28 -0500 Subject: [PATCH 15/43] LR schedule change (cool down and constant lr) --- main_training.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/main_training.py b/main_training.py index 0e8b5c2a..f61c6bd3 100644 --- a/main_training.py +++ b/main_training.py @@ -146,10 +146,10 @@ def main(**kwargs): # ) # (constant schedule) - warmup_interval = 1000 - schedule = lambda x: ( - min(x, warmup_interval) / warmup_interval - ) + # warmup_interval = 1000 + # schedule = lambda x: ( + # min(x, warmup_interval) / warmup_interval + # ) # (cosine 0.1 decay) # warmup_interval = min(2000, cfg.num_steps // 20) @@ -160,6 +160,9 @@ def main(**kwargs): # * (1 - 0.1) # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), # ) + + # linear decay to 50b tokens and then constant lr + schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) From f87ca63dd149f87151f6a47463a14abe98cb24b2 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Sat, 14 Dec 2024 21:50:04 +0300 Subject: [PATCH 16/43] Update dataset_utils.py Added a check for length of doc --- fms_fsdp/utils/dataset_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 1a17dd75..c4efde2c 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -394,6 +394,8 @@ def get(self, reader, index: int, drop_tokens: Set): if len(doc) > 0: if doc[0] in drop_tokens: doc = doc[1:] + # check the length again after removing the first token + if len(doc) > 0: if doc[-1] in drop_tokens: doc = doc[:-1] return doc From e39f56164c2b0a8fb908bcebeacd461ec2c4da28 Mon Sep 17 00:00:00 2001 From: divya-kumari32 Date: Mon, 16 Dec 2024 19:42:27 +0300 Subject: [PATCH 17/43] LR schedule change (Warmup + constant) --- main_training.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/main_training.py b/main_training.py index f61c6bd3..f4988a1d 100644 --- a/main_training.py +++ b/main_training.py @@ -146,10 +146,10 @@ def main(**kwargs): # ) # (constant schedule) - # warmup_interval = 1000 - # schedule = lambda x: ( - # min(x, warmup_interval) / warmup_interval - # ) + warmup_interval = 1000 + schedule = lambda x: ( + min(x, warmup_interval) / warmup_interval + ) # (cosine 0.1 decay) # warmup_interval = min(2000, cfg.num_steps // 20) @@ -162,7 +162,7 @@ def main(**kwargs): # ) # linear decay to 50b tokens and then constant lr - schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 + # schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) From 2402c88d57eddb9fd016c7f5d367bc68a643234d Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Wed, 18 Dec 2024 21:45:46 +0300 Subject: [PATCH 18/43] Update dataset_utils.py --- fms_fsdp/utils/dataset_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index c4efde2c..9bd1081a 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -360,6 +360,7 @@ def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): if len(doc) > 0: if doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) + if len(doc) > 0: if doc[-1].as_py() in drop_tokens: doc = doc.slice(0, len(doc) - 1) return doc From e43510217c54cdce9ac0e5c953fe5a377b604f9a Mon Sep 17 00:00:00 2001 From: divya-kumari32 Date: Thu, 19 Dec 2024 20:21:06 +0300 Subject: [PATCH 19/43] Cosine schedule --- main_training.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/main_training.py b/main_training.py index f4988a1d..86b8857e 100644 --- a/main_training.py +++ b/main_training.py @@ -146,20 +146,20 @@ def main(**kwargs): # ) # (constant schedule) - warmup_interval = 1000 - schedule = lambda x: ( - min(x, warmup_interval) / warmup_interval - ) + # warmup_interval = 1000 + # schedule = lambda x: ( + # min(x, warmup_interval) / warmup_interval + # ) # (cosine 0.1 decay) - # warmup_interval = min(2000, cfg.num_steps // 20) - # schedule = lambda x: min( - # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, - # 0.1 - # + 0.5 - # * (1 - 0.1) - # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), - # ) + warmup_interval = min(2000, cfg.num_steps // 20) + schedule = lambda x: min( + 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + 0.1 + + 0.5 + * (1 - 0.1) + * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + ) # linear decay to 50b tokens and then constant lr # schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 From 499a83074a0d68b0c332a5520c1c991744a887af Mon Sep 17 00:00:00 2001 From: divya-kumari32 Date: Mon, 13 Jan 2025 19:16:55 +0530 Subject: [PATCH 20/43] For constant lr 1.5e5 --- main_training.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/main_training.py b/main_training.py index 86b8857e..f2976c03 100644 --- a/main_training.py +++ b/main_training.py @@ -150,16 +150,17 @@ def main(**kwargs): # schedule = lambda x: ( # min(x, warmup_interval) / warmup_interval # ) + schedule = lambda x: 1.0 # (cosine 0.1 decay) - warmup_interval = min(2000, cfg.num_steps // 20) - schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, - 0.1 - + 0.5 - * (1 - 0.1) - * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), - ) + # warmup_interval = min(2000, cfg.num_steps // 20) + # schedule = lambda x: min( + # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + # 0.1 + # + 0.5 + # * (1 - 0.1) + # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + # ) # linear decay to 50b tokens and then constant lr # schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 From 99a6543771a8831c3952bdfd0edcb059736150bc Mon Sep 17 00:00:00 2001 From: divya-kumari32 Date: Mon, 13 Jan 2025 21:17:39 +0530 Subject: [PATCH 21/43] Schedule change --- main_training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main_training.py b/main_training.py index f2976c03..d5dac1d8 100644 --- a/main_training.py +++ b/main_training.py @@ -150,7 +150,7 @@ def main(**kwargs): # schedule = lambda x: ( # min(x, warmup_interval) / warmup_interval # ) - schedule = lambda x: 1.0 + schedule = 1.0 # (cosine 0.1 decay) # warmup_interval = min(2000, cfg.num_steps // 20) @@ -160,7 +160,7 @@ def main(**kwargs): # + 0.5 # * (1 - 0.1) # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), - # ) + # # ) # linear decay to 50b tokens and then constant lr # schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 From 5ff2e864a49f2e5484bf4bdd3e25d16ac2eba730 Mon Sep 17 00:00:00 2001 From: divya-kumari32 Date: Mon, 13 Jan 2025 22:34:24 +0530 Subject: [PATCH 22/43] Schedule change --- main_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_training.py b/main_training.py index d5dac1d8..4bd2a3c3 100644 --- a/main_training.py +++ b/main_training.py @@ -150,7 +150,7 @@ def main(**kwargs): # schedule = lambda x: ( # min(x, warmup_interval) / warmup_interval # ) - schedule = 1.0 + schedule = lambda x: min(1.0, x) # (cosine 0.1 decay) # warmup_interval = min(2000, cfg.num_steps // 20) From 0fdb43dcfd31ab093f8d873b58b0b531dd0818b1 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Tue, 14 Jan 2025 18:58:54 -0500 Subject: [PATCH 23/43] Final singlefile checkpoint saves one folder up (#127) * Final singlefile checkpoint saves one folder up Signed-off-by: Davis Wertheimer * save file under new pth subfolder Signed-off-by: Davis Wertheimer * Repath for easier consumption/conversion Signed-off-by: Davis Wertheimer --------- Signed-off-by: Davis Wertheimer --- fms_fsdp/utils/checkpointing_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 2bbeef18..65b18838 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -324,7 +324,9 @@ def save_single_file( ): # Note: metadata kwargs cannot contain any of: # (step, model) - save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth") + pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step)) + os.makedirs(pth_path, exist_ok=True) + save_name = os.path.join(pth_path, "consolidated.00.pth") save_time = time.time() with FSDP.state_dict_type( model, From ee2eb626338c280f740ae6b66f4c1a37da15323c Mon Sep 17 00:00:00 2001 From: divykum2 Date: Fri, 24 Jan 2025 13:38:57 -0500 Subject: [PATCH 24/43] Added cool down --- main_training_mamba.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/main_training_mamba.py b/main_training_mamba.py index 3619ea25..58999768 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -134,7 +134,7 @@ def main(**kwargs): # linear decay for annealing if cfg.training_stage == "annealing": schedule = lambda x: 1 - x / cfg.num_steps - else: + elif cfg.training_stage == "cosine": # cosine decay warmup_interval = min(2000, cfg.num_steps // 20) schedule = lambda x: min( @@ -144,6 +144,8 @@ def main(**kwargs): * (1 - 0.1) * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), ) + else: + schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) From f425e502a350c7605886330a185069fd68d42b9f Mon Sep 17 00:00:00 2001 From: divykum2 Date: Fri, 31 Jan 2025 16:02:13 -0500 Subject: [PATCH 25/43] length of doc check --- fms_fsdp/utils/dataset_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index b26f1913..aedc5862 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -393,11 +393,11 @@ def length(self, path: str): def get(self, reader, index: int, drop_tokens: Set): doc = self.tokenizer(str(reader[index]))["input_ids"] - if len(doc) > 0: - if doc[0] in drop_tokens: - doc = doc[1:] - if doc[-1] in drop_tokens: - doc = doc[:-1] + if len(doc) > 0 and doc[0] in drop_tokens: + doc = doc[1:] + # Recheck len for edge case where doc=[eos] + if len(doc) > 0 and doc[-1] in drop_tokens: + doc = doc[:-1] return doc def slice(self, doc: List, index: int, n_pull: int) -> List: From 8b8f6883eb2112351f05fcce8cbafbb81525fb85 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 3 Feb 2025 16:56:45 -0500 Subject: [PATCH 26/43] splitstrip cols and pass to fhandler --- fms_fsdp/utils/dataloader_utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 4b811d6d..afce42ec 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -75,7 +75,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - datasets, weights = parse_data_args(cfg.datasets, cfg.weights) + datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. # Implements dataset sampling and rescalability. @@ -87,9 +87,9 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): cfg.file_type in _handler_map ), f"File type {cfg.file_type} is not recognized ({list(_handler_map.keys())})" if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": - filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cfg.col_name) + filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols) else: - filehandler = _handler_map[cfg.file_type] + filehandler = _handler_map[cfg.file_type, cols] # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -146,7 +146,7 @@ def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): ) -def parse_data_args(datas, weights): +def parse_data_args(datas, weights, cols): # Convert csv inputs into corresponding lists of values def splitstrip(x): if isinstance(x, str): @@ -160,4 +160,5 @@ def splitstrip(x): datas = splitstrip(datas) weights = [float(x) for x in splitstrip(weights)] - return datas, weights + cols = splitstrip(cols) + return datas, weights, cols From c328c56e4e5c7882de1e8fbab6c4932cadcf6511 Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Mon, 3 Feb 2025 17:16:14 -0500 Subject: [PATCH 27/43] fhandler col_names support --- fms_fsdp/utils/dataset_utils.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index aedc5862..087d0394 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -343,8 +343,8 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_name: str = "tokens"): - self.col_name = col_name + def __init__(self, col_names: List[str] = ["tokens"]): + self.col_names = col_names def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,7 +356,13 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - doc = reader.get_batch(index)[self.col_name] + frame = reader.get_batch(index) + doc = None + for name in self.col_names: + if name in frame.column_names: + doc = frame[name] + break + assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -376,17 +382,22 @@ class ParquetHandler(_ShardFileHandler): before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__(self, tokenizer_path: str, col_name: str = "text"): + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) - self.col_name = col_name + self.col_names = col_names def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] def open(self, path: str): - return pq.read_pandas(path, columns=[self.col_name], partitioning=None)[ - self.col_name - ] + names = pq.read_metadata(path).schema.names + match = None + for name in self.col_names: + if name in names: + match = name + break + assert match is not None, f"None of column names {self.col_names} found in file headers {names}" + return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows @@ -405,9 +416,9 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_name: str = "text"): - self.PHandler = ParquetHandler(tokenizer_path, col_name) - self.AHandler = ArrowHandler() + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "tokens"]): + self.PHandler = ParquetHandler(tokenizer_path, col_names) + self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() def is_legal(self, filepath: str): From 9764afcfb2492644645627e3a739cc509a40fda0 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Mon, 3 Feb 2025 22:34:07 -0500 Subject: [PATCH 28/43] Warmup for annealing --- fms_fsdp/utils/dataset_utils.py | 4 +++- main_training_mamba.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 087d0394..a072fde6 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -345,7 +345,8 @@ class ArrowHandler(_ShardFileHandler): def __init__(self, col_names: List[str] = ["tokens"]): self.col_names = col_names - + # print(self.col_names) + def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -385,6 +386,7 @@ class ParquetHandler(_ShardFileHandler): def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.col_names = col_names + # print(self.col_names) def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] diff --git a/main_training_mamba.py b/main_training_mamba.py index 58999768..6c44133b 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -133,7 +133,8 @@ def main(**kwargs): # LR schedule # linear decay for annealing if cfg.training_stage == "annealing": - schedule = lambda x: 1 - x / cfg.num_steps + warmup_interval = 1000 + schedule = lambda x: x / warmup_interval if x < warmup_interval else 1 - (x - warmup_interval) / (cfg.num_steps - warmup_interval) elif cfg.training_stage == "cosine": # cosine decay warmup_interval = min(2000, cfg.num_steps // 20) From 43902daa22c66adf2840d3779e4e1e0c173ee088 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Mon, 3 Feb 2025 23:01:38 -0500 Subject: [PATCH 29/43] Debugging --- fms_fsdp/utils/dataset_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index a072fde6..499e382f 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -358,6 +358,8 @@ def length(self, path: str): def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): frame = reader.get_batch(index) + print(f"Printing column names in frame: {frame.column_names}") + doc = None for name in self.col_names: if name in frame.column_names: @@ -418,7 +420,7 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "tokens"]): + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]): self.PHandler = ParquetHandler(tokenizer_path, col_names) self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() From c1320c245ead5ccb7bb2fc0161b71e7b7e73867e Mon Sep 17 00:00:00 2001 From: divykum2 Date: Mon, 3 Feb 2025 23:07:19 -0500 Subject: [PATCH 30/43] Debugging II --- fms_fsdp/utils/dataset_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 499e382f..d556e994 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -345,7 +345,6 @@ class ArrowHandler(_ShardFileHandler): def __init__(self, col_names: List[str] = ["tokens"]): self.col_names = col_names - # print(self.col_names) def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -358,7 +357,6 @@ def length(self, path: str): def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): frame = reader.get_batch(index) - print(f"Printing column names in frame: {frame.column_names}") doc = None for name in self.col_names: @@ -388,7 +386,6 @@ class ParquetHandler(_ShardFileHandler): def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.col_names = col_names - # print(self.col_names) def is_legal(self, filepath: str): return "parquet" in os.path.splitext(filepath)[1] From 701d0ebb24c56b6fc01d5100eca6c62faf4c4dea Mon Sep 17 00:00:00 2001 From: Davis Wertheimer Date: Wed, 19 Feb 2025 11:12:08 -0500 Subject: [PATCH 31/43] Empty shard check --- fms_fsdp/utils/dataset_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index d556e994..f7c0cd28 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -891,6 +891,7 @@ def __init__( # Position self.docset_index = 0 self.chunk_index = -1 + self.has_yielded = False # Stats self.epochs_seen = -1 @@ -1122,6 +1123,7 @@ def __iter__(self): self.percent_seen = ( self.docs_seen * 100 / (self._len + 1e-9) ) + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) # Advance RNG state @@ -1142,8 +1144,12 @@ def __iter__(self): n_chunks = math.ceil(doclen / self.chunksize) for j in range(residual_chunks): self.chunk_index = j + self.has_yielded = True yield self._construct_chunk(j, doc, n_chunks) + # Check that epoch was non-empty + assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}" + def load_state_dict(self, state_dicts, sharded_input=False): self.setup() assert ( From c90acb9962788c26a034e8312da3dc60b26b54ba Mon Sep 17 00:00:00 2001 From: divykum2 Date: Sun, 23 Feb 2025 15:28:08 -0500 Subject: [PATCH 32/43] Added constant lr schedule with warmup --- main_training_mamba.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/main_training_mamba.py b/main_training_mamba.py index 6c44133b..3ff12a60 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -145,6 +145,9 @@ def main(**kwargs): * (1 - 0.1) * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), ) + elif cfg.training_stage == "constant": + warmup_interval = 2000 + schedule = lambda x: (min(x, warmup_interval) / warmup_interval) else: schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 From 55fd5ac5322a8df80ed41e2dbee208f20b5fb526 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 26 Feb 2025 17:45:59 -0500 Subject: [PATCH 33/43] added print for lenght of doc --- fms_fsdp/utils/dataset_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index f7c0cd28..5865b2ed 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -404,7 +404,11 @@ def length(self, path: str): return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): + doc = self.tokenizer(str(reader[index]))["input_ids"] + + print(f"length of doc {len(doc)}") + if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] # Recheck len for edge case where doc=[eos] From 97eade553569f8ad90dbb79abde4dcd1018a0500 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Wed, 26 Feb 2025 18:00:08 -0500 Subject: [PATCH 34/43] added print for lenght of doc II --- fms_fsdp/utils/dataset_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 5865b2ed..19289706 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -405,9 +405,10 @@ def length(self, path: str): def get(self, reader, index: int, drop_tokens: Set): - doc = self.tokenizer(str(reader[index]))["input_ids"] + document_str = str(reader[index]) + print(f"Length of document in characters: {len(document_str)}") - print(f"length of doc {len(doc)}") + doc = self.tokenizer(str(reader[index]))["input_ids"] if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] From 1b50708ec4c08b94df31c8866730870041ba4520 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Thu, 6 Mar 2025 09:50:43 -0500 Subject: [PATCH 35/43] Update dataset_utils.py --- fms_fsdp/utils/dataset_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 19289706..8f5d3f36 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -406,7 +406,6 @@ def length(self, path: str): def get(self, reader, index: int, drop_tokens: Set): document_str = str(reader[index]) - print(f"Length of document in characters: {len(document_str)}") doc = self.tokenizer(str(reader[index]))["input_ids"] From c2acc2846dd5a130c82a083c53ddb6308babb70b Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Sat, 24 May 2025 15:16:59 -0400 Subject: [PATCH 36/43] Update dataset_utils.py From 87bdab0d8bd1a088a39793a12de6b7e1dd4a37f7 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Sat, 24 May 2025 15:25:59 -0400 Subject: [PATCH 37/43] Update dataset_utils.py --- fms_fsdp/utils/dataset_utils.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 029d450c..467543ff 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -1123,20 +1123,21 @@ def setup(self): ] tally += shard_sizes[i] # Count file exists, use it - with open(countpath, "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find(dataset) - if prefix >= 0: - key = fullpath[prefix + len(dataset) + 1 :] - doc_counts[key] = int(row["documents"]) - else: - # Count file does not exist, touch every owned file for length - doc_counts = { - shard: self.filehandler.length(os.path.join(datapath, shard)) - for shard in shardset - } + if len(countfiles) > 0: + with open(countpath, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find(dataset) + if prefix >= 0: + key = fullpath[prefix + len(dataset) + 1 :] + doc_counts[key] = int(row["documents"]) + else: + # Count file does not exist, touch every owned file for length + doc_counts = { + shard: self.filehandler.length(os.path.join(datapath, shard)) + for shard in shardset + } # Assemble doc list for each file shard # Create docset of form [shardid, min docid, max docid] From 34d59743772ac30499c043be4eea1d9cd0274c80 Mon Sep 17 00:00:00 2001 From: divya-kumari32 <72085811+divya-kumari32@users.noreply.github.com> Date: Sat, 24 May 2025 15:48:41 -0400 Subject: [PATCH 38/43] Update dataset_utils.py --- fms_fsdp/utils/dataset_utils.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 467543ff..029d450c 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -1123,21 +1123,20 @@ def setup(self): ] tally += shard_sizes[i] # Count file exists, use it - if len(countfiles) > 0: - with open(countpath, "r") as csvfile: - reader = csv.DictReader(csvfile) - for row in reader: - fullpath = row["dataset/filename"] - prefix = fullpath.find(dataset) - if prefix >= 0: - key = fullpath[prefix + len(dataset) + 1 :] - doc_counts[key] = int(row["documents"]) - else: - # Count file does not exist, touch every owned file for length - doc_counts = { - shard: self.filehandler.length(os.path.join(datapath, shard)) - for shard in shardset - } + with open(countpath, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + fullpath = row["dataset/filename"] + prefix = fullpath.find(dataset) + if prefix >= 0: + key = fullpath[prefix + len(dataset) + 1 :] + doc_counts[key] = int(row["documents"]) + else: + # Count file does not exist, touch every owned file for length + doc_counts = { + shard: self.filehandler.length(os.path.join(datapath, shard)) + for shard in shardset + } # Assemble doc list for each file shard # Create docset of form [shardid, min docid, max docid] From 2f302d485649444c4023989551a2f84740020ac8 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Mon, 26 May 2025 14:03:00 -0400 Subject: [PATCH 39/43] Adding print for debug --- fms_fsdp/utils/dataset_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index 029d450c..ad5da9c7 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -1069,6 +1069,8 @@ def setup(self): pardir, dataset = pathsplit self.dataset = dataset + print(f"printing datapath for this file: {datapath}") + # Assemble document set owned by this worker: # listdir, assemble shardfraglist (ind -> shard, frag) shards = [ From 9558e2e42a5ad5c84014369bbc9c511cb3f24da9 Mon Sep 17 00:00:00 2001 From: divykum2 Date: Mon, 26 May 2025 14:24:53 -0400 Subject: [PATCH 40/43] Revert "Pulled from data-fixes branch" This reverts commit ac5194bbf117900b21f924db76a5fb07c2d3bb25, reversing changes made to 1b50708ec4c08b94df31c8866730870041ba4520. reverting changes --- fms_fsdp/config/training.py | 7 - fms_fsdp/utils/dataloader_utils.py | 36 +--- fms_fsdp/utils/dataset_utils.py | 294 +++++++---------------------- main_training_llama.py | 8 +- main_training_mamba.py | 8 +- tests/test_datasets.py | 4 - 6 files changed, 86 insertions(+), 271 deletions(-) diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 20eb0b76..1d072958 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -72,10 +72,3 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 - - # FIM training - psm_rate: float = 0.0 - spm_rate: float = 0.0 - fim_pre: int = 1 - fim_mid: int = 2 - fim_suf: int = 3 diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 5c2016bd..afce42ec 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -5,7 +5,6 @@ AutoHandler, BufferDataset, CheckpointDataset, - FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -58,9 +57,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size): +def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): """ - Pytorch dataloader for stateful, distributed, and rescalable language model training. + Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. Assumes underlying data is sequences of integer values. ... Args @@ -71,12 +70,11 @@ def get_data_loader(cfg, rank, world_size): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. + postprocess : List[Callable] + Any task-specific postprocessing to apply before handing over data. Steps will apply in + the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - fim_training = cfg.psm_rate + cfg.spm_rate > 0 - if fim_training: - assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" - datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. @@ -91,7 +89,7 @@ def get_data_loader(cfg, rank, world_size): if cfg.file_type == "hf_parquet" or cfg.file_type == "auto": filehandler = _handler_map[cfg.file_type](cfg.tokenizer_path, cols) else: - filehandler = _handler_map[cfg.file_type](cols) + filehandler = _handler_map[cfg.file_type, cols] # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -120,10 +118,9 @@ def get_data_loader(cfg, rank, world_size): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. - # Increment seq len to counteract CLM's one token removal. data = BufferDataset( data, - cfg.seq_length + 1, + cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -131,23 +128,10 @@ def get_data_loader(cfg, rank, world_size): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply FIM transformation if needed - if fim_training: - data = FIMDataset( - data, - cfg.eos_token, - cfg.psm_rate, - cfg.spm_rate, - pre_token=cfg.fim_pre, - mid_token=cfg.fim_mid, - suf_token=cfg.fim_suf, - ) - - # Transform to tensors + # Apply desired postprocessing steps in sequence data = PreprocessDataset(data, torch.IntTensor) - - # Apply CLM transformation - data = PreprocessDataset(data, causal_lm) + for p in postprocess: + data = PreprocessDataset(data, p) # Enable auto-saving data = CheckpointDataset( diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index ad5da9c7..2522a536 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -343,9 +343,9 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): + def __init__(self, col_names: List[str] = ["tokens"]): self.col_names = col_names - + def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,18 +356,14 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - assert ( - index < reader.num_record_batches - ), f"Illegal index {index} in set of {reader.num_record_batches} documents" frame = reader.get_batch(index) + doc = None for name in self.col_names: if name in frame.column_names: doc = frame[name] break - assert ( - doc is not None - ), f"None of column names {self.col_names} found in file headers {frame.column_names}" + assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -387,9 +383,7 @@ class ParquetHandler(_ShardFileHandler): before getting/slicing. However, this is a standard and widely-used data format. """ - def __init__( - self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"] - ): + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) self.col_names = col_names @@ -403,19 +397,18 @@ def open(self, path: str): if name in names: match = name break - assert ( - match is not None - ), f"None of column names {self.col_names} found in file headers {names}" + assert match is not None, f"None of column names {self.col_names} found in file headers {names}" return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): return pq.read_metadata(path).num_rows def get(self, reader, index: int, drop_tokens: Set): - assert ( - index < reader.length() - ), f"Illegal index {index} in set of {reader.length()} documents" - doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"] + + document_str = str(reader[index]) + + doc = self.tokenizer(str(reader[index]))["input_ids"] + if len(doc) > 0 and doc[0] in drop_tokens: doc = doc[1:] # Recheck len for edge case where doc=[eos] @@ -428,9 +421,7 @@ def slice(self, doc: List, index: int, n_pull: int) -> List: class AutoHandler(_ShardFileHandler): - def __init__( - self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"] - ): + def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]): self.PHandler = ParquetHandler(tokenizer_path, col_names) self.AHandler = ArrowHandler(col_names) self.current = _ShardFileHandler() @@ -721,128 +712,6 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts -class FIMDataset(_WrapperDataset): - """ - Wrapper for a StatefulDataset that implements Fill-In-the-Middle training - (https://arxiv.org/pdf/2207.14255). - Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). - Breaks sequence apart into component document spans, and for each document span - of sufficient length, transforms with specified probability into: - PSM mode:
 (prefix)  (suffix)  (middle) 
-    SPM mode: 
  (suffix)  (prefix) (middle) 
-    The new delimiter tokens can be omitted by passing in None.
-    Any extra tokens after transformation are dropped from the end of the sequence.
-    ...
-    Args
-    ----
-    dataset : _StatefulDataset
-        Fully instantiated dataset
-    delimiter_token : any
-        Token used to indicate document boundaries
-    psm_rate : float
-        Chance to transform into PSM. Cannot exceed 1.
-    spm_rate : float
-        Chance to transform into SPM. Cannot exceed 1.
-    min_len : int
-        Minimum document length to perform FIM transformation
-    pre_token : any | none
-        Token used to indicate prefix section of the document
-    mid_token : any | none
-        Token used to indicate middle infill section of the document
-    suf_token : any | none
-        Token used to indicate suffix section of the document
-    """
-
-    def __init__(
-        self,
-        dataset: _StatefulDataset,
-        delimiter_token: Any,
-        psm_rate: float = 0.0,
-        spm_rate: float = 0.0,
-        min_len: int = 10,
-        pre_token=None,
-        mid_token=None,
-        suf_token=None,
-    ):
-        super().__init__(dataset)
-        assert (
-            psm_rate + spm_rate > 0
-        ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
-        assert (
-            psm_rate + spm_rate <= 1
-        ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
-        self.psm = psm_rate
-        self.spm = spm_rate
-        self.delimiter = delimiter_token
-        self.min_len = min_len
-        self.pref = pre_token
-        self.suff = suf_token
-        self.midd = mid_token
-
-        self.g_state = None
-        self.generator = torch.Generator().manual_seed(self.rank)
-        self.state_params = ["g_state"]
-
-    def __iter__(self):
-        dataset = iter(self.dataset)
-        while True:
-            inp = next(dataset)
-            len_ = len(inp)
-            i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
-            docs = [
-                inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
-            ]  # list[list[any]]
-            out = []
-            for i in range(len(docs)):
-                doc = docs[i]
-                if len(docs[i]) >= self.min_len:
-                    # decide psm, spm, or nothing
-                    thresh = torch.rand([1], generator=self.generator).item()
-                    if thresh < self.psm + self.spm:
-                        # Split doc
-                        doc = []
-                        if self.pref:
-                            doc = [self.pref]
-                        splits = torch.randint(
-                            0, len(docs[i]), [2], generator=self.generator
-                        ).tolist()
-                        pre = docs[i][: min(splits)]
-                        mid = docs[i][min(splits) : max(splits)]
-                        suf = docs[i][max(splits) :]
-
-                        if thresh < self.psm:
-                            # PSM transformation
-                            doc += pre
-                            if self.suff:
-                                doc.append(self.suff)
-                            doc += suf
-                            if self.midd:
-                                doc.append(self.midd)
-                            doc += mid
-                        else:
-                            # SPM transformation
-                            if self.suff:
-                                doc.append(self.suff)
-                            doc += suf
-                            if self.midd:
-                                doc.append(self.midd)
-                            doc += pre + mid
-                out += doc + [self.delimiter]
-            yield out[:len_]
-
-    def state_dict(self):
-        # Write generator state manually
-        self.g_state = self.generator.get_state()
-        return super().state_dict()
-
-    def load_state_dict(self, state_dicts, sharded_input=False):
-        sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
-        # Manually set generator state if it exists
-        if self.g_state is not None:
-            self.generator.set_state(self.g_state)
-        return sharded_dicts
-
-
 class BufferDataset(_WrapperDataset):
     """
     Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
@@ -988,10 +857,10 @@ class StreamingDocDataset(_StatefulDataset):
         Documents below this length are skipped
     max_chunksize : int
         Maximum sequence length to return. Break long docs into chunks of this size or shorter.
-    max_consecutive_chunks : int
-        Number of doc chunks to emit before manually inserting EOS and resuming later.
     verbose : bool
         Track setup progress?
+    shuffle : bool
+        Shuffle shard file and document orders? (Disable for simple testing)
     """
 
     def __init__(
@@ -1006,7 +875,6 @@ def __init__(
         seed: int = 42,
         min_length: int = 1,
         max_chunksize: int = 1024,
-        max_consecutive_chunks: int = 64,
         verbose: bool = False,
     ):
         super().__init__(datapath, rank, worldsize)
@@ -1019,10 +887,10 @@ def __init__(
         self.eos = delimiter_token
         self.bos = bos_token
         self.drop = strip_tokens
-        self.max_consec = max_consecutive_chunks
         self.verbose = verbose
-        # Map of doc indices to (shardid, min docid, max docid)
-        self.docset: List[Any] = []
+        self.docset: List[
+            Any
+        ] = []  # map of doc indices to (shardid, min docid, max docid)
 
         # Position
         self.docset_index = 0
@@ -1034,7 +902,6 @@ def __init__(
         self.tokens_seen = 0
         self.docs_seen = 0
         self.percent_seen = 0
-        self.consec = 0
 
         self.state_params = [
             "dataset",
@@ -1045,7 +912,6 @@ def __init__(
             "docs_seen",
             "percent_seen",
             "lcg_state",
-            "consec",
         ]
 
         # Setup flags
@@ -1080,8 +946,17 @@ def setup(self):
                 if self.filehandler.is_legal(os.path.join(root, name))
             ]
             shards.sort()  # Ensure consistent sharding across machines
+            start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
+            end_frag = (
+                (self.rank + 1) * self.worldsize * len(shards)
+            ) // self.worldsize
+            shardfrags = [
+                (shards[i // self.worldsize], i % self.worldsize)
+                for i in range(start_frag, end_frag)
+            ]
+
+            # Assemble length of each owned shard file
 
-            # Find metadata file
             countfiles = []
             if os.path.exists(os.path.join(pardir, "meta")):
                 countfiles = [
@@ -1089,74 +964,55 @@ def setup(self):
                     for x in os.listdir(os.path.join(pardir, "meta"))
                     if "counts" in x and "csv" in x
                 ]
+            doc_counts = {}
             if len(countfiles) > 0:
                 # Count file exists, use it
                 countpath = os.path.join(pardir, "meta", countfiles[0])
-            else:
-                countpath = ""
-
-            # Use shard file sizes to perform partitioning
-            # Create shardlist of form shardid -> [start%, end%]
-            if len(countfiles) > 0:
-                sizes = {}
-                with open(countpath, "r") as csvfile:
-                    reader = csv.DictReader(csvfile)
-                    for row in reader:
-                        fullpath = row["dataset/filename"]
-                        prefix = fullpath.find(dataset + "/")
-                        if prefix >= 0:
-                            key = fullpath[prefix + len(dataset) + 1 :]
-                            sizes[key] = int(row["size"])
-                shard_sizes = [sizes[shard] for shard in shards]
-            else:
-                shard_sizes = [
-                    os.path.getsize(os.path.join(datapath, shard)) for shard in shards
-                ]
-            shard_sizes = [s / sum(shard_sizes) for s in shard_sizes]
-            start = self.rank / self.worldsize
-            end = (self.rank + 1) / self.worldsize
-            shardset = {}
-            tally = 0
-            for i in range(len(shards)):
-                if tally <= end and tally + shard_sizes[i] >= start:
-                    shardset[shards[i]] = [
-                        min(max((start - tally) / shard_sizes[i], 0), 1),
-                        min(max((end - tally) / shard_sizes[i], 0), 1),
-                    ]
-                tally += shard_sizes[i]
-                # Count file exists, use it
                 with open(countpath, "r") as csvfile:
                     reader = csv.DictReader(csvfile)
                     for row in reader:
                         fullpath = row["dataset/filename"]
-                        prefix = fullpath.find(dataset)
-                        if prefix >= 0:
+                        prefix = fullpath.find("/" + dataset) + 1
+                        if prefix > 0:
                             key = fullpath[prefix + len(dataset) + 1 :]
                             doc_counts[key] = int(row["documents"])
             else:
                 # Count file does not exist, touch every owned file for length
+                unique_shardfiles = set(shard for shard, frag in shardfrags)
                 doc_counts = {
                     shard: self.filehandler.length(os.path.join(datapath, shard))
-                    for shard in shardset
+                    for shard in unique_shardfiles
                 }
 
-            # Assemble doc list for each file shard
-            # Create docset of form [shardid, min docid, max docid]
-            doccount = 0
-            for shard in shardset:
+            # Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
+            ndocs = -1
+            docset = {}  # shardid -> (min docid, max docid)
+            for i, (shard, frag) in enumerate(shardfrags):
                 ndocs = doc_counts[shard]
-                if ndocs > 0:
-                    doc_start = int(ndocs * shardset[shard][0])
-                    doc_end = max(
-                        doc_start, int(ndocs * shardset[shard][1]) - 1
-                    )  # inclusive upper bound
-                    self.docset.append([shard, doc_start, doc_end])
-                    doccount += doc_end - doc_start + 1
+                doc_start = (ndocs * frag) // self.worldsize
+                doc_end = (
+                    ndocs * frag + ndocs
+                ) // self.worldsize - 1  # Inclusive upper bound
+                if shard not in docset:
+                    docset[shard] = [doc_start, doc_end]
+                min_d, max_d = docset[shard]
+                if doc_start < min_d:
+                    docset[shard][0] = doc_start
+                if doc_end > max_d:
+                    docset[shard][1] = doc_end
+
+            # Add shard entries to self.docset
+            doccount = 0
+            for shardid in docset:
+                min_d = docset[shardid][0]
+                max_d = docset[shardid][1]
+                self.docset.append((shardid, min_d, max_d))
+                doccount += max_d - min_d + 1
             self._len = doccount
 
             if self.verbose:
                 logging.info(
-                    f"    Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}"
+                    f"    Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
                 )
 
             # Shuffle shard files - guaranteed inconsistent across workers
@@ -1211,11 +1067,8 @@ def _construct_chunk(self, j, doc, n_chunks):
         # Add bos/eos tokens if needed
         if self.bos is not None and j == 0:
             chunk = [self.bos] + chunk
-        if j == n_chunks - 1 or self.consec == self.max_consec:
+        if j == n_chunks - 1:
             chunk = chunk + [self.eos]
-            self.consec = 0
-        else:
-            self.consec += 1
         return chunk
 
     def _random_map_docid(self, size):
@@ -1260,8 +1113,10 @@ def __iter__(self):
                 doclcg = self._random_map_docid(docrange)
                 docid = doclcg + mindoc
                 doc = self.filehandler.get(reader, docid, self.drop)
+                if len(doc) == 0:
+                    continue
                 doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
-                if len(doc) > 0 and doclen >= self.min_length:
+                if doclen >= self.min_length:
                     n_chunks = math.ceil(doclen / self.chunksize)
                     for j in range(n_chunks):
                         if i == 0 and j < residual_chunks:
@@ -1288,8 +1143,10 @@ def __iter__(self):
             newpath = os.path.join(self.datapath, shardid)
             path, reader = self._get_reader(path, newpath, reader)
             doc = self.filehandler.get(reader, docid, self.drop)
+            if len(doc) == 0:
+                continue
             doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
-            if len(doc) > 0 and doclen >= self.min_length:
+            if doclen >= self.min_length:
                 n_chunks = math.ceil(doclen / self.chunksize)
                 for j in range(residual_chunks):
                     self.chunk_index = j
@@ -1297,9 +1154,7 @@ def __iter__(self):
                     yield self._construct_chunk(j, doc, n_chunks)
 
             # Check that epoch was non-empty
-            assert (
-                self.has_yielded
-            ), f"Empty logical shard detected: {self.dataset, self.docset}"
+            assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
 
     def load_state_dict(self, state_dicts, sharded_input=False):
         self.setup()
@@ -1372,12 +1227,12 @@ def setup(self):
         if not self.is_setup:
             _StatefulDataset.setup(self)
             n_logical_shards = self.total_shards
-            assert (
-                n_logical_shards % self.worldsize == 0
-            ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
             logicals = list(range(n_logical_shards))
             self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
             self.n_logicals = n_logical_shards // self.worldsize
+            assert (
+                len(self.logicals_owned) == self.n_logicals
+            ), "(world size * num workers) does not divide logical shards evenly"
 
             # Build logical shards
             for i in range(self.n_logicals):
@@ -1394,9 +1249,6 @@ def setup(self):
                     )
             [d.setup() for d in self.data]
             self.n_docs_remaining = [d._len for d in self.data]
-            assert (
-                sum(self.n_docs_remaining) > 0
-            ), f"No documents detected in shard {self.rank} of {self.datapath}"
 
             self.generator = torch.Generator().manual_seed(self.rank)
 
@@ -1404,16 +1256,14 @@ def __iter__(self):
         self.setup()
         # Grab one doc at a time in random order
         data = [iter(d) for d in self.data]
-        # Reset if we're rescaling into a prematurely finished epoch
-        # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
-        if sum(self.n_docs_remaining) == 0:
-            self.n_docs_remaining = [d._len for d in self.data]
-            self.generator.manual_seed(self.rank)
         while True:
             # Sample logical shard (or load from ckp)
             if self.current_reader is not None:
                 ind = self.current_reader
             else:
+                assert (
+                    sum(self.n_docs_remaining) > 0
+                ), f"No documents detected in {self.datapath}"
                 ind = torch.multinomial(
                     torch.tensor(self.n_docs_remaining, dtype=torch.float),
                     1,
@@ -1505,10 +1355,6 @@ def __init__(
             ]
         )
         assert len(self.datasets) > 0, "You must specify at least one dataset"
-        for d in datasets:
-            assert os.path.exists(
-                os.path.join(datapath, d)
-            ), f"Invalid subdataset path: {os.path.join(datapath, d)}"
 
         if weights is not None:
             assert len(weights) == len(
diff --git a/main_training_llama.py b/main_training_llama.py
index bc9191b5..4bd2a3c3 100644
--- a/main_training_llama.py
+++ b/main_training_llama.py
@@ -119,11 +119,9 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=(
-            os.path.join(cfg.ckpt_load_path, "checkpoints/")
-            if not os.path.isfile(cfg.ckpt_load_path)
-            else cfg.ckpt_load_path
-        ),
+        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
+        if not os.path.isfile(cfg.ckpt_load_path)
+        else cfg.ckpt_load_path,
         strict=False,
     )
     if not is_resuming:
diff --git a/main_training_mamba.py b/main_training_mamba.py
index 6578e374..3ff12a60 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -119,11 +119,9 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=(
-            os.path.join(cfg.ckpt_load_path, "checkpoints/")
-            if not os.path.isfile(cfg.ckpt_load_path)
-            else cfg.ckpt_load_path
-        ),
+        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
+        if not os.path.isfile(cfg.ckpt_load_path)
+        else cfg.ckpt_load_path,
         strict=False,
     )
     if not is_resuming:
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 40bef481..83b2426b 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -632,10 +632,6 @@ def test_multi_reload_stress():
     # preload / sample / scale / doc pipeline
     multi_reload_stress_check(lambda: d6(d5(d4())))
 
-    # Add FIM dataset
-    d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x]
-    multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
-
 
 # SCALABLEDATASET TESTS
 

From db7d4ad95c05c84e74e6e4e50799e42f67e88f9f Mon Sep 17 00:00:00 2001
From: divykum2 
Date: Mon, 26 May 2025 14:29:57 -0400
Subject: [PATCH 41/43] Revert all changes made after March 6 (before merge)

---
 fms_fsdp/utils/dataloader_utils.py |  2 +-
 fms_fsdp/utils/dataset_utils.py    | 15 +++++++++++++--
 2 files changed, 14 insertions(+), 3 deletions(-)

diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py
index afce42ec..48648b05 100644
--- a/fms_fsdp/utils/dataloader_utils.py
+++ b/fms_fsdp/utils/dataloader_utils.py
@@ -161,4 +161,4 @@ def splitstrip(x):
     datas = splitstrip(datas)
     weights = [float(x) for x in splitstrip(weights)]
     cols = splitstrip(cols)
-    return datas, weights, cols
+    return datas, weights, cols
\ No newline at end of file
diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 2522a536..073c57b0 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -15,7 +15,6 @@
 
 from fms_fsdp.utils.checkpointing_utils import get_latest
 
-
 """
 The following distributed dataloaders are designed around 3 main principles:
 
@@ -356,6 +355,10 @@ def length(self, path: str):
         return self.open(path).num_record_batches
 
     def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
+<<<<<<< HEAD
+=======
+        assert index < reader.num_record_batches, f"Illegal index {index} in set of {reader.num_record_batches} documents"
+>>>>>>> parent of 15f4d7e (Blacking)
         frame = reader.get_batch(index)
         
         doc = None
@@ -383,7 +386,11 @@ class ParquetHandler(_ShardFileHandler):
     before getting/slicing. However, this is a standard and widely-used data format.
     """
 
+<<<<<<< HEAD
     def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]):
+=======
+    def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
+>>>>>>> parent of 15f4d7e (Blacking)
         self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
         self.col_names = col_names
 
@@ -404,11 +411,16 @@ def length(self, path: str):
         return pq.read_metadata(path).num_rows
 
     def get(self, reader, index: int, drop_tokens: Set):
+<<<<<<< HEAD
         
         document_str = str(reader[index])
         
         doc = self.tokenizer(str(reader[index]))["input_ids"]
         
+=======
+        assert index < reader.length(), f"Illegal index {index} in set of {reader.length()} documents"
+        doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
+>>>>>>> parent of 15f4d7e (Blacking)
         if len(doc) > 0 and doc[0] in drop_tokens:
             doc = doc[1:]
         # Recheck len for edge case where doc=[eos]
@@ -1249,7 +1261,6 @@ def setup(self):
                     )
             [d.setup() for d in self.data]
             self.n_docs_remaining = [d._len for d in self.data]
-
             self.generator = torch.Generator().manual_seed(self.rank)
 
     def __iter__(self):

From 149b6fa14944211dcec0355d1318d842054024d8 Mon Sep 17 00:00:00 2001
From: divykum2 
Date: Mon, 26 May 2025 14:32:21 -0400
Subject: [PATCH 42/43] Revert all changes made after March 6 (before merge)

---
 fms_fsdp/utils/dataset_utils.py | 13 -------------
 1 file changed, 13 deletions(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 073c57b0..994c080c 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -355,10 +355,6 @@ def length(self, path: str):
         return self.open(path).num_record_batches
 
     def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set):
-<<<<<<< HEAD
-=======
-        assert index < reader.num_record_batches, f"Illegal index {index} in set of {reader.num_record_batches} documents"
->>>>>>> parent of 15f4d7e (Blacking)
         frame = reader.get_batch(index)
         
         doc = None
@@ -386,11 +382,7 @@ class ParquetHandler(_ShardFileHandler):
     before getting/slicing. However, this is a standard and widely-used data format.
     """
 
-<<<<<<< HEAD
     def __init__(self, tokenizer_path: str, col_names: List[str] = ["text"]):
-=======
-    def __init__(self, tokenizer_path: str, col_names: List[str] = ["text", "contents", "tokens"]):
->>>>>>> parent of 15f4d7e (Blacking)
         self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
         self.col_names = col_names
 
@@ -411,16 +403,11 @@ def length(self, path: str):
         return pq.read_metadata(path).num_rows
 
     def get(self, reader, index: int, drop_tokens: Set):
-<<<<<<< HEAD
         
         document_str = str(reader[index])
         
         doc = self.tokenizer(str(reader[index]))["input_ids"]
         
-=======
-        assert index < reader.length(), f"Illegal index {index} in set of {reader.length()} documents"
-        doc = self.tokenizer(str(reader[index])[:1_000_000])["input_ids"]
->>>>>>> parent of 15f4d7e (Blacking)
         if len(doc) > 0 and doc[0] in drop_tokens:
             doc = doc[1:]
         # Recheck len for edge case where doc=[eos]

From 385c98d17d08ec01b82082901d61d28f4b23fa05 Mon Sep 17 00:00:00 2001
From: divykum2 
Date: Mon, 26 May 2025 14:44:16 -0400
Subject: [PATCH 43/43] removed print

---
 fms_fsdp/utils/dataset_utils.py | 2 --
 1 file changed, 2 deletions(-)

diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py
index 994c080c..72abb1fc 100644
--- a/fms_fsdp/utils/dataset_utils.py
+++ b/fms_fsdp/utils/dataset_utils.py
@@ -933,8 +933,6 @@ def setup(self):
                 pathsplit = os.path.split(pathsplit[0])
             pardir, dataset = pathsplit
             self.dataset = dataset
-
-            print(f"printing datapath for this file: {datapath}")
             
             # Assemble document set owned by this worker:
             # listdir, assemble shardfraglist (ind -> shard, frag)