diff --git a/fms_fsdp/config/training.py b/fms_fsdp/config/training.py index 22b9a840..ca942bc6 100644 --- a/fms_fsdp/config/training.py +++ b/fms_fsdp/config/training.py @@ -76,10 +76,3 @@ class train_config: stage2_prompt_length: int = 64 stage2_batch_size: int = 96 stage2_seq_length: int = 256 - - # FIM training - psm_rate: float = 0.0 - spm_rate: float = 0.0 - fim_pre: int = 1 - fim_mid: int = 2 - fim_suf: int = 3 diff --git a/fms_fsdp/utils/checkpointing_utils.py b/fms_fsdp/utils/checkpointing_utils.py index 2bbeef18..65b18838 100644 --- a/fms_fsdp/utils/checkpointing_utils.py +++ b/fms_fsdp/utils/checkpointing_utils.py @@ -324,7 +324,9 @@ def save_single_file( ): # Note: metadata kwargs cannot contain any of: # (step, model) - save_name = os.path.join(self.ckp_path, "step_" + str(step) + "_ckp.pth") + pth_path = os.path.join(self.ckp_path[:-12], "pth", "step_" + str(step)) + os.makedirs(pth_path, exist_ok=True) + save_name = os.path.join(pth_path, "consolidated.00.pth") save_time = time.time() with FSDP.state_dict_type( model, diff --git a/fms_fsdp/utils/dataloader_utils.py b/fms_fsdp/utils/dataloader_utils.py index 6078ae7c..c648f40a 100644 --- a/fms_fsdp/utils/dataloader_utils.py +++ b/fms_fsdp/utils/dataloader_utils.py @@ -6,7 +6,6 @@ AutoHandler, BufferDataset, CheckpointDataset, - FIMDataset, ParquetHandler, PreloadBufferDataset, PreprocessDataset, @@ -59,9 +58,9 @@ def __iter__(self): return torch.utils.data.DataLoader(data, batch_size=cfg.batch_size) -def get_data_loader(cfg, rank, world_size): +def get_data_loader(cfg, rank, world_size, postprocess=[causal_lm]): """ - Pytorch dataloader for stateful, distributed, and rescalable language model training. + Pytorch dataloader for stateful, distributed, and rescalable causal language model (CLM) training. Assumes underlying data is sequences of integer values. ... Args @@ -72,12 +71,11 @@ def get_data_loader(cfg, rank, world_size): Rank of current distributed worker. Used for handling dataset sharding logic. world_size : int Number of distributed workers. Used for handling dataset sharding logic. + postprocess : List[Callable] + Any task-specific postprocessing to apply before handing over data. Steps will apply in + the order provided by the user. For CLM training, use postprocess=[causal_lm]. """ - fim_training = cfg.psm_rate + cfg.spm_rate > 0 - if fim_training: - assert cfg.bos_token is None, "No BOS in FIM training. Did you mean fim_pre?" - datasets, weights, cols = parse_data_args(cfg.datasets, cfg.weights, cfg.col_name) # Base streaming dataset. Returns doc chunks in sequence. @@ -94,7 +92,7 @@ def get_data_loader(cfg, rank, world_size): cfg.tokenizer_path, cols, cfg.doc_cutoff ) else: - filehandler = _handler_map[cfg.file_type](cols) + filehandler = _handler_map[cfg.file_type, cols] # Base reader layer data = StreamingDocDataset( cfg.data_path, @@ -124,10 +122,9 @@ def get_data_loader(cfg, rank, world_size): verbose=(rank == 0), ) # Wrap above dataset in packing logic to form constant-length lines. - # Increment seq len to counteract CLM's one token removal. data = BufferDataset( data, - cfg.seq_length + 1, + cfg.seq_length if causal_lm not in postprocess else cfg.seq_length + 1, bos_token=cfg.bol_token, eos_token=cfg.eol_token, pack_hard=True, @@ -135,23 +132,10 @@ def get_data_loader(cfg, rank, world_size): # Shuffle outputs in length 10k buffer. Consecutive lines appear 10k steps apart on average. data = PreloadBufferDataset(data, 10000) - # Apply FIM transformation if needed - if fim_training: - data = FIMDataset( - data, - cfg.eos_token, - cfg.psm_rate, - cfg.spm_rate, - pre_token=cfg.fim_pre, - mid_token=cfg.fim_mid, - suf_token=cfg.fim_suf, - ) - - # Transform to tensors + # Apply desired postprocessing steps in sequence data = PreprocessDataset(data, torch.IntTensor) - - # Apply CLM transformation - data = PreprocessDataset(data, causal_lm) + for p in postprocess: + data = PreprocessDataset(data, p) # Enable auto-saving data = CheckpointDataset( @@ -181,4 +165,4 @@ def splitstrip(x): datas = splitstrip(datas) weights = [float(x) for x in splitstrip(weights)] cols = splitstrip(cols) - return datas, weights, cols + return datas, weights, cols \ No newline at end of file diff --git a/fms_fsdp/utils/dataset_utils.py b/fms_fsdp/utils/dataset_utils.py index bc5ed772..0152687f 100644 --- a/fms_fsdp/utils/dataset_utils.py +++ b/fms_fsdp/utils/dataset_utils.py @@ -15,7 +15,6 @@ from fms_fsdp.utils.checkpointing_utils import get_latest - """ The following distributed dataloaders are designed around 3 main principles: @@ -343,9 +342,9 @@ class ArrowHandler(_ShardFileHandler): Non-standard data format, though. """ - def __init__(self, col_names: List[str] = ["text", "contents", "tokens"]): + def __init__(self, col_names: List[str] = ["tokens"]): self.col_names = col_names - + def is_legal(self, filepath: str): return "arrow" in os.path.splitext(filepath)[1] @@ -356,18 +355,14 @@ def length(self, path: str): return self.open(path).num_record_batches def get(self, reader: pa.RecordBatchFileReader, index: int, drop_tokens: Set): - assert ( - index < reader.num_record_batches - ), f"Illegal index {index} in set of {reader.num_record_batches} documents" frame = reader.get_batch(index) + doc = None for name in self.col_names: if name in frame.column_names: doc = frame[name] break - assert ( - doc is not None - ), f"None of column names {self.col_names} found in file headers {frame.column_names}" + assert doc is not None, f"None of column names {self.col_names} found in file headers {frame.column_names}" if len(doc) > 0 and doc[0].as_py() in drop_tokens: doc = doc.slice(1, len(doc) - 1) # Recheck len for edge case where doc=[eos] @@ -407,9 +402,7 @@ def open(self, path: str): if name in names: match = name break - assert ( - match is not None - ), f"None of column names {self.col_names} found in file headers {names}" + assert match is not None, f"None of column names {self.col_names} found in file headers {names}" return pq.read_pandas(path, columns=[match], partitioning=None)[match] def length(self, path: str): @@ -728,128 +721,6 @@ def load_state_dict(self, state_dicts, sharded_input=False): return sharded_dicts -class FIMDataset(_WrapperDataset): - """ - Wrapper for a StatefulDataset that implements Fill-In-the-Middle training - (https://arxiv.org/pdf/2207.14255). - Input should be a packed sequence (i.e. call BufferDataset before FIMDataset). - Breaks sequence apart into component document spans, and for each document span - of sufficient length, transforms with specified probability into: - PSM mode:
 (prefix)  (suffix)  (middle) 
