diff --git a/CHANGELOG.md b/CHANGELOG.md index f0fca7936b..62ade526da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Improve OLMo-core DPO MFU for the Olmo Hybrid (GDN) model: pack DPO microbatches to the `max_seq_length` token budget instead of capping at `per_device_train_batch_size` sequences, and yield rectangular stacked packed-row batches (`stack_packed_rows`/`unstack_packed_rows`) so OLMo-core's dict batch contract, batch-size validation, and token accounting work natively (gradient accumulation = packed rows per rank per step) (https://github.com/allenai/open-instruct/pull/1713). - Add minimal support for DPO-training the Olmo Hybrid 7B (GDN linear attention) model with `open_instruct/dpo.py` (OLMo-core): bump OLMo-core to a rev with the `olmo3_hybrid_7B` preset and bidirectional HF weight conversion, bump `flash-linear-attention` to 0.5.0 and add `tilelang` (correct GDN gradients on Hopper), add `selected_modules` activation checkpointing (forwarding `determinism_check="none"` through OLMo-core's activation-checkpointing config so `torch.compile` and the opaque `fla` kernels coexist), extend `ModelDims` with GDN FLOPs/params accounting, and add `scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh` (https://github.com/allenai/open-instruct/pull/1715). - Record a `_metrics_keepalive` metric on every rank every GRPO+OLMo-core step to keep `_metrics` non-empty, preventing OLMo-core's empty-skip in `_log_metrics` from desyncing the bookkeeping process group and deadlocking gloo for 30 minutes at save-time flushes (https://github.com/allenai/open-instruct/pull/1708). - Expand type-checking coverage by replacing `# ty: ignore` directives with typed casts and fixing related type issues (https://github.com/allenai/open-instruct/pull/1688). diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 2d1660c793..8f43909ea6 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -91,12 +91,17 @@ def __init__( drop_last: bool = True, fs_local_rank: int | None = None, max_seq_length: int = 1, + microbatch_sample_cap: int | None = None, ) -> None: """Initialize the HFDataLoader. Args: dataset: The HuggingFace Dataset to load data from. Must have an 'index' column. - batch_size: The global batch size (in sequences). + batch_size: The global batch size in instances. Without packing, an instance is one + example. With packing (a collator with max_seq_length set), an instance is one + packed row holding a variable number of examples, and each rank's batch is the + stack of its `batch_size // dp_world_size` rows (see + padding_free_collator.stack_packed_rows). seed: Random seed for shuffling. dp_rank: The rank of the current process in the distributed setup. dp_world_size: Total number of data-parallel processes in the distributed setup. @@ -108,8 +113,11 @@ def __init__( drop_last: If True, drop the last incomplete batch. If False, pad the last batch with repeated indices to fill a complete batch. fs_local_rank: File system local rank. Defaults to dp_rank when None. - max_seq_length: Maximum sequence length. Used to report global_batch_size in tokens + max_seq_length: Tokens per instance. Used to report global_batch_size in tokens to the trainer for batch-size validation. + microbatch_sample_cap: When packing, the maximum number of examples per packed + row. A row closes when either the token budget or this cap is reached. + None means pack purely to the token budget. Note: The dataset must have an 'index' column for tracking samples across epochs. @@ -144,7 +152,9 @@ def __init__( f"The effective global batch size will be {batch_size // dp_world_size * dp_world_size}." ) self._per_rank_batch_size = batch_size // dp_world_size + self._microbatch_sample_cap = microbatch_sample_cap self._collator = collator if collator is not None else (lambda x: {"examples": x}) + self._collator_max_seq_length = getattr(self._collator, "max_seq_length", None) self._automatic_reshuffle = automatic_reshuffle self._drop_last = drop_last self._excluded_indices: set[int] = set() @@ -176,24 +186,28 @@ def __next__(self) -> dict[str, Any]: def _iter_batches(self) -> Iterable[dict[str, Any]]: """Return an iterable over all batches in the epoch.""" - # World-aware packing: batch boundaries were precomputed by + # World-aware packing: row boundaries were precomputed by # _reshard_with_packing so that every rank has the same number of - # batches. Each entry in _precomputed_batch_sizes is the number of - # examples in that batch (variable due to packing). + # rows. Each entry in _precomputed_batch_sizes is the number of + # examples in that row (variable due to packing). Each yielded batch + # stacks per_rank_batch_size rows into one rectangular dict; the train + # module splits it back into per-row microbatches. if self._precomputed_batch_sizes is not None: + rows_per_batch = self._per_rank_batch_size num_real = len(self._precomputed_batch_sizes) - self._num_padding_batches - offset = 0 - for batch_idx, batch_size in enumerate(self._precomputed_batch_sizes): - if batch_idx < self.batches_processed: - offset += batch_size - continue - examples = [] - for i in range(offset, offset + batch_size): - example = self.dataset[i] - examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"}) - batch = to_device(self._collator(examples), self._device) | {"is_padding": batch_idx >= num_real} - offset += batch_size - yield batch + num_batches = len(self._precomputed_batch_sizes) // rows_per_batch + offsets = [0] + for row_size in self._precomputed_batch_sizes: + offsets.append(offsets[-1] + row_size) + for batch_idx in range(self.batches_processed, num_batches): + rows = [] + for row_idx in range(batch_idx * rows_per_batch, (batch_idx + 1) * rows_per_batch): + examples = [] + for i in range(offsets[row_idx], offsets[row_idx + 1]): + example = self.dataset[i] + examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"}) + rows.append(self._collator(examples) | {"is_padding": row_idx >= num_real}) + yield to_device(padding_free_collator.stack_packed_rows(rows), self._device) return start_example = self.batches_processed * self._per_rank_batch_size @@ -219,7 +233,7 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: def total_batches(self) -> int: """Return the total number of batches in an epoch.""" if self._precomputed_batch_sizes is not None: - return len(self._precomputed_batch_sizes) + return len(self._precomputed_batch_sizes) // self._per_rank_batch_size return self.effective_size // self._per_rank_batch_size def state_dict(self) -> dict[str, Any]: @@ -272,9 +286,8 @@ def _reshard(self, epoch: int) -> None: mask = np.isin(all_indices, list(self._excluded_indices), invert=True) all_indices = all_indices[mask] - packing_enabled = hasattr(self._collator, "max_seq_length") and self._collator.max_seq_length is not None - if packing_enabled: - self._reshard_with_packing(all_indices) + if self._collator_max_seq_length is not None: + self._reshard_with_packing(all_indices, self._collator_max_seq_length) return self._precomputed_batch_sizes = None @@ -300,7 +313,7 @@ def _reshard(self, epoch: int) -> None: self.effective_size = len(rank_indices) self.dataset = self._full_dataset.select(rank_indices.tolist()) - def _reshard_with_packing(self, all_indices: np.ndarray) -> None: + def _reshard_with_packing(self, all_indices: np.ndarray, max_seq_length: int) -> None: """Reshard with world-aware packing so all ranks get the same batch count. Instead of distributing examples to ranks and letting each rank pack @@ -308,7 +321,6 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: overflow), this packs globally first and then distributes packed batches round-robin to ranks. """ - max_seq_length = self._collator.max_seq_length column_names = self._full_dataset.column_names subset = self._full_dataset.select(all_indices.tolist()) if "chosen_input_ids" in column_names: @@ -324,9 +336,9 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: for i in range(len(all_indices)): new_totals = [running_totals[s] + lengths[i][s] for s in range(num_streams)] would_exceed = len(current_batch) > 0 and any(t > max_seq_length for t in new_totals) - at_max_samples = len(current_batch) >= self._per_rank_batch_size + at_cap = self._microbatch_sample_cap is not None and len(current_batch) >= self._microbatch_sample_cap - if would_exceed or at_max_samples: + if would_exceed or at_cap: batches.append(current_batch) current_batch = [i] running_totals = list(lengths[i]) @@ -337,14 +349,19 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: if current_batch: batches.append(current_batch) + # Rows are distributed round-robin to ranks and then stacked into + # per_rank_batch_size-sized batches, so the global count must be a + # multiple of dp_world_size * per_rank_batch_size for every rank to have + # the same number of complete batches. + group_size = self.dp_world_size * self._per_rank_batch_size num_batches = len(batches) padding_start = num_batches if self._drop_last: - num_batches = (num_batches // self.dp_world_size) * self.dp_world_size + num_batches = (num_batches // group_size) * group_size batches = batches[:num_batches] else: - if (remainder := num_batches % self.dp_world_size) > 0: - for _ in range(self.dp_world_size - remainder): + if (remainder := num_batches % group_size) > 0: + for _ in range(group_size - remainder): batches.append(batches[-1]) rank_global_indices = list(range(self.dp_rank, len(batches), self.dp_world_size)) @@ -369,8 +386,16 @@ def get_mock_batch(self) -> dict[str, Any]: forward and backward pass before training officially starts. """ num_examples = min(self._per_rank_batch_size, len(self.dataset)) + # When packing, the collator consumes only as many examples as fit the token + # budget, so at most max_seq_length examples (each >= 1 token) can be used. + # Bound the rows loaded so a large microbatch_sample_cap doesn't load the dataset. + if self._collator_max_seq_length is not None: + num_examples = min(num_examples, self._collator_max_seq_length) examples = [self.dataset[i] for i in range(num_examples)] - return to_device(self._collator(examples), self._device) + collated = self._collator(examples) + if self._collator_max_seq_length is not None: + collated = padding_free_collator.stack_packed_rows([collated]) + return to_device(collated, self._device) def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int: """Return the total number of tokens in the batch across all ranks. diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index b279988958..58bfddd948 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -25,7 +25,7 @@ from open_instruct import data_loader as data_loader_lib from open_instruct import dataset_transformation, dpo_utils, logger_utils, model_utils, olmo_core_utils, utils from open_instruct.olmo_core_callbacks import PerfCallback -from open_instruct.olmo_core_train_modules import DPOTrainModule +from open_instruct.olmo_core_train_modules import DPOMetricsCallback, DPOTrainModule from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO logger = logger_utils.setup_logger(__name__) @@ -66,13 +66,13 @@ def _setup_callbacks(args: dpo_utils.DPOExperimentConfig, dp_world_size: int): wandb_entity=args.wandb_entity, save_async=False, ) + trainer_callbacks["dpo_metrics"] = DPOMetricsCallback() slack_webhook_url = os.environ.get("SLACK_WEBHOOK_URL") if args.send_slack_alerts and slack_webhook_url: trainer_callbacks["slack"] = callbacks.SlackNotifierCallback(name=run_name, webhook_url=slack_webhook_url) model_dims = utils.ModelDims.from_hf_config(args.model_name_or_path) trainer_callbacks["perf"] = PerfCallback( model_dims=model_dims, - per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, dp_world_size=dp_world_size, tensor_parallel_degree=args.tensor_parallel_degree, @@ -191,8 +191,15 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz else: collator = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=None, padding="longest") - rank_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps - global_batch_size = rank_batch_size * dp_world_size + # With packing, an instance is one packed row (one microbatch), so a rank accumulates + # gradient_accumulation_steps rows per optimizer step. Without packing, an instance is + # one example and a rank consumes per_device_train_batch_size * gradient_accumulation_steps + # examples per step. Either way, each instance spans 2 * max_seq_length tokens + # (chosen + rejected streams). + if args.packing: + global_batch_size = args.gradient_accumulation_steps * dp_world_size + else: + global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * dp_world_size data_loader = data_loader_lib.HFDataLoader( dataset=dataset, batch_size=global_batch_size, @@ -204,14 +211,13 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz device=device, drop_last=True, fs_local_rank=global_rank, + max_seq_length=2 * args.max_seq_length, + microbatch_sample_cap=args.per_device_train_batch_size, ) - # 4x batch size: forward-only (no backward), so no activation storage needed. - # With packing, the collator's token budget controls the actual forward-pass size - # and the overflow mechanism in HFDataLoader ensures no examples are dropped. - # We could probably have logic to use a longer sequence length here when packing - # is enabled, but for simplicity we just keep the 4x increase in batch size regardless of packing. - # We want the batch size to be as large as possible so that we always pack efficiently. - cache_batch_size = int(args.per_device_train_batch_size * 4 * dp_world_size) + # Forward-only (no backward), so no activation storage is needed. With packing each + # instance is already a full token-budget row, so one row per rank per batch suffices; + # without packing we use a 4x larger example batch. + cache_batch_size = dp_world_size if args.packing else int(args.per_device_train_batch_size * 4 * dp_world_size) cache_data_loader = data_loader_lib.HFDataLoader( dataset=dataset, batch_size=cache_batch_size, @@ -224,6 +230,7 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz # We need to process every example to cache reference logprobs, so we can't drop the last batch. drop_last=False, fs_local_rank=global_rank, + max_seq_length=2 * args.max_seq_length, ) forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo diff --git a/open_instruct/dpo_utils.py b/open_instruct/dpo_utils.py index 7f1b8f3553..0a1c9c4d10 100644 --- a/open_instruct/dpo_utils.py +++ b/open_instruct/dpo_utils.py @@ -459,34 +459,35 @@ def build_reference_logprobs_cache( with torch.no_grad(): for batch in (pbar := tqdm(dataloader, disable=not is_main_process, desc="Caching reference logprobs")): - batch_start = time.perf_counter() - if use_lora and disable_adapter_context is not None: - with disable_adapter_context(): + for row in padding_free_collator.unstack_packed_rows(batch): + row_start = time.perf_counter() + if use_lora and disable_adapter_context is not None: + with disable_adapter_context(): + chosen_logps, rejected_logps, _ = forward_fn( + model, row, average_log_prob=average_log_prob, **(forward_kwargs or {}) + ) + else: chosen_logps, rejected_logps, _ = forward_fn( - model, batch, average_log_prob=average_log_prob, **(forward_kwargs or {}) + model, row, average_log_prob=average_log_prob, **(forward_kwargs or {}) ) - else: - chosen_logps, rejected_logps, _ = forward_fn( - model, batch, average_log_prob=average_log_prob, **(forward_kwargs or {}) - ) - if batch.get("is_padding", False): - continue - - chosen_tensor[batch["index"]] = chosen_logps - rejected_tensor[batch["index"]] = rejected_logps - - batch_tokens, batch_size, chosen_lengths, rejected_lengths = _get_batch_stats(batch) - total_tokens += batch_tokens - total_examples += batch_size - pbar.set_postfix( - { - "avg_tok/ex": f"{total_tokens / total_examples:.0f}", - "MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - batch_start):.1f}", - "mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}", - "mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}", - } - ) + if row.get("is_padding", False): + continue + + chosen_tensor[row["index"]] = chosen_logps + rejected_tensor[row["index"]] = rejected_logps + + batch_tokens, batch_size, chosen_lengths, rejected_lengths = _get_batch_stats(row) + total_tokens += batch_tokens + total_examples += batch_size + pbar.set_postfix( + { + "avg_tok/ex": f"{total_tokens / total_examples:.0f}", + "MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - row_start):.1f}", + "mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}", + "mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}", + } + ) dist.all_reduce(chosen_tensor, op=dist.ReduceOp.MAX) dist.all_reduce(rejected_tensor, op=dist.ReduceOp.MAX) diff --git a/open_instruct/olmo_core_callbacks.py b/open_instruct/olmo_core_callbacks.py index afbb5eada8..cc1e35529b 100644 --- a/open_instruct/olmo_core_callbacks.py +++ b/open_instruct/olmo_core_callbacks.py @@ -123,7 +123,6 @@ class PerfCallback(Callback): """Calculates MFU and tokens_per_second using same formula as dpo_tune_cache.py.""" model_dims: utils.ModelDims - per_device_train_batch_size: int gradient_accumulation_steps: int dp_world_size: int tensor_parallel_degree: int = 1 @@ -161,8 +160,6 @@ def pre_step(self, batch: dict[str, Any]) -> None: self._pre_step_time = time.perf_counter() self._step_start_time = self._pre_step_time num_seqs = padding_free_collator.get_num_sequences(batch) - if num_seqs is None: - num_seqs = self.per_device_train_batch_size * 2 self._interval_num_sequences += num_seqs * self.dp_world_size def post_step(self) -> None: @@ -201,11 +198,11 @@ def post_step(self) -> None: seconds_per_step = interval_end - self._step_start_time - self.trainer.record_metric("perf/mfu", mfu_result["mfu"], reduce_type=None) + self.trainer.record_metric("perf/mfu_step", mfu_result["mfu"], reduce_type=None) self.trainer.record_metric("perf/mfu_avg", mfu_avg, reduce_type=None) self.trainer.record_metric("perf/seconds_per_step", seconds_per_step, reduce_type=None) - self.trainer.record_metric("perf/tokens_per_second", tokens_per_second, reduce_type=None) - self.trainer.record_metric("perf/tokens_per_second_avg", tokens_per_second_avg, reduce_type=None) + self.trainer.record_metric("perf/tokens_per_second_step", tokens_per_second, reduce_type=None) + self.trainer.record_metric("perf/tokens_per_second_total", tokens_per_second_avg, reduce_type=None) self.trainer.record_metric( "perf/tokens_per_second_per_gpu", tokens_per_second / (self.dp_world_size * self.tensor_parallel_degree), diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 036b93d441..e329d4b383 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -5,7 +5,7 @@ """ import math -from typing import Any, Literal +from typing import Any, Literal, cast import numpy as np import torch @@ -18,6 +18,8 @@ from olmo_core.nn.transformer import Transformer from olmo_core.optim import OptimConfig from olmo_core.optim.scheduler import Scheduler +from olmo_core.train.callbacks import Callback +from olmo_core.train.common import ReduceType from olmo_core.train.train_module import TransformerTrainModule from olmo_core.train.train_module.transformer import config as transformer_config from torch.distributed.tensor import DTensor, Replicate, Shard @@ -34,6 +36,36 @@ {AttentionBackendName.flash_2, AttentionBackendName.flash_3, AttentionBackendName.flash_4} ) +# DPO token-weighted metrics are ratios sum_ranks(sum_mb(metric*tokens)) / sum_ranks(tokens). +# DPOTrainModule records each numerator under the _DPO_REDUCE_NS namespace and the shared +# denominators (real and padded token counts) under the keys below, all with ReduceType.sum so +# the trainer reduces them in its batched per-interval all-reduce. DPOMetricsCallback then divides +# numerator/denominator after reduction, avoiding a per-step host sync and explicit all-reduce. +_DPO_REDUCE_NS = "_dpo_reduce" +_DPO_TOKENS_KEY = f"{_DPO_REDUCE_NS}/__tokens__" +_DPO_PADDED_KEY = f"{_DPO_REDUCE_NS}/__padded__" + + +class DPOMetricsCallback(Callback): + """Reconstructs token-weighted DPO metrics from reduced numerator/denominator sums.""" + + priority = 10 + + def pre_log_metrics(self, step: int, metrics: dict[str, float]) -> None: + if _DPO_TOKENS_KEY not in metrics: + return + tokens = metrics.pop(_DPO_TOKENS_KEY) + padded = metrics.pop(_DPO_PADDED_KEY) + prefix = f"{_DPO_REDUCE_NS}/" + for key in [k for k in metrics if k.startswith(prefix)]: + metrics[key[len(prefix) :]] = metrics.pop(key) / tokens + metrics["train/padding_fraction"] = 1.0 - tokens / padded + metrics["training_step"] = float(step) + if self.trainer.steps_per_epoch is not None: + metrics["epoch"] = step / self.trainer.steps_per_epoch + if "optim/LR (group 0)" in metrics: + metrics["learning_rate"] = metrics["optim/LR (group 0)"] + class DPOLMHead(LMHead): """LM head that returns per-token log-probabilities for DPO training. @@ -138,7 +170,13 @@ def __init__( ) # TODO(finbarrtimbers): Remove this hack once Transformer supports configuring the LM head. model.lm_head.__class__ = DPOLMHead - rank_microbatch_size_tokens = sample_microbatch_size * max_sequence_length * 2 + # With packing, a microbatch is one packed row: chosen + rejected streams, each + # padded to max_sequence_length tokens. Without packing, it is sample_microbatch_size + # example pairs of up to max_sequence_length tokens per stream. + if dpo_config.packing: + rank_microbatch_size_tokens = 2 * max_sequence_length + else: + rank_microbatch_size_tokens = sample_microbatch_size * max_sequence_length * 2 super().__init__( model=model, optim=optim, @@ -175,13 +213,9 @@ def __init__( if dpo_config.packing: self._forward_kwargs["packing"] = True - def pre_train(self): - pass - def global_num_flops_in_batch(self, batch: dict[str, Any]) -> int | None: - global_num_tokens = self.trainer.data_loader.global_num_tokens_in_batch(batch) - if global_num_tokens is None: - return None + data_loader = cast(data_loader_lib.HFDataLoader, self.trainer.data_loader) + global_num_tokens = data_loader.global_num_tokens_in_batch(batch) seq_len = batch["chosen_input_ids"].shape[1] flops_per_token = self.num_flops_per_token(seq_len=seq_len) return flops_per_token * global_num_tokens if flops_per_token is not None else None @@ -223,10 +257,14 @@ def _compute_microbatch_loss(self, micro_batch: dict[str, Any]) -> tuple[torch.T def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: self.model.train() - micro_batches = split_batch_dpo(batch, self.sample_microbatch_size) + if self.dpo_config.packing: + micro_batches = padding_free_collator.unstack_packed_rows(batch) + else: + micro_batches = split_batch_dpo(batch, self.sample_microbatch_size) num_micro_batches = len(micro_batches) device = batch["chosen_input_ids"].device - total_tokens = padding_free_collator.get_num_tokens(batch) + micro_token_counts = [padding_free_collator.get_num_tokens(mb) for mb in micro_batches] + total_tokens = sum(micro_token_counts) for v in self._metrics.values(): v.zero_() @@ -234,7 +272,7 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: for micro_batch_idx, micro_batch in enumerate(micro_batches): with self._train_microbatch_context(micro_batch_idx, num_micro_batches): loss, step_metrics = self._compute_microbatch_loss(micro_batch) - micro_tokens = padding_free_collator.get_num_tokens(micro_batch) + micro_tokens = micro_token_counts[micro_batch_idx] weight = micro_tokens / total_tokens for k, v in step_metrics.items(): self._metrics[k] += v.detach() * micro_tokens @@ -243,39 +281,42 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: self.model.post_batch(dry_run=dry_run) if not dry_run: - metric_keys = sorted(self._metrics.keys()) - local_sums_list = [torch.tensor(total_tokens, dtype=torch.float32, device=device)] + [ - self._metrics[k] for k in metric_keys - ] - local_sums = torch.stack(local_sums_list) - dist.all_reduce(local_sums, op=dist.ReduceOp.SUM, group=self.trainer.dp_process_group) - - global_total_tokens = local_sums[0] - global_metrics = {k: local_sums[i + 1] / global_total_tokens for i, k in enumerate(metric_keys)} - - self.record_metric("train/loss", global_metrics["loss"].item(), reduce_type=None) - self.record_metric("train/logps_chosen", global_metrics["chosen_logps"].item(), reduce_type=None) - self.record_metric("train/logps_rejected", global_metrics["rejected_logps"].item(), reduce_type=None) - token_count = self.trainer.data_loader.global_num_tokens_in_batch(batch) - assert token_count is not None - self.record_metric("train/token_count", token_count, reduce_type=None) + local_padded_tokens = padding_free_collator.get_num_padded_tokens(batch) + local_num_sequences = padding_free_collator.get_num_sequences(batch) + tokens_tensor = torch.tensor(float(total_tokens), device=device) + self.record_metric(_DPO_TOKENS_KEY, tokens_tensor, reduce_type=ReduceType.sum) + self.record_metric( + _DPO_PADDED_KEY, torch.tensor(float(local_padded_tokens), device=device), reduce_type=ReduceType.sum + ) + # PerfCallback reads train/token_count from the per-step buffer (before the deferred + # reduction), so it must be recorded here rather than reconstructed in DPOMetricsCallback. + self.record_metric( + "train/token_count", float(total_tokens) * self.trainer.data_loader.dp_world_size, reduce_type=None + ) + weighted_sums = { + "train_loss": self._metrics["loss"], + "logps/chosen": self._metrics["chosen_logps"], + "logps/rejected": self._metrics["rejected_logps"], + } if self.dpo_config.loss_type.computes_reward_metrics: - margin = global_metrics["chosen_rewards"] - global_metrics["rejected_rewards"] - self.record_metric("train/rewards_chosen", global_metrics["chosen_rewards"].item(), reduce_type=None) - self.record_metric( - "train/rewards_rejected", global_metrics["rejected_rewards"].item(), reduce_type=None - ) - self.record_metric( - "train/rewards_average", - ((global_metrics["chosen_rewards"] + global_metrics["rejected_rewards"]) / 2).item(), - reduce_type=None, - ) - self.record_metric("train/rewards_accuracy", global_metrics["accuracy"].item(), reduce_type=None) - self.record_metric("train/rewards_margin", margin.item(), reduce_type=None) - - if "aux_loss" in global_metrics: - self.record_metric("train/aux_loss", global_metrics["aux_loss"].item(), reduce_type=None) + chosen_rewards = self._metrics["chosen_rewards"] + rejected_rewards = self._metrics["rejected_rewards"] + weighted_sums["rewards/chosen"] = chosen_rewards + weighted_sums["rewards/rejected"] = rejected_rewards + weighted_sums["rewards/average"] = (chosen_rewards + rejected_rewards) / 2 + weighted_sums["rewards/accuracy"] = self._metrics["accuracy"] + weighted_sums["rewards/margin"] = chosen_rewards - rejected_rewards + if "aux_loss" in self._metrics: + weighted_sums["aux_loss"] = self._metrics["aux_loss"] + for name, value in weighted_sums.items(): + self.record_metric(f"{_DPO_REDUCE_NS}/{name}", value, reduce_type=ReduceType.sum) + + self.record_metric( + "train/sequences_per_step", + torch.tensor(float(local_num_sequences), device=device), + reduce_type=ReduceType.sum, + ) class GRPOTrainModule(TransformerTrainModule): diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 85cf9fbd20..119f8e3368 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -229,19 +229,81 @@ def get_batch_logps( return segment_sums +def stack_packed_rows(rows: list[dict[str, Any]]) -> dict[str, Any]: + """Stack collated packed rows into one rectangular batch dict. + + Each row is a collator output where every stream tensor has shape + (1, max_seq_length). Stream tensors are concatenated along dim 0 to + (num_rows, max_seq_length). Per-row cu_seq_lens tensors are padded to a + common length by repeating their final value (the row's total real token + count), producing zero-length phantom boundaries that unstack_packed_rows + trims away. `index` is padded with -1 and `max_length_q/k` ints are reduced + with max. + """ + out: dict[str, Any] = {} + for k, v in rows[0].items(): + if k.endswith(("cu_seq_lens_q", "cu_seq_lens_k")): + width = max(len(r[k]) for r in rows) + out[k] = torch.stack([torch.cat([r[k], r[k][-1:].expand(width - len(r[k]))]) for r in rows]) + elif k.endswith(("max_length_q", "max_length_k")): + out[k] = max(r[k] for r in rows) + elif k == "index": + width = max(len(r[k]) for r in rows) + out[k] = torch.stack([torch.cat([r[k], r[k].new_full((width - len(r[k]),), -1)]) for r in rows]) + elif k == "is_padding": + out[k] = torch.tensor([r[k] for r in rows]) + elif isinstance(v, torch.Tensor): + out[k] = torch.cat([r[k] for r in rows], dim=0) + else: + out[k] = v + return out + + +def unstack_packed_rows(batch: dict[str, Any]) -> list[dict[str, Any]]: + """Invert stack_packed_rows, recovering the per-row collated dicts. + + Trims the repeated-value padding from cu_seq_lens rows and the -1 padding + from index rows. Batches without stacked (2-D) cu_seq_lens tensors are + returned unchanged as a single-element list, so callers can treat stacked + and unstacked batches uniformly. + """ + cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] + if not cu_keys or batch[cu_keys[0]].dim() == 1: + return [batch] + rows = [] + for i in range(batch[cu_keys[0]].shape[0]): + row: dict[str, Any] = {} + for k, v in batch.items(): + if k.endswith(("cu_seq_lens_q", "cu_seq_lens_k")): + cu = v[i] + num_seqs = int((cu[1:] > cu[:-1]).sum().item()) + row[k] = cu[: num_seqs + 1] + elif k == "index": + row[k] = v[i][v[i] >= 0] + elif k == "is_padding": + row[k] = bool(v[i].item()) + elif isinstance(v, torch.Tensor): + row[k] = v[i : i + 1] + else: + row[k] = v + rows.append(row) + return rows + + def get_num_tokens(batch: dict[str, Any]) -> int: """Return total non-padding token count from a training batch. For packed batches (DPO or GRPO), reads cu_seq_lens_k tensors whose last - element is the total token count for that branch. For padded batches, sums - the attention_mask. Falls back to counting input_ids elements. + element (per row, for stacked batches) is the total token count for that + branch. For padded batches, sums the attention_mask. Falls back to counting + input_ids elements. """ # cu_seq_lens_k is a cumulative sequence length tensor from the padding-free # collator. Its last element equals the total token count for that branch. # DPO has chosen_cu_seq_lens_k + rejected_cu_seq_lens_k; GRPO has cu_seq_lens_k. cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] if cu_keys: - return sum(batch[k][-1].item() for k in cu_keys) + return sum(batch[k][..., -1].sum().item() for k in cu_keys) # DPO batches have chosen_attention_mask and rejected_attention_mask; sum both branches. attn_keys = [k for k in batch if k.endswith("attention_mask")] if attn_keys: @@ -249,14 +311,23 @@ def get_num_tokens(batch: dict[str, Any]) -> int: return sum(v.numel() for k, v in batch.items() if "input_ids" in k and isinstance(v, torch.Tensor)) -def get_num_sequences(batch: dict[str, Any]) -> int | None: - """Return total sequence count from a training batch, or None for non-packing batches. +def get_num_padded_tokens(batch: dict[str, Any]) -> int: + """Return total token count including padding from a training batch. + + Counts all elements of input_ids tensors. + """ + return sum(v.numel() for k, v in batch.items() if k.endswith("input_ids") and isinstance(v, torch.Tensor)) + + +def get_num_sequences(batch: dict[str, Any]) -> int: + """Return total sequence count from a training batch. - For packed batches, reads cu_seq_lens_k tensors which each have num_seqs + 1 - elements (including a leading 0). Returns None if no cu_seq_lens_k keys are found. + For packed batches, counts strictly-increasing boundaries in each + cu_seq_lens_k tensor, which works for both 1-D rows and stacked 2-D rows + (whose repeated-value padding contributes no increase). For non-packed + batches, counts rows of each input_ids tensor. """ cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] if cu_keys: - # Each cu_seq_lens tensor has num_seqs + 1 elements (leading 0 boundary). - return sum(len(batch[k]) - 1 for k in cu_keys) - return None + return int(sum((batch[k][..., 1:] > batch[k][..., :-1]).sum().item() for k in cu_keys)) + return sum(batch[k].shape[0] for k in batch if k.endswith("input_ids")) diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index 2bb6390fde..fc2b93ad71 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -5,7 +5,7 @@ import torch from datasets import Dataset -from open_instruct import data_loader +from open_instruct import data_loader, padding_free_collator from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO @@ -72,12 +72,152 @@ def test_packing_equal_batches_across_ranks( for loader in loaders: for batch in loader: if "index" in batch: - all_indices.update(batch["index"].tolist()) + all_indices.update(batch["index"][batch["index"] >= 0].tolist()) if not drop_last: expected_indices = set(range(num_samples)) self.assertEqual(all_indices, expected_indices, f"Missing indices: {expected_indices - all_indices}") +def _make_fixed_length_dpo_dataset(num_samples: int, seq_len: int) -> Dataset: + rng = torch.Generator().manual_seed(42) + data = { + "chosen_input_ids": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "chosen_labels": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "rejected_input_ids": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "rejected_labels": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "index": list(range(num_samples)), + } + ds = Dataset.from_dict(data) + ds.set_format(type="pt") + return ds + + +class TestTokenBudgetPacking(unittest.TestCase): + def test_packs_to_token_budget_not_sample_cap(self): + max_seq_length = 16384 + seq_len = 100 + num_samples = 200 + global_batch_size = 4 + dataset = _make_fixed_length_dpo_dataset(num_samples, seq_len) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + + with tempfile.TemporaryDirectory() as work_dir: + loader = data_loader.HFDataLoader( + dataset=dataset, + batch_size=global_batch_size, + seed=42, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + collator=collator, + drop_last=False, + ) + + row_sizes = [] + seen_indices = set() + for batch in loader: + for row in padding_free_collator.unstack_packed_rows(batch): + row_sizes.append(len(row["index"])) + seen_indices.update(row["index"].tolist()) + self.assertLessEqual(row["chosen_cu_seq_lens_k"][-1].item(), max_seq_length) + self.assertLessEqual(row["rejected_cu_seq_lens_k"][-1].item(), max_seq_length) + + self.assertGreater(max(row_sizes), global_batch_size) + self.assertEqual(seen_indices, set(range(num_samples))) + + def test_microbatch_sample_cap_binds(self): + max_seq_length = 16384 + seq_len = 100 + num_samples = 200 + cap = 3 + dataset = _make_fixed_length_dpo_dataset(num_samples, seq_len) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + + with tempfile.TemporaryDirectory() as work_dir: + loader = data_loader.HFDataLoader( + dataset=dataset, + batch_size=4, + seed=42, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + collator=collator, + drop_last=False, + microbatch_sample_cap=cap, + ) + + for batch in loader: + for row in padding_free_collator.unstack_packed_rows(batch): + self.assertLessEqual(len(row["index"]), cap) + + +class TestStackedPackedBatches(unittest.TestCase): + @parameterized.parameterized.expand([("rows2_dp1", 2, 1), ("rows4_dp1", 4, 1), ("rows2_dp2", 2, 2)]) + def test_yields_per_rank_rows_per_batch(self, _name, rows_per_rank, dp_world_size): + max_seq_length = 16384 + seq_len = 100 + num_samples = 200 + cap = 2 + dataset = _make_fixed_length_dpo_dataset(num_samples, seq_len) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + + with tempfile.TemporaryDirectory() as work_dir: + loaders = [ + data_loader.HFDataLoader( + dataset=dataset, + batch_size=rows_per_rank * dp_world_size, + seed=42, + dp_rank=rank, + dp_world_size=dp_world_size, + work_dir=work_dir, + collator=collator, + drop_last=True, + microbatch_sample_cap=cap, + ) + for rank in range(dp_world_size) + ] + + batch_counts = [loader.total_batches for loader in loaders] + self.assertTrue(all(c == batch_counts[0] for c in batch_counts), f"Step counts differ: {batch_counts}") + + for loader in loaders: + num_batches = 0 + for batch in loader: + self.assertIsInstance(batch, dict) + self.assertEqual(batch["chosen_input_ids"].shape, (rows_per_rank, max_seq_length)) + rows = padding_free_collator.unstack_packed_rows(batch) + self.assertEqual(len(rows), rows_per_rank) + for row in rows: + self.assertLessEqual(len(row["index"]), cap) + num_batches += 1 + self.assertEqual(num_batches, loader.total_batches) + + def test_stack_unstack_round_trip(self): + max_seq_length = 512 + dataset = _make_dpo_dataset(num_samples=7, max_seq_length=max_seq_length) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + rows = [ + collator([dataset[0], dataset[1]]) | {"is_padding": False}, + collator([dataset[2]]) | {"is_padding": False}, + collator([dataset[3], dataset[4], dataset[5]]) | {"is_padding": True}, + ] + + stacked = padding_free_collator.stack_packed_rows(rows) + unstacked = padding_free_collator.unstack_packed_rows(stacked) + + self.assertEqual(len(unstacked), len(rows)) + for original, restored in zip(rows, unstacked): + self.assertEqual(set(original.keys()), set(restored.keys())) + for k, v in original.items(): + if k.endswith(("max_length_q", "max_length_k")): + # Stacking reduces max_length to a batch-level max (a safe upper bound). + self.assertEqual(restored[k], max(r[k] for r in rows)) + elif isinstance(v, torch.Tensor): + torch.testing.assert_close(restored[k], v) + else: + self.assertEqual(restored[k], v) + + if __name__ == "__main__": unittest.main() diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh index 992ff01645..743333456e 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh @@ -16,28 +16,21 @@ BEAKER_IMAGE="${1:-finbarrt/hybrid-dpo-stable}" BASE_PATH="/weka/oe-adapt-default/nathanl/checkpoints" SFT_MODELS=( - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_6e-5/step3256-hf" - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_5e-5/step3256-hf" - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_8e-5/step3256-hf" - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_9e-5/step3256-hf" - "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_2.5e-5/step3256-hf" # Final Instruct SFT Model - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_1e-4/step3256-hf" + allenai/Olmo-Hybrid-Instruct-SFT-7B ) DPO_LRS=( - 2e-6 - # 1e-6 - 8.5e-7 - 7e-7 - 5e-7 - 2.5e-7 + #2e-6 + 1e-6 + #8.5e-7 + #7e-7 + #5e-7 + #2.5e-7 ) for MODEL_PATH in "${SFT_MODELS[@]}"; do - # Extract SFT LR from path, e.g. HYBRID_INSTRUCT_SFT_0218_6e-5 -> 6e-5 - SFT_LR=$(basename "$(dirname "$MODEL_PATH")" | sed 's/.*_\([0-9.e-]*\)$/\1/') for LR in "${DPO_LRS[@]}"; do - EXP_NAME="hybrid-7b-DPO-0219-SFT-${SFT_LR}-LR-${LR}" + EXP_NAME="hybrid-7b-DPO-0219-SFT-public-LR-${LR}" echo "=====================================" echo "Launching: ${EXP_NAME}" echo " SFT model: ${MODEL_PATH}" @@ -46,8 +39,8 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do uv run python mason.py \ --cluster ai2/jupiter \ - --description "Hybrid 7B DPO sweep, SFT-${SFT_LR}, LR=${LR}, 4 nodes, 16k seq, ZeRO-3." \ - --workspace ai2/olmo-instruct \ + --description "Hybrid 7B DPO sweep, SFT-public, LR=${LR}, 4 nodes, 16k seq, ZeRO-3." \ + --workspace ai2/linear-rnns \ --priority urgent \ --max_retries 0 \ --preemptible \ diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh new file mode 100755 index 0000000000..994a624fe8 --- /dev/null +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -0,0 +1,112 @@ +#!/bin/bash +# DPO sweep for hybrid instruct models, using OLMo-core (dpo.py) instead of +# dpo_tune_cache.py (Accelerate + DeepSpeed ZeRO-3). +# +# Usage (with pre-built image, no Docker build needed): +# bash scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +# +# Usage (with build_image_and_launch.sh, slow ~1hr Docker build): +# ./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +# +# NOTE: dpo.py builds the model with OLMo-core's native TransformerConfig, so the +# hybrid architecture must be resolvable from --config_name. See OLMO_MODEL_CONFIG_MAP +# / get_transformer_config in open_instruct/olmo_core_utils.py. + +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" + +# MFU-tuning knobs (defaults reproduce the original fully-sharded config). +FSDP_SHARD_DEGREE="${FSDP_SHARD_DEGREE:-32}" +FSDP_NUM_REPLICAS="${FSDP_NUM_REPLICAS:-1}" +ACTIVATION_CHECKPOINTING_MODE="${ACTIVATION_CHECKPOINTING_MODE:-selected_modules}" +PER_DEVICE_TRAIN_BATCH_SIZE="${PER_DEVICE_TRAIN_BATCH_SIZE:-1}" +GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-4}" +EXP_TAG="${EXP_TAG:-}" +PROFILING="${PROFILING:-false}" +PROFILING_FLAG="" +if [ "$PROFILING" = "true" ]; then PROFILING_FLAG="--profiling"; fi +MAX_TRAIN_STEPS="${MAX_TRAIN_STEPS:-}" +MAX_TRAIN_STEPS_FLAG="" +if [ -n "$MAX_TRAIN_STEPS" ]; then MAX_TRAIN_STEPS_FLAG="--max_train_steps $MAX_TRAIN_STEPS"; fi +AC_MODULES_FLAG="" +if [ -n "$AC_MODULES" ]; then AC_MODULES_FLAG="--activation_checkpointing_modules $AC_MODULES"; fi +TENSOR_PARALLEL_DEGREE="${TENSOR_PARALLEL_DEGREE:-1}" + +SFT_MODELS=( + allenai/Olmo-Hybrid-Instruct-SFT-7B +) + +DPO_LRS=( + 1e-6 +) + +# OLMo-core TransformerConfig preset for the hybrid 7B model. Must be a config +# name registered with olmo-core's TransformerConfig (see olmo_core_utils.py). +CONFIG_NAME=olmo3_hybrid_7B + +for MODEL_PATH in "${SFT_MODELS[@]}"; do + for LR in "${DPO_LRS[@]}"; do + EXP_NAME="hybrid-7b-DPO-oc-0219-SFT-public-LR-${LR}${EXP_TAG}" + echo "=====================================" + echo "Launching: ${EXP_NAME}" + echo " SFT model: ${MODEL_PATH}" + echo " DPO LR: ${LR}" + echo "=====================================" + + uv run python mason.py \ + --cluster ai2/jupiter \ + --description "Hybrid 7B DPO sweep (OLMo-core), LR=${LR}, 4 nodes, 16k seq." \ + --workspace ai2/linear-rnns \ + --priority urgent \ + --max_retries 0 \ + --artifact_ttl 1d \ + --preemptible \ + --image "$BEAKER_IMAGE" \ + --pure_docker_mode \ + --no_auto_dataset_cache \ + --env OLMO_SHARED_FS=1 \ + --env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + --env NCCL_IB_HCA=^=mlx5_bond_0 \ + --env NCCL_SOCKET_IFNAME=ib \ + --env TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + --env TORCH_DIST_INIT_BARRIER=1 \ + --env TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=1800 \ + --num_nodes 4 \ + --gpus 8 -- torchrun \ + --nnodes=4 \ + --node_rank=\$BEAKER_REPLICA_RANK \ + --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ + --master_port=29400 \ + --nproc_per_node=8 \ + open_instruct/dpo.py \ + --exp_name "$EXP_NAME" \ + --model_name_or_path "$MODEL_PATH" \ + --config_name "$CONFIG_NAME" \ + --chat_template_name olmo123 \ + --mixer_list allenai/Dolci-Instruct-DPO-fixed 259922 \ + --max_seq_length 16384 \ + --per_device_train_batch_size "$PER_DEVICE_TRAIN_BATCH_SIZE" \ + --gradient_accumulation_steps "$GRADIENT_ACCUMULATION_STEPS" \ + --fsdp_shard_degree "$FSDP_SHARD_DEGREE" \ + --fsdp_num_replicas "$FSDP_NUM_REPLICAS" \ + --tensor_parallel_degree "$TENSOR_PARALLEL_DEGREE" \ + --learning_rate "$LR" \ + --lr_scheduler_type linear \ + --checkpointing_steps 500 \ + --keep_last_n_checkpoints -1 \ + --warmup_ratio 0.1 \ + --weight_decay 0.0 \ + --num_epochs 1 \ + --logging_steps 1 \ + --loss_type dpo_norm \ + --beta 5 \ + --packing \ + --push_to_hub False \ + --try_launch_beaker_eval_jobs False \ + --activation_checkpointing_mode "$ACTIVATION_CHECKPOINTING_MODE" \ + $AC_MODULES_FLAG \ + --compile_model true \ + $PROFILING_FLAG \ + $MAX_TRAIN_STEPS_FLAG \ + --with_tracking + done +done