diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 22b9a840..ca942bc6 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -76,10 +76,3 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 - - # FIM training - psm_rate: float = 0.0 - spm_rate: float = 0.0 - fim_pre: int = 1 - fim_mid: int = 2 - fim_suf: int = 3 diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 2bbeef18..65b18838 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -324,7 +324,9 @@ def save_single_file( ): # Note: metadata kwargs cannot contain any of: # (step, model) - save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth") + pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step)) + os.makedirs(pth_path, exist_ok=True) + save_name = os.path.join(pth_path, "consolidated.00.pth") save_time = time.time() with FSDP.state_dict_type( model, diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 6078ae7c..c648f40a 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -6,7 +6,6 @@ AutoHandler, BufferDataset, CheckpointDataset, - FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -59,9 +58,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size): +def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): """ - Pytorch dataloader for stateful, distributed, and rescalable language model training. + Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. Assumes underlying data is sequences of integer values. ... Args @@ -72,12 +71,11 @@ def get_data_loader(cfg, rank, world_size): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. + postprocess : List[Callable] + Any task-specific postprocessing to apply before handing over data. Steps will apply in + the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - fim_training = cfg.psm_rate + cfg.spm_rate > 0 - if fim_training: - assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" - datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. @@ -94,7 +92,7 @@ def get_data_loader(cfg, rank, world_size): cfg.tokenizer_path, cols, cfg.doc_cutoff ) else: - filehandler = _handler_map[cfg.file_type](cols) + filehandler = _handler_map[cfg.file_type, cols] # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -124,10 +122,9 @@ def get_data_loader(cfg, rank, world_size): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. - # Increment seq len to counteract CLM's one token removal. data = BufferDataset( data, - cfg.seq_length + 1, + cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -135,23 +132,10 @@ def get_data_loader(cfg, rank, world_size): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply FIM transformation if needed - if fim_training: - data = FIMDataset( - data, - cfg.eos_token, - cfg.psm_rate, - cfg.spm_rate, - pre_token=cfg.fim_pre, - mid_token=cfg.fim_mid, - suf_token=cfg.fim_suf, - ) - - # Transform to tensors + # Apply desired postprocessing steps in sequence data = PreprocessDataset(data, torch.IntTensor) - - # Apply CLM transformation - data = PreprocessDataset(data, causal_lm) + for p in postprocess: + data = PreprocessDataset(data, p) # Enable auto-saving data = CheckpointDataset( @@ -181,4 +165,4 @@ def splitstrip(x): datas = splitstrip(datas) weights = [float(x) for x in splitstrip(weights)] cols = splitstrip(cols) - return datas, weights, cols + return datas, weights, cols \ No newline at end of file diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index bc5ed772..0152687f 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -15,7 +15,6 @@ from fms_fsdp.utils.checkpointing_utils import get_latest - """ The following distributed dataloaders are designed around 3 main principles: @@ -343,9 +342,9 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): + def __init__(self, col_names: List[str] = ["tokens"]): self.col_names = col_names - + def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,18 +355,14 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - assert ( - index < reader.num_record_batches - ), f"Illegal index {index} in set of {reader.num_record_batches} documents" frame = reader.get_batch(index) + doc = None for name in self.col_names: if name in frame.column_names: doc = frame[name] break - assert ( - doc is not None - ), f"None of column names {self.col_names} found in file headers {frame.column_names}" + assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -407,9 +402,7 @@ def open(self, path: str): if name in names: match = name break - assert ( - match is not None - ), f"None of column names {self.col_names} found in file headers {names}" + assert match is not None, f"None of column names {self.col_names} found in file headers {names}" return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): @@ -728,128 +721,6 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts -class FIMDataset(_WrapperDataset): - """ - Wrapper for a StatefulDataset that implements Fill-In-the-Middle training - (https://arxiv.org/pdf/2207.14255). - Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). - Breaks sequence apart into component document spans, and for each document span - of sufficient length, transforms with specified probability into: - PSM mode:
(prefix)(suffix) (middle) - SPM mode: (suffix) (prefix) (middle) - The new delimiter tokens can be omitted by passing in None. - Any extra tokens after transformation are dropped from the end of the sequence. - ... - Args - ---- - dataset : _StatefulDataset - Fully instantiated dataset - delimiter_token : any - Token used to indicate document boundaries - psm_rate : float - Chance to transform into PSM. Cannot exceed 1. - spm_rate : float - Chance to transform into SPM. Cannot exceed 1. - min_len : int - Minimum document length to perform FIM transformation - pre_token : any | none - Token used to indicate prefix section of the document - mid_token : any | none - Token used to indicate middle infill section of the document - suf_token : any | none - Token used to indicate suffix section of the document - """ - - def __init__( - self, - dataset: _StatefulDataset, - delimiter_token: Any, - psm_rate: float = 0.0, - spm_rate: float = 0.0, - min_len: int = 10, - pre_token=None, - mid_token=None, - suf_token=None, - ): - super().__init__(dataset) - assert ( - psm_rate + spm_rate > 0 - ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate." - assert ( - psm_rate + spm_rate <= 1 - ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1." - self.psm = psm_rate - self.spm = spm_rate - self.delimiter = delimiter_token - self.min_len = min_len - self.pref = pre_token - self.suff = suf_token - self.midd = mid_token - - self.g_state = None - self.generator = torch.Generator().manual_seed(self.rank) - self.state_params = ["g_state"] - - def __iter__(self): - dataset = iter(self.dataset) - while True: - inp = next(dataset) - len_ = len(inp) - i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_] - docs = [ - inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1) - ] # list[list[any]] - out = [] - for i in range(len(docs)): - doc = docs[i] - if len(docs[i]) >= self.min_len: - # decide psm, spm, or nothing - thresh = torch.rand([1], generator=self.generator).item() - if thresh < self.psm + self.spm: - # Split doc - doc = [] - if self.pref: - doc = [self.pref] - splits = torch.randint( - 0, len(docs[i]), [2], generator=self.generator - ).tolist() - pre = docs[i][: min(splits)] - mid = docs[i][min(splits) : max(splits)] - suf = docs[i][max(splits) :] - - if thresh < self.psm: - # PSM transformation - doc += pre - if self.suff: - doc.append(self.suff) - doc += suf - if self.midd: - doc.append(self.midd) - doc += mid - else: - # SPM transformation - if self.suff: - doc.append(self.suff) - doc += suf - if self.midd: - doc.append(self.midd) - doc += pre + mid - out += doc + [self.delimiter] - yield out[:len_] - - def state_dict(self): - # Write generator state manually - self.g_state = self.generator.get_state() - return super().state_dict() - - def load_state_dict(self, state_dicts, sharded_input=False): - sharded_dicts = super().load_state_dict(state_dicts, sharded_input) - # Manually set generator state if it exists - if self.g_state is not None: - self.generator.set_state(self.g_state) - return sharded_dicts - - class BufferDataset(_WrapperDataset): """ Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them @@ -995,10 +866,10 @@ class StreamingDocDataset(_StatefulDataset): Documents below this length are skipped max_chunksize : int Maximum sequence length to return. Break long docs into chunks of this size or shorter. - max_consecutive_chunks : int - Number of doc chunks to emit before manually inserting EOS and resuming later. verbose : bool Track setup progress? + shuffle : bool + Shuffle shard file and document orders? (Disable for simple testing) """ def __init__( @@ -1013,7 +884,6 @@ def __init__( seed: int = 42, min_length: int = 1, max_chunksize: int = 1024, - max_consecutive_chunks: int = 64, verbose: bool = False, ): super().__init__(datapath, rank, worldsize) @@ -1026,10 +896,10 @@ def __init__( self.eos = delimiter_token self.bos = bos_token self.drop = strip_tokens - self.max_consec = max_consecutive_chunks self.verbose = verbose - # Map of doc indices to (shardid, min docid, max docid) - self.docset: List[Any] = [] + self.docset: List[ + Any + ] = [] # map of doc indices to (shardid, min docid, max docid) # Position self.docset_index = 0 @@ -1041,7 +911,6 @@ def __init__( self.tokens_seen = 0 self.docs_seen = 0 self.percent_seen = 0 - self.consec = 0 self.state_params = [ "dataset", @@ -1052,7 +921,6 @@ def __init__( "docs_seen", "percent_seen", "lcg_state", - "consec", ] # Setup flags @@ -1075,7 +943,7 @@ def setup(self): pathsplit = os.path.split(pathsplit[0]) pardir, dataset = pathsplit self.dataset = dataset - + # Assemble document set owned by this worker: # listdir, assemble shardfraglist (ind -> shard, frag) shards = [ @@ -1085,8 +953,17 @@ def setup(self): if self.filehandler.is_legal(os.path.join(root, name)) ] shards.sort() # Ensure consistent sharding across machines + start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize + end_frag = ( + (self.rank + 1) * self.worldsize * len(shards) + ) // self.worldsize + shardfrags = [ + (shards[i // self.worldsize], i % self.worldsize) + for i in range(start_frag, end_frag) + ] + + # Assemble length of each owned shard file - # Find metadata file countfiles = [] if os.path.exists(os.path.join(pardir, "meta")): countfiles = [ @@ -1094,6 +971,7 @@ def setup(self): for x in os.listdir(os.path.join(pardir, "meta")) if "counts" in x and "csv" in x ] + doc_counts = {} if len(countfiles) > 0: # Count file exists, use it countpath = os.path.join(pardir, "meta", countfiles[0]) @@ -1138,34 +1016,47 @@ def setup(self): reader = csv.DictReader(csvfile) for row in reader: fullpath = row["dataset/filename"] - prefix = fullpath.find(dataset) - if prefix >= 0: + prefix = fullpath.find("/" + dataset) + 1 + if prefix > 0: key = fullpath[prefix + len(dataset) + 1 :] doc_counts[key] = int(row["documents"]) else: # Count file does not exist, touch every owned file for length + unique_shardfiles = set(shard for shard, frag in shardfrags) doc_counts = { shard: self.filehandler.length(os.path.join(datapath, shard)) - for shard in shardset + for shard in unique_shardfiles } - # Assemble doc list for each file shard - # Create docset of form [shardid, min docid, max docid] - doccount = 0 - for shard in shardset: + # Read shardfrags, assemble doc list for each file shard (aggregating over fragments): + ndocs = -1 + docset = {} # shardid -> (min docid, max docid) + for i, (shard, frag) in enumerate(shardfrags): ndocs = doc_counts[shard] - if ndocs > 0: - doc_start = int(ndocs * shardset[shard][0]) - doc_end = max( - doc_start, int(ndocs * shardset[shard][1]) - 1 - ) # inclusive upper bound - self.docset.append([shard, doc_start, doc_end]) - doccount += doc_end - doc_start + 1 + doc_start = (ndocs * frag) // self.worldsize + doc_end = ( + ndocs * frag + ndocs + ) // self.worldsize - 1 # Inclusive upper bound + if shard not in docset: + docset[shard] = [doc_start, doc_end] + min_d, max_d = docset[shard] + if doc_start < min_d: + docset[shard][0] = doc_start + if doc_end > max_d: + docset[shard][1] = doc_end + + # Add shard entries to self.docset + doccount = 0 + for shardid in docset: + min_d = docset[shardid][0] + max_d = docset[shardid][1] + self.docset.append((shardid, min_d, max_d)) + doccount += max_d - min_d + 1 self._len = doccount if self.verbose: logging.info( - f" Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}" + f" Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}" ) # Shuffle shard files - guaranteed inconsistent across workers @@ -1220,11 +1111,8 @@ def _construct_chunk(self, j, doc, n_chunks): # Add bos/eos tokens if needed if self.bos is not None and j == 0: chunk = [self.bos] + chunk - if j == n_chunks - 1 or self.consec == self.max_consec: + if j == n_chunks - 1: chunk = chunk + [self.eos] - self.consec = 0 - else: - self.consec += 1 return chunk def _random_map_docid(self, size): @@ -1269,8 +1157,10 @@ def __iter__(self): doclcg = self._random_map_docid(docrange) docid = doclcg + mindoc doc = self.filehandler.get(reader, docid, self.drop) + if len(doc) == 0: + continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if len(doc) > 0 and doclen >= self.min_length: + if doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) for j in range(n_chunks): if i == 0 and j < residual_chunks: @@ -1297,8 +1187,10 @@ def __iter__(self): newpath = os.path.join(self.datapath, shardid) path, reader = self._get_reader(path, newpath, reader) doc = self.filehandler.get(reader, docid, self.drop) + if len(doc) == 0: + continue doclen = len(doc) + 1 if self.bos is None else len(doc) + 2 - if len(doc) > 0 and doclen >= self.min_length: + if doclen >= self.min_length: n_chunks = math.ceil(doclen / self.chunksize) for j in range(residual_chunks): self.chunk_index = j @@ -1306,9 +1198,7 @@ def __iter__(self): yield self._construct_chunk(j, doc, n_chunks) # Check that epoch was non-empty - assert ( - self.has_yielded - ), f"Empty logical shard detected: {self.dataset, self.docset}" + assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}" def load_state_dict(self, state_dicts, sharded_input=False): self.setup() @@ -1381,12 +1271,12 @@ def setup(self): if not self.is_setup: _StatefulDataset.setup(self) n_logical_shards = self.total_shards - assert ( - n_logical_shards % self.worldsize == 0 - ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly" logicals = list(range(n_logical_shards)) self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize) self.n_logicals = n_logical_shards // self.worldsize + assert ( + len(self.logicals_owned) == self.n_logicals + ), "(world size * num workers) does not divide logical shards evenly" # Build logical shards for i in range(self.n_logicals): @@ -1403,26 +1293,20 @@ def setup(self): ) [d.setup() for d in self.data] self.n_docs_remaining = [d._len for d in self.data] - assert ( - sum(self.n_docs_remaining) > 0 - ), f"No documents detected in shard {self.rank} of {self.datapath}" - self.generator = torch.Generator().manual_seed(self.rank) def __iter__(self): self.setup() # Grab one doc at a time in random order data = [iter(d) for d in self.data] - # Reset if we're rescaling into a prematurely finished epoch - # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] ) - if sum(self.n_docs_remaining) == 0: - self.n_docs_remaining = [d._len for d in self.data] - self.generator.manual_seed(self.rank) while True: # Sample logical shard (or load from ckp) if self.current_reader is not None: ind = self.current_reader else: + assert ( + sum(self.n_docs_remaining) > 0 + ), f"No documents detected in {self.datapath}" ind = torch.multinomial( torch.tensor(self.n_docs_remaining, dtype=torch.float), 1, @@ -1514,10 +1398,6 @@ def __init__( ] ) assert len(self.datasets) > 0, "You must specify at least one dataset" - for d in datasets: - assert os.path.exists( - os.path.join(datapath, d) - ), f"Invalid subdataset path: {os.path.join(datapath, d)}" if weights is not None: assert len(weights) == len( diff --git a/fms_to_hf_llama.py b/fms_to_hf_llama.py index 76042eee..6a81947a 100644 --- a/fms_to_hf_llama.py +++ b/fms_to_hf_llama.py @@ -3,160 +3,27 @@ from fms.models.hf.utils import to_hf_api from fms.models.llama import LLaMA from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict -from transformers import LlamaConfig, LlamaForCausalLM from fms_fsdp.utils.config_utils import get_model_config -def convert_to_hf(model: LLaMA, model_variant, is_old_fms) -> LlamaForCausalLM: - fms_hf_model = to_hf_api(model) - hf_config = fms_hf_model.config - if "llama3" in model_variant: - hf_config.bos_token_id = 128000 - hf_config.eos_token_id = 128001 - oss_hf_model = LlamaForCausalLM( - LlamaConfig( - vocab_size=hf_config.vocab_size, - hidden_size=hf_config.hidden_size, - rms_norm_eps=hf_config.norm_eps, - num_attention_heads=hf_config.nheads, - num_key_value_heads=None if hf_config.kvheads == 0 else hf_config.kvheads, - num_hidden_layers=hf_config.nlayers, - intermediate_size=hf_config.multiple_of - * ( - ( - int(hf_config.hidden_grow_factor * hf_config.hidden_size) - + hf_config.multiple_of - - 1 - ) - // hf_config.multiple_of - ), - pad_token_id=( - None if hf_config.pad_token_id == -1 else hf_config.pad_token_id - ), - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - max_position_embeddings=hf_config.max_expected_seq_len, - ) - ) - - # compute the freq from rot_emb since it is gathered lazily - rot_emb = fms_hf_model.decoder.model.rot_emb - max_seq_len = rot_emb.max_seq_len - alpha = rot_emb._alpha(max_seq_len) - ratio = rot_emb.ratio - dim = rot_emb.dim - if rot_emb.ntk_scaling: - ratio = ratio * alpha ** (dim / (dim - 2)) - freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - - with torch.no_grad(): - oss_hf_model.model.embed_tokens.weight.copy_(fms_hf_model.embedding.weight) - i = 0 - for oss_hf_layer in oss_hf_model.model.layers: - fms_hf_layer = fms_hf_model.decoder.model.layers[i] - - # self attn - if is_old_fms: - oss_hf_layer.self_attn.q_proj.weight.copy_( - fms_hf_layer.attn.query.weight - ) - oss_hf_layer.self_attn.k_proj.weight.copy_(fms_hf_layer.attn.key.weight) - oss_hf_layer.self_attn.v_proj.weight.copy_( - fms_hf_layer.attn.value.weight - ) - else: - q, k, v = torch.split( - fms_hf_layer.attn.in_proj.qkv_fused.weight, - fms_hf_layer.attn.in_proj.splits, - dim=0, - ) - oss_hf_layer.self_attn.q_proj.weight.copy_(q) - oss_hf_layer.self_attn.k_proj.weight.copy_(k) - oss_hf_layer.self_attn.v_proj.weight.copy_(v) - oss_hf_layer.self_attn.o_proj.weight.copy_(fms_hf_layer.attn.dense.weight) - oss_hf_layer.self_attn.rotary_emb.inv_freqs = freqs - - # mlp - if is_old_fms: - oss_hf_layer.mlp.gate_proj.weight.copy_( - fms_hf_layer.ff_sub_layer.wg.weight - ) - oss_hf_layer.mlp.up_proj.weight.copy_( - fms_hf_layer.ff_sub_layer.w1.weight - ) - else: - wg1_fused = fms_hf_layer.ff_sub_layer.wg1_fused.weight - wg_splits = [wg1_fused.size(0) // 2, wg1_fused.size(0) // 2] - wg, w1 = torch.split( - fms_hf_layer.ff_sub_layer.wg1_fused.weight, wg_splits, dim=0 - ) - oss_hf_layer.mlp.gate_proj.weight.copy_(wg) - oss_hf_layer.mlp.up_proj.weight.copy_(w1) - oss_hf_layer.mlp.down_proj.weight.copy_(fms_hf_layer.ff_sub_layer.w2.weight) - - # layer norm - oss_hf_layer.input_layernorm.weight.copy_(fms_hf_layer.ln.weight) - oss_hf_layer.post_attention_layernorm.weight.copy_( - fms_hf_layer.ff_ln.weight - ) - - # adjust q, k - q = oss_hf_layer.self_attn.q_proj.weight.data - q = ( - q.view(hf_config.nheads, -1, 2, q.size(1)) - .transpose(1, 2) - .reshape(*q.size()) - ) - oss_hf_layer.self_attn.q_proj.weight.copy_(q) - - k = oss_hf_layer.self_attn.k_proj.weight.data - k = ( - k.view( - hf_config.nheads if hf_config.kvheads == 0 else hf_config.kvheads, - -1, - 2, - k.size(1), - ) - .transpose(1, 2) - .reshape(*k.size()) - ) - oss_hf_layer.self_attn.k_proj.weight.copy_(k) - - i = i + 1 - oss_hf_model.model.norm.weight = fms_hf_model.decoder.model.dec_norm.weight - oss_hf_model.lm_head.weight = fms_hf_model.lm_head.weight - - return oss_hf_model - - -def main( - model_variant, compiled, is_old_fms, load_path, save_path, tokenizer_name_or_path -): +def main(model_variant, load_path, save_path, tokenizer_name_or_path): print("Initializing model...") - llama_config = get_model_config(model_variant) - with torch.device("meta"): - model = LLaMA(llama_config) - model.to_empty(device="cpu") + config_data = get_model_config(model_variant) + mamba_config = MambaConfig(**config_data) + model = MambaLMHeadModel(mamba_config) print(f"Reading state dict from {load_path}") - if not compiled: - state_dict = {"model_state": model.state_dict()} - else: - state_dict = {"model_state": {"_orig_mod": model.state_dict()}} + state_dict = {"model_state": model.state_dict()} load_state_dict( state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True ) print("Loading state dict into the model...") - if not compiled: - model.load_state_dict(state_dict["model_state"]) - else: - model.load_state_dict(state_dict["model_state"]["_orig_mod"]) + model.load_state_dict(state_dict["model_state"]) - print("Converting to HF model..") - hf_model = convert_to_hf(model, model_variant, is_old_fms) - hf_model.save_pretrained(save_path) + print("Saving model to HF-compatible format...") + model.save_pretrained(save_path) print("Copying tokenizer...") from transformers import AutoTokenizer @@ -164,7 +31,7 @@ def main( tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) tokenizer.save_pretrained(save_path) - print(f"Model converted to HF model, saving at {save_path}") + print(f"Model saving at {save_path}") if __name__ == "__main__": diff --git a/main_training_llama.py b/main_training_llama.py index a7e1020f..4bd2a3c3 100644 --- a/main_training_llama.py +++ b/main_training_llama.py @@ -1,10 +1,13 @@ import math import os +from pathlib import Path import fire import torch import torch.optim as optim -from fms.models.llama import LLaMA, LLaMABlock +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel +from mamba_ssm.modules.block import Block from torch import distributed as dist from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.optim.lr_scheduler import LambdaLR @@ -44,9 +47,12 @@ def main(**kwargs): torch.cuda.set_device(local_rank) torch.cuda.empty_cache() setup_environ_flags() + os.environ["TRITON_CACHE_DIR"] = os.path.join( + Path.home(), ".triton", "cache", str(local_rank) + ) # get policy - block = LLaMABlock + block = Block ( mixed_precision_policy, wrapping_policy, @@ -55,14 +61,10 @@ def main(**kwargs): param_init_fn, ) = get_policies(cfg, rank, block) - # get fms model - llama_config = get_model_config(cfg.model_variant) - if cfg.low_cpu_fsdp: - with torch.device("meta"): - model = LLaMA(llama_config) - else: - model = LLaMA(llama_config) - model.reset_parameters() + # get model + config_data = get_model_config(cfg.model_variant) + mamba_config = MambaConfig(**config_data) + model = MambaLMHeadModel(mamba_config) if rank == 0: total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) @@ -89,11 +91,6 @@ def main(**kwargs): limit_all_gathers=True, param_init_fn=param_init_fn, ) - # we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution. - model.rot_emb.compute_freqs_cis( - torch.device("cuda", torch.cuda.current_device()), - model.config.max_expected_seq_len, - ) # fsdp activation checkpointing if cfg.fsdp_activation_checkpointing: @@ -122,11 +119,9 @@ def main(**kwargs): model, optimizer, None, - path=( - os.path.join(cfg.ckpt_load_path, "checkpoints/") - if not os.path.isfile(cfg.ckpt_load_path) - else cfg.ckpt_load_path - ), + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path, strict=False, ) if not is_resuming: @@ -139,14 +134,37 @@ def main(**kwargs): if cfg.training_stage == "annealing": schedule = lambda x: 1 - x / cfg.num_steps else: - warmup_interval = min(2000, cfg.num_steps // 20) - schedule = lambda x: min( - 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, - 0.1 - + 0.5 - * (1 - 0.1) - * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), - ) + + # (cosine 0.01 decay) + # warmup_interval = min(1000, cfg.num_steps // 10) + # schedule = lambda x: min( + # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + # 0.01 + # + 0.5 + # * (1 - 0.01) + # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + # ) + + # (constant schedule) + # warmup_interval = 1000 + # schedule = lambda x: ( + # min(x, warmup_interval) / warmup_interval + # ) + schedule = lambda x: min(1.0, x) + + # (cosine 0.1 decay) + # warmup_interval = min(2000, cfg.num_steps // 20) + # schedule = lambda x: min( + # 1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2, + # 0.1 + # + 0.5 + # * (1 - 0.1) + # * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), + # # ) + + # linear decay to 50b tokens and then constant lr + # schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 + scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) # profiler diff --git a/main_training_mamba.py b/main_training_mamba.py index 68a3c830..3ff12a60 100644 --- a/main_training_mamba.py +++ b/main_training_mamba.py @@ -119,11 +119,9 @@ def main(**kwargs): model, optimizer, None, - path=( - os.path.join(cfg.ckpt_load_path, "checkpoints/") - if not os.path.isfile(cfg.ckpt_load_path) - else cfg.ckpt_load_path - ), + path=os.path.join(cfg.ckpt_load_path, "checkpoints/") + if not os.path.isfile(cfg.ckpt_load_path) + else cfg.ckpt_load_path, strict=False, ) if not is_resuming: @@ -135,8 +133,9 @@ def main(**kwargs): # LR schedule # linear decay for annealing if cfg.training_stage == "annealing": - schedule = lambda x: 1 - x / cfg.num_steps - else: + warmup_interval = 1000 + schedule = lambda x: x / warmup_interval if x < warmup_interval else 1 - (x - warmup_interval) / (cfg.num_steps - warmup_interval) + elif cfg.training_stage == "cosine": # cosine decay warmup_interval = min(2000, cfg.num_steps // 20) schedule = lambda x: min( @@ -146,6 +145,11 @@ def main(**kwargs): * (1 - 0.1) * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)), ) + elif cfg.training_stage == "constant": + warmup_interval = 2000 + schedule = lambda x: (min(x, warmup_interval) / warmup_interval) + else: + schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75 scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step)) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 40bef481..83b2426b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -632,10 +632,6 @@ def test_multi_reload_stress(): # preload / sample / scale / doc pipeline multi_reload_stress_check(lambda: d6(d5(d4()))) - # Add FIM dataset - d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x] - multi_reload_stress_check(lambda: d7(d6(d5(d4())))) - # SCALABLEDATASET TESTS