-    SPM mode: 
  (suffix)  (prefix) (middle) 
-    The new delimiter tokens can be omitted by passing in None.
-    Any extra tokens after transformation are dropped from the end of the sequence.
-    ...
-    Args
-    ----
-    dataset : _StatefulDataset
-        Fully instantiated dataset
-    delimiter_token : any
-        Token used to indicate document boundaries
-    psm_rate : float
-        Chance to transform into PSM. Cannot exceed 1.
-    spm_rate : float
-        Chance to transform into SPM. Cannot exceed 1.
-    min_len : int
-        Minimum document length to perform FIM transformation
-    pre_token : any | none
-        Token used to indicate prefix section of the document
-    mid_token : any | none
-        Token used to indicate middle infill section of the document
-    suf_token : any | none
-        Token used to indicate suffix section of the document
-    """
-
-    def __init__(
-        self,
-        dataset: _StatefulDataset,
-        delimiter_token: Any,
-        psm_rate: float = 0.0,
-        spm_rate: float = 0.0,
-        min_len: int = 10,
-        pre_token=None,
-        mid_token=None,
-        suf_token=None,
-    ):
-        super().__init__(dataset)
-        assert (
-            psm_rate + spm_rate > 0
-        ), f"FIM training requires SPM or PSM transformation. Please specify a nonzero psm_rate or spm_rate."
-        assert (
-            psm_rate + spm_rate <= 1
-        ), f"Combined psm_rate {psm_rate} and spm_rate {spm_rate} probabilities cannot exceed 1."
-        self.psm = psm_rate
-        self.spm = spm_rate
-        self.delimiter = delimiter_token
-        self.min_len = min_len
-        self.pref = pre_token
-        self.suff = suf_token
-        self.midd = mid_token
-
-        self.g_state = None
-        self.generator = torch.Generator().manual_seed(self.rank)
-        self.state_params = ["g_state"]
-
-    def __iter__(self):
-        dataset = iter(self.dataset)
-        while True:
-            inp = next(dataset)
-            len_ = len(inp)
-            i_eos = [0] + [i for i, x in enumerate(inp) if x == self.delimiter] + [len_]
-            docs = [
-                inp[i_eos[j] + 1 : i_eos[j + 1]] for j in range(len(i_eos) - 1)
-            ]  # list[list[any]]
-            out = []
-            for i in range(len(docs)):
-                doc = docs[i]
-                if len(docs[i]) >= self.min_len:
-                    # decide psm, spm, or nothing
-                    thresh = torch.rand([1], generator=self.generator).item()
-                    if thresh < self.psm + self.spm:
-                        # Split doc
-                        doc = []
-                        if self.pref:
-                            doc = [self.pref]
-                        splits = torch.randint(
-                            0, len(docs[i]), [2], generator=self.generator
-                        ).tolist()
-                        pre = docs[i][: min(splits)]
-                        mid = docs[i][min(splits) : max(splits)]
-                        suf = docs[i][max(splits) :]
-
-                        if thresh < self.psm:
-                            # PSM transformation
-                            doc += pre
-                            if self.suff:
-                                doc.append(self.suff)
-                            doc += suf
-                            if self.midd:
-                                doc.append(self.midd)
-                            doc += mid
-                        else:
-                            # SPM transformation
-                            if self.suff:
-                                doc.append(self.suff)
-                            doc += suf
-                            if self.midd:
-                                doc.append(self.midd)
-                            doc += pre + mid
-                out += doc + [self.delimiter]
-            yield out[:len_]
-
-    def state_dict(self):
-        # Write generator state manually
-        self.g_state = self.generator.get_state()
-        return super().state_dict()
-
-    def load_state_dict(self, state_dicts, sharded_input=False):
-        sharded_dicts = super().load_state_dict(state_dicts, sharded_input)
-        # Manually set generator state if it exists
-        if self.g_state is not None:
-            self.generator.set_state(self.g_state)
-        return sharded_dicts
-
-
 class BufferDataset(_WrapperDataset):
     """
     Wrapper for a _StatefulDataset that takes in sequences of varying lengths, and packs/pads them
@@ -995,10 +866,10 @@ class StreamingDocDataset(_StatefulDataset):
         Documents below this length are skipped
     max_chunksize : int
         Maximum sequence length to return. Break long docs into chunks of this size or shorter.
-    max_consecutive_chunks : int
-        Number of doc chunks to emit before manually inserting EOS and resuming later.
     verbose : bool
         Track setup progress?
+    shuffle : bool
+        Shuffle shard file and document orders? (Disable for simple testing)
     """
 
     def __init__(
@@ -1013,7 +884,6 @@ def __init__(
         seed: int = 42,
         min_length: int = 1,
         max_chunksize: int = 1024,
-        max_consecutive_chunks: int = 64,
         verbose: bool = False,
     ):
         super().__init__(datapath, rank, worldsize)
@@ -1026,10 +896,10 @@ def __init__(
         self.eos = delimiter_token
         self.bos = bos_token
         self.drop = strip_tokens
-        self.max_consec = max_consecutive_chunks
         self.verbose = verbose
-        # Map of doc indices to (shardid, min docid, max docid)
-        self.docset: List[Any] = []
+        self.docset: List[
+            Any
+        ] = []  # map of doc indices to (shardid, min docid, max docid)
 
         # Position
         self.docset_index = 0
@@ -1041,7 +911,6 @@ def __init__(
         self.tokens_seen = 0
         self.docs_seen = 0
         self.percent_seen = 0
-        self.consec = 0
 
         self.state_params = [
             "dataset",
@@ -1052,7 +921,6 @@ def __init__(
             "docs_seen",
             "percent_seen",
             "lcg_state",
-            "consec",
         ]
 
         # Setup flags
@@ -1075,7 +943,7 @@ def setup(self):
                 pathsplit = os.path.split(pathsplit[0])
             pardir, dataset = pathsplit
             self.dataset = dataset
-
+            
             # Assemble document set owned by this worker:
             # listdir, assemble shardfraglist (ind -> shard, frag)
             shards = [
@@ -1085,8 +953,17 @@ def setup(self):
                 if self.filehandler.is_legal(os.path.join(root, name))
             ]
             shards.sort()  # Ensure consistent sharding across machines
+            start_frag = (self.rank * self.worldsize * len(shards)) // self.worldsize
+            end_frag = (
+                (self.rank + 1) * self.worldsize * len(shards)
+            ) // self.worldsize
+            shardfrags = [
+                (shards[i // self.worldsize], i % self.worldsize)
+                for i in range(start_frag, end_frag)
+            ]
+
+            # Assemble length of each owned shard file
 
-            # Find metadata file
             countfiles = []
             if os.path.exists(os.path.join(pardir, "meta")):
                 countfiles = [
@@ -1094,6 +971,7 @@ def setup(self):
                     for x in os.listdir(os.path.join(pardir, "meta"))
                     if "counts" in x and "csv" in x
                 ]
+            doc_counts = {}
             if len(countfiles) > 0:
                 # Count file exists, use it
                 countpath = os.path.join(pardir, "meta", countfiles[0])
@@ -1138,34 +1016,47 @@ def setup(self):
                     reader = csv.DictReader(csvfile)
                     for row in reader:
                         fullpath = row["dataset/filename"]
-                        prefix = fullpath.find(dataset)
-                        if prefix >= 0:
+                        prefix = fullpath.find("/" + dataset) + 1
+                        if prefix > 0:
                             key = fullpath[prefix + len(dataset) + 1 :]
                             doc_counts[key] = int(row["documents"])
             else:
                 # Count file does not exist, touch every owned file for length
+                unique_shardfiles = set(shard for shard, frag in shardfrags)
                 doc_counts = {
                     shard: self.filehandler.length(os.path.join(datapath, shard))
-                    for shard in shardset
+                    for shard in unique_shardfiles
                 }
 
-            # Assemble doc list for each file shard
-            # Create docset of form [shardid, min docid, max docid]
-            doccount = 0
-            for shard in shardset:
+            # Read shardfrags, assemble doc list for each file shard (aggregating over fragments):
+            ndocs = -1
+            docset = {}  # shardid -> (min docid, max docid)
+            for i, (shard, frag) in enumerate(shardfrags):
                 ndocs = doc_counts[shard]
-                if ndocs > 0:
-                    doc_start = int(ndocs * shardset[shard][0])
-                    doc_end = max(
-                        doc_start, int(ndocs * shardset[shard][1]) - 1
-                    )  # inclusive upper bound
-                    self.docset.append([shard, doc_start, doc_end])
-                    doccount += doc_end - doc_start + 1
+                doc_start = (ndocs * frag) // self.worldsize
+                doc_end = (
+                    ndocs * frag + ndocs
+                ) // self.worldsize - 1  # Inclusive upper bound
+                if shard not in docset:
+                    docset[shard] = [doc_start, doc_end]
+                min_d, max_d = docset[shard]
+                if doc_start < min_d:
+                    docset[shard][0] = doc_start
+                if doc_end > max_d:
+                    docset[shard][1] = doc_end
+
+            # Add shard entries to self.docset
+            doccount = 0
+            for shardid in docset:
+                min_d = docset[shardid][0]
+                max_d = docset[shardid][1]
+                self.docset.append((shardid, min_d, max_d))
+                doccount += max_d - min_d + 1
             self._len = doccount
 
             if self.verbose:
                 logging.info(
-                    f"    Worker {self.rank} ingested {len(self.docset)} shard fragments from {dataset}"
+                    f"    Worker {self.rank} ingested {len(shardfrags)} shard fragments from {dataset}"
                 )
 
             # Shuffle shard files - guaranteed inconsistent across workers
@@ -1220,11 +1111,8 @@ def _construct_chunk(self, j, doc, n_chunks):
         # Add bos/eos tokens if needed
         if self.bos is not None and j == 0:
             chunk = [self.bos] + chunk
-        if j == n_chunks - 1 or self.consec == self.max_consec:
+        if j == n_chunks - 1:
             chunk = chunk + [self.eos]
-            self.consec = 0
-        else:
-            self.consec += 1
         return chunk
 
     def _random_map_docid(self, size):
@@ -1269,8 +1157,10 @@ def __iter__(self):
                 doclcg = self._random_map_docid(docrange)
                 docid = doclcg + mindoc
                 doc = self.filehandler.get(reader, docid, self.drop)
+                if len(doc) == 0:
+                    continue
                 doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
-                if len(doc) > 0 and doclen >= self.min_length:
+                if doclen >= self.min_length:
                     n_chunks = math.ceil(doclen / self.chunksize)
                     for j in range(n_chunks):
                         if i == 0 and j < residual_chunks:
@@ -1297,8 +1187,10 @@ def __iter__(self):
             newpath = os.path.join(self.datapath, shardid)
             path, reader = self._get_reader(path, newpath, reader)
             doc = self.filehandler.get(reader, docid, self.drop)
+            if len(doc) == 0:
+                continue
             doclen = len(doc) + 1 if self.bos is None else len(doc) + 2
-            if len(doc) > 0 and doclen >= self.min_length:
+            if doclen >= self.min_length:
                 n_chunks = math.ceil(doclen / self.chunksize)
                 for j in range(residual_chunks):
                     self.chunk_index = j
@@ -1306,9 +1198,7 @@ def __iter__(self):
                     yield self._construct_chunk(j, doc, n_chunks)
 
             # Check that epoch was non-empty
-            assert (
-                self.has_yielded
-            ), f"Empty logical shard detected: {self.dataset, self.docset}"
+            assert self.has_yielded, f"Empty logical shard detected: {self.dataset, self.docset}"
 
     def load_state_dict(self, state_dicts, sharded_input=False):
         self.setup()
@@ -1381,12 +1271,12 @@ def setup(self):
         if not self.is_setup:
             _StatefulDataset.setup(self)
             n_logical_shards = self.total_shards
-            assert (
-                n_logical_shards % self.worldsize == 0
-            ), f"Total workers {self.worldsize} must divide n_logical_shards {n_logical_shards} evenly"
             logicals = list(range(n_logical_shards))
             self.logicals_owned = _shard_partition(logicals, self.rank, self.worldsize)
             self.n_logicals = n_logical_shards // self.worldsize
+            assert (
+                len(self.logicals_owned) == self.n_logicals
+            ), "(world size * num workers) does not divide logical shards evenly"
 
             # Build logical shards
             for i in range(self.n_logicals):
@@ -1403,26 +1293,20 @@ def setup(self):
                     )
             [d.setup() for d in self.data]
             self.n_docs_remaining = [d._len for d in self.data]
-            assert (
-                sum(self.n_docs_remaining) > 0
-            ), f"No documents detected in shard {self.rank} of {self.datapath}"
-
             self.generator = torch.Generator().manual_seed(self.rank)
 
     def __iter__(self):
         self.setup()
         # Grab one doc at a time in random order
         data = [iter(d) for d in self.data]
-        # Reset if we're rescaling into a prematurely finished epoch
-        # (i.e. [1,1,0,0,0,0] into [1,1,0] [0,0,0] )
-        if sum(self.n_docs_remaining) == 0:
-            self.n_docs_remaining = [d._len for d in self.data]
-            self.generator.manual_seed(self.rank)
         while True:
             # Sample logical shard (or load from ckp)
             if self.current_reader is not None:
                 ind = self.current_reader
             else:
+                assert (
+                    sum(self.n_docs_remaining) > 0
+                ), f"No documents detected in {self.datapath}"
                 ind = torch.multinomial(
                     torch.tensor(self.n_docs_remaining, dtype=torch.float),
                     1,
@@ -1514,10 +1398,6 @@ def __init__(
             ]
         )
         assert len(self.datasets) > 0, "You must specify at least one dataset"
-        for d in datasets:
-            assert os.path.exists(
-                os.path.join(datapath, d)
-            ), f"Invalid subdataset path: {os.path.join(datapath, d)}"
 
         if weights is not None:
             assert len(weights) == len(
diff --git a/fms_to_hf_llama.py b/fms_to_hf_llama.py
index 76042eee..6a81947a 100644
--- a/fms_to_hf_llama.py
+++ b/fms_to_hf_llama.py
@@ -3,160 +3,27 @@
 from fms.models.hf.utils import to_hf_api
 from fms.models.llama import LLaMA
 from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
-from transformers import LlamaConfig, LlamaForCausalLM
 
 from fms_fsdp.utils.config_utils import get_model_config
 
 
-def convert_to_hf(model: LLaMA, model_variant, is_old_fms) -> LlamaForCausalLM:
-    fms_hf_model = to_hf_api(model)
-    hf_config = fms_hf_model.config
-    if "llama3" in model_variant:
-        hf_config.bos_token_id = 128000
-        hf_config.eos_token_id = 128001
-    oss_hf_model = LlamaForCausalLM(
-        LlamaConfig(
-            vocab_size=hf_config.vocab_size,
-            hidden_size=hf_config.hidden_size,
-            rms_norm_eps=hf_config.norm_eps,
-            num_attention_heads=hf_config.nheads,
-            num_key_value_heads=None if hf_config.kvheads == 0 else hf_config.kvheads,
-            num_hidden_layers=hf_config.nlayers,
-            intermediate_size=hf_config.multiple_of
-            * (
-                (
-                    int(hf_config.hidden_grow_factor * hf_config.hidden_size)
-                    + hf_config.multiple_of
-                    - 1
-                )
-                // hf_config.multiple_of
-            ),
-            pad_token_id=(
-                None if hf_config.pad_token_id == -1 else hf_config.pad_token_id
-            ),
-            bos_token_id=hf_config.bos_token_id,
-            eos_token_id=hf_config.eos_token_id,
-            max_position_embeddings=hf_config.max_expected_seq_len,
-        )
-    )
-
-    # compute the freq from rot_emb since it is gathered lazily
-    rot_emb = fms_hf_model.decoder.model.rot_emb
-    max_seq_len = rot_emb.max_seq_len
-    alpha = rot_emb._alpha(max_seq_len)
-    ratio = rot_emb.ratio
-    dim = rot_emb.dim
-    if rot_emb.ntk_scaling:
-        ratio = ratio * alpha ** (dim / (dim - 2))
-    freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
-
-    with torch.no_grad():
-        oss_hf_model.model.embed_tokens.weight.copy_(fms_hf_model.embedding.weight)
-        i = 0
-        for oss_hf_layer in oss_hf_model.model.layers:
-            fms_hf_layer = fms_hf_model.decoder.model.layers[i]
-
-            # self attn
-            if is_old_fms:
-                oss_hf_layer.self_attn.q_proj.weight.copy_(
-                    fms_hf_layer.attn.query.weight
-                )
-                oss_hf_layer.self_attn.k_proj.weight.copy_(fms_hf_layer.attn.key.weight)
-                oss_hf_layer.self_attn.v_proj.weight.copy_(
-                    fms_hf_layer.attn.value.weight
-                )
-            else:
-                q, k, v = torch.split(
-                    fms_hf_layer.attn.in_proj.qkv_fused.weight,
-                    fms_hf_layer.attn.in_proj.splits,
-                    dim=0,
-                )
-                oss_hf_layer.self_attn.q_proj.weight.copy_(q)
-                oss_hf_layer.self_attn.k_proj.weight.copy_(k)
-                oss_hf_layer.self_attn.v_proj.weight.copy_(v)
-            oss_hf_layer.self_attn.o_proj.weight.copy_(fms_hf_layer.attn.dense.weight)
-            oss_hf_layer.self_attn.rotary_emb.inv_freqs = freqs
-
-            # mlp
-            if is_old_fms:
-                oss_hf_layer.mlp.gate_proj.weight.copy_(
-                    fms_hf_layer.ff_sub_layer.wg.weight
-                )
-                oss_hf_layer.mlp.up_proj.weight.copy_(
-                    fms_hf_layer.ff_sub_layer.w1.weight
-                )
-            else:
-                wg1_fused = fms_hf_layer.ff_sub_layer.wg1_fused.weight
-                wg_splits = [wg1_fused.size(0) // 2, wg1_fused.size(0) // 2]
-                wg, w1 = torch.split(
-                    fms_hf_layer.ff_sub_layer.wg1_fused.weight, wg_splits, dim=0
-                )
-                oss_hf_layer.mlp.gate_proj.weight.copy_(wg)
-                oss_hf_layer.mlp.up_proj.weight.copy_(w1)
-            oss_hf_layer.mlp.down_proj.weight.copy_(fms_hf_layer.ff_sub_layer.w2.weight)
-
-            # layer norm
-            oss_hf_layer.input_layernorm.weight.copy_(fms_hf_layer.ln.weight)
-            oss_hf_layer.post_attention_layernorm.weight.copy_(
-                fms_hf_layer.ff_ln.weight
-            )
-
-            # adjust q, k
-            q = oss_hf_layer.self_attn.q_proj.weight.data
-            q = (
-                q.view(hf_config.nheads, -1, 2, q.size(1))
-                .transpose(1, 2)
-                .reshape(*q.size())
-            )
-            oss_hf_layer.self_attn.q_proj.weight.copy_(q)
-
-            k = oss_hf_layer.self_attn.k_proj.weight.data
-            k = (
-                k.view(
-                    hf_config.nheads if hf_config.kvheads == 0 else hf_config.kvheads,
-                    -1,
-                    2,
-                    k.size(1),
-                )
-                .transpose(1, 2)
-                .reshape(*k.size())
-            )
-            oss_hf_layer.self_attn.k_proj.weight.copy_(k)
-
-            i = i + 1
-        oss_hf_model.model.norm.weight = fms_hf_model.decoder.model.dec_norm.weight
-        oss_hf_model.lm_head.weight = fms_hf_model.lm_head.weight
-
-    return oss_hf_model
-
-
-def main(
-    model_variant, compiled, is_old_fms, load_path, save_path, tokenizer_name_or_path
-):
+def main(model_variant, load_path, save_path, tokenizer_name_or_path):
     print("Initializing model...")
-    llama_config = get_model_config(model_variant)
-    with torch.device("meta"):
-        model = LLaMA(llama_config)
-    model.to_empty(device="cpu")
+    config_data = get_model_config(model_variant)
+    mamba_config = MambaConfig(**config_data)
+    model = MambaLMHeadModel(mamba_config)
 
     print(f"Reading state dict from {load_path}")
-    if not compiled:
-        state_dict = {"model_state": model.state_dict()}
-    else:
-        state_dict = {"model_state": {"_orig_mod": model.state_dict()}}
+    state_dict = {"model_state": model.state_dict()}
     load_state_dict(
         state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True
     )
 
     print("Loading state dict into the model...")
-    if not compiled:
-        model.load_state_dict(state_dict["model_state"])
-    else:
-        model.load_state_dict(state_dict["model_state"]["_orig_mod"])
+    model.load_state_dict(state_dict["model_state"])
 
-    print("Converting to HF model..")
-    hf_model = convert_to_hf(model, model_variant, is_old_fms)
-    hf_model.save_pretrained(save_path)
+    print("Saving model to HF-compatible format...")
+    model.save_pretrained(save_path)
 
     print("Copying tokenizer...")
     from transformers import AutoTokenizer
@@ -164,7 +31,7 @@ def main(
     tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
     tokenizer.save_pretrained(save_path)
 
-    print(f"Model converted to HF model, saving at {save_path}")
+    print(f"Model saving at {save_path}")
 
 
 if __name__ == "__main__":
diff --git a/main_training_llama.py b/main_training_llama.py
index a7e1020f..4bd2a3c3 100644
--- a/main_training_llama.py
+++ b/main_training_llama.py
@@ -1,10 +1,13 @@
 import math
 import os
+from pathlib import Path
 
 import fire
 import torch
 import torch.optim as optim
-from fms.models.llama import LLaMA, LLaMABlock
+from mamba_ssm.models.config_mamba import MambaConfig
+from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
+from mamba_ssm.modules.block import Block
 from torch import distributed as dist
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.optim.lr_scheduler import LambdaLR
@@ -44,9 +47,12 @@ def main(**kwargs):
     torch.cuda.set_device(local_rank)
     torch.cuda.empty_cache()
     setup_environ_flags()
+    os.environ["TRITON_CACHE_DIR"] = os.path.join(
+        Path.home(), ".triton", "cache", str(local_rank)
+    )
 
     # get policy
-    block = LLaMABlock
+    block = Block
     (
         mixed_precision_policy,
         wrapping_policy,
@@ -55,14 +61,10 @@ def main(**kwargs):
         param_init_fn,
     ) = get_policies(cfg, rank, block)
 
-    # get fms model
-    llama_config = get_model_config(cfg.model_variant)
-    if cfg.low_cpu_fsdp:
-        with torch.device("meta"):
-            model = LLaMA(llama_config)
-    else:
-        model = LLaMA(llama_config)
-        model.reset_parameters()
+    # get model
+    config_data = get_model_config(cfg.model_variant)
+    mamba_config = MambaConfig(**config_data)
+    model = MambaLMHeadModel(mamba_config)
 
     if rank == 0:
         total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
@@ -89,11 +91,6 @@ def main(**kwargs):
         limit_all_gathers=True,
         param_init_fn=param_init_fn,
     )
-    # we need this post-fsdp call to avoid graph break with torch.compile, until we figure out a better solution.
-    model.rot_emb.compute_freqs_cis(
-        torch.device("cuda", torch.cuda.current_device()),
-        model.config.max_expected_seq_len,
-    )
 
     # fsdp activation checkpointing
     if cfg.fsdp_activation_checkpointing:
@@ -122,11 +119,9 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=(
-            os.path.join(cfg.ckpt_load_path, "checkpoints/")
-            if not os.path.isfile(cfg.ckpt_load_path)
-            else cfg.ckpt_load_path
-        ),
+        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
+        if not os.path.isfile(cfg.ckpt_load_path)
+        else cfg.ckpt_load_path,
         strict=False,
     )
     if not is_resuming:
@@ -139,14 +134,37 @@ def main(**kwargs):
     if cfg.training_stage == "annealing":
         schedule = lambda x: 1 - x / cfg.num_steps
     else:
-        warmup_interval = min(2000, cfg.num_steps // 20)
-        schedule = lambda x: min(
-            1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2,
-            0.1
-            + 0.5
-            * (1 - 0.1)
-            * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
-        )
+        
+        # (cosine 0.01 decay)
+        # warmup_interval = min(1000, cfg.num_steps // 10)
+        # schedule = lambda x: min(
+        #     1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2,
+        #     0.01
+        #     + 0.5
+        #     * (1 - 0.01)
+        #     * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
+        # )
+        
+        # (constant schedule)
+        # warmup_interval = 1000  
+        # schedule = lambda x: (
+        #     min(x, warmup_interval) / warmup_interval
+        # )
+        schedule = lambda x: min(1.0, x)
+
+        # (cosine 0.1 decay)
+        # warmup_interval = min(2000, cfg.num_steps // 20)
+        # schedule = lambda x: min(
+        #     1 - (1 - min(x, warmup_interval) / warmup_interval) ** 2,
+        #     0.1
+        #     + 0.5
+        #     * (1 - 0.1)
+        #     * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
+        # # )
+        
+        # linear decay to 50b tokens and then constant lr
+        # schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75
+
     scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
 
     # profiler
diff --git a/main_training_mamba.py b/main_training_mamba.py
index 68a3c830..3ff12a60 100644
--- a/main_training_mamba.py
+++ b/main_training_mamba.py
@@ -119,11 +119,9 @@ def main(**kwargs):
         model,
         optimizer,
         None,
-        path=(
-            os.path.join(cfg.ckpt_load_path, "checkpoints/")
-            if not os.path.isfile(cfg.ckpt_load_path)
-            else cfg.ckpt_load_path
-        ),
+        path=os.path.join(cfg.ckpt_load_path, "checkpoints/")
+        if not os.path.isfile(cfg.ckpt_load_path)
+        else cfg.ckpt_load_path,
         strict=False,
     )
     if not is_resuming:
@@ -135,8 +133,9 @@ def main(**kwargs):
     # LR schedule
     # linear decay for annealing
     if cfg.training_stage == "annealing":
-        schedule = lambda x: 1 - x / cfg.num_steps
-    else:
+        warmup_interval = 1000
+        schedule = lambda x: x / warmup_interval if x < warmup_interval else 1 - (x - warmup_interval) / (cfg.num_steps - warmup_interval)
+    elif cfg.training_stage == "cosine":
         # cosine decay
         warmup_interval = min(2000, cfg.num_steps // 20)
         schedule = lambda x: min(
@@ -146,6 +145,11 @@ def main(**kwargs):
             * (1 - 0.1)
             * (1 + math.cos(min(x, cfg.num_steps) / cfg.num_steps * math.pi)),
         )
+    elif cfg.training_stage == "constant":
+        warmup_interval = 2000
+        schedule = lambda x: (min(x, warmup_interval) / warmup_interval)
+    else:
+        schedule = lambda x: 1.0 + (0.75 - 1.0) * (x / 32000) if x <= 32000 else 0.75
 
     scheduler = LambdaLR(optimizer, lambda x: schedule(x + start_step))
 
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 40bef481..83b2426b 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -632,10 +632,6 @@ def test_multi_reload_stress():
     # preload / sample / scale / doc pipeline
     multi_reload_stress_check(lambda: d6(d5(d4())))
 
-    # Add FIM dataset
-    d7 = lambda x: [FIMDataset(d, -1, 0.25, 0.25, 10, -2, -3, -4) for d in x]
-    multi_reload_stress_check(lambda: d7(d6(d5(d4()))))
-
 
 # SCALABLEDATASET TESTS