Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
78f2dd1
minor tweaks to script
finbarrtimbers Jun 1, 2026
ac6a4ad
using ai2/linear-rnns workspace
finbarrtimbers Jun 1, 2026
d0d8ea1
modified sweep
finbarrtimbers Jun 1, 2026
86160a6
only one lr
finbarrtimbers Jun 2, 2026
7b34e23
Add Olmo Hybrid DPO sweep (olmo-core) and GDN-aware ModelDims FLOPs/m…
finbarrtimbers Jun 2, 2026
1fbac3f
Simplify ModelDims GDN handling: zero-default linear-attn dims, dedup…
finbarrtimbers Jun 2, 2026
2c23960
Bump olmo-core to hybrid-dpo-conversion branch for Olmo-Hybrid DPO su…
finbarrtimbers Jun 2, 2026
600f83e
Fully shard (fsdp_shard_degree=32) Olmo-Hybrid DPO sweep to fix OOM, …
finbarrtimbers Jun 2, 2026
73ad560
Add full-block activation checkpointing mode for olmo-core DPO to fit…
finbarrtimbers Jun 2, 2026
2daaa2a
Disable torch.compile in Olmo-Hybrid DPO sweep: compile+full-block ch…
finbarrtimbers Jun 2, 2026
e2a385c
Bump flash-linear-attention 0.4.2 -> 0.5.0
finbarrtimbers Jun 2, 2026
53e2af3
Add selected_modules activation checkpointing to enable compile with …
finbarrtimbers Jun 2, 2026
ed6b218
Checkpoint all Olmo-Hybrid block submodules except the GDN mixer for …
finbarrtimbers Jun 2, 2026
66bbd9c
Use full AC + compile for Olmo-Hybrid DPO by skipping checkpoint dete…
finbarrtimbers Jun 2, 2026
872ad77
DPO: checkpoint GDN mixer via selected_modules to keep compile outsid…
finbarrtimbers Jun 2, 2026
63348c8
Add tilelang dep so fla routes GDN chunk_bwd_dqkwg around the broken …
finbarrtimbers Jun 2, 2026
f0c8b07
DPO: align dpo.py wandb metric keys with dpo_tune_cache.py (rename tr…
finbarrtimbers Jun 3, 2026
7856a45
committed changes
finbarrtimbers Jun 3, 2026
9e00c78
Added scripts
finbarrtimbers Jun 3, 2026
0ba1ec8
DPO: bucket-pad packed microbatches to next power-of-two (not max_seq…
finbarrtimbers Jun 3, 2026
e1dfe41
DPO: pack microbatches to the max_seq_length token budget instead of …
finbarrtimbers Jun 3, 2026
5cf1528
DPO: add configurable per-microbatch sample cap + real gradient accum…
finbarrtimbers Jun 4, 2026
7949208
DPO: bound get_mock_batch rows by token budget so a large microbatch_…
finbarrtimbers Jun 4, 2026
ac010f0
set flag
finbarrtimbers Jun 4, 2026
7957881
disable HF upload
finbarrtimbers Jun 4, 2026
91afdf0
moved flag
finbarrtimbers Jun 4, 2026
efd9ad4
set flags correctly
finbarrtimbers Jun 4, 2026
a59d9d2
cleaned up pr
finbarrtimbers Jun 4, 2026
a4ccbfd
DPO: restore per-step train/token_count record so PerfCallback can co…
finbarrtimbers Jun 4, 2026
6d807e6
Drop leftover TRITON_PRINT_AUTOTUNING debug env from oc DPO sweep scr…
finbarrtimbers Jun 5, 2026
066265b
Fix stale SFT_LR reference in DeepSpeed sweep description (PR review)…
finbarrtimbers Jun 5, 2026
9d99e3e
Simplify: delegate global_num_flops_in_batch token count to data_load…
finbarrtimbers Jun 5, 2026
1cab2fa
Merge remote-tracking branch 'origin/main' into finbarr/oc-hybrid-dpo
finbarrtimbers Jun 5, 2026
cecab92
DPO packing: yield rectangular stacked-row batches (stack/unstack_pac…
finbarrtimbers Jun 5, 2026
de4ff38
Update CHANGELOG entry for rectangular stacked packed-row DPO batches…
finbarrtimbers Jun 5, 2026
e224447
Simplify: get_num_sequences always returns int (counts input_ids rows…
finbarrtimbers Jun 5, 2026
0f5c4e4
Merge remote-tracking branch 'origin/main' into finbarr/oc-hybrid-dpo
finbarrtimbers Jun 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.


### Changed
- Improve OLMo-core DPO MFU for the Olmo Hybrid (GDN) model: pack DPO microbatches to the `max_seq_length` token budget instead of capping at `per_device_train_batch_size` sequences, and yield rectangular stacked packed-row batches (`stack_packed_rows`/`unstack_packed_rows`) so OLMo-core's dict batch contract, batch-size validation, and token accounting work natively (gradient accumulation = packed rows per rank per step) (https://github.com/allenai/open-instruct/pull/1713).
- Add minimal support for DPO-training the Olmo Hybrid 7B (GDN linear attention) model with `open_instruct/dpo.py` (OLMo-core): bump OLMo-core to a rev with the `olmo3_hybrid_7B` preset and bidirectional HF weight conversion, bump `flash-linear-attention` to 0.5.0 and add `tilelang` (correct GDN gradients on Hopper), add `selected_modules` activation checkpointing (forwarding `determinism_check="none"` through OLMo-core's activation-checkpointing config so `torch.compile` and the opaque `fla` kernels coexist), extend `ModelDims` with GDN FLOPs/params accounting, and add `scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh` (https://github.com/allenai/open-instruct/pull/1715).
- Record a `_metrics_keepalive` metric on every rank every GRPO+OLMo-core step to keep `_metrics` non-empty, preventing OLMo-core's empty-skip in `_log_metrics` from desyncing the bookkeeping process group and deadlocking gloo for 30 minutes at save-time flushes (https://github.com/allenai/open-instruct/pull/1708).
- Expand type-checking coverage by replacing `# ty: ignore` directives with typed casts and fixing related type issues (https://github.com/allenai/open-instruct/pull/1688).
Expand Down
83 changes: 54 additions & 29 deletions open_instruct/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,17 @@ def __init__(
drop_last: bool = True,
fs_local_rank: int | None = None,
max_seq_length: int = 1,
microbatch_sample_cap: int | None = None,
) -> None:
"""Initialize the HFDataLoader.

Args:
dataset: The HuggingFace Dataset to load data from. Must have an 'index' column.
batch_size: The global batch size (in sequences).
batch_size: The global batch size in instances. Without packing, an instance is one
example. With packing (a collator with max_seq_length set), an instance is one
packed row holding a variable number of examples, and each rank's batch is the
stack of its `batch_size // dp_world_size` rows (see
padding_free_collator.stack_packed_rows).
seed: Random seed for shuffling.
dp_rank: The rank of the current process in the distributed setup.
dp_world_size: Total number of data-parallel processes in the distributed setup.
Expand All @@ -108,8 +113,11 @@ def __init__(
drop_last: If True, drop the last incomplete batch. If False, pad the last batch
with repeated indices to fill a complete batch.
fs_local_rank: File system local rank. Defaults to dp_rank when None.
max_seq_length: Maximum sequence length. Used to report global_batch_size in tokens
max_seq_length: Tokens per instance. Used to report global_batch_size in tokens
to the trainer for batch-size validation.
microbatch_sample_cap: When packing, the maximum number of examples per packed
row. A row closes when either the token budget or this cap is reached.
None means pack purely to the token budget.

Note:
The dataset must have an 'index' column for tracking samples across epochs.
Expand Down Expand Up @@ -144,7 +152,9 @@ def __init__(
f"The effective global batch size will be {batch_size // dp_world_size * dp_world_size}."
)
self._per_rank_batch_size = batch_size // dp_world_size
self._microbatch_sample_cap = microbatch_sample_cap
self._collator = collator if collator is not None else (lambda x: {"examples": x})
self._collator_max_seq_length = getattr(self._collator, "max_seq_length", None)
self._automatic_reshuffle = automatic_reshuffle
self._drop_last = drop_last
self._excluded_indices: set[int] = set()
Expand Down Expand Up @@ -176,24 +186,28 @@ def __next__(self) -> dict[str, Any]:

def _iter_batches(self) -> Iterable[dict[str, Any]]:
"""Return an iterable over all batches in the epoch."""
# World-aware packing: batch boundaries were precomputed by
# World-aware packing: row boundaries were precomputed by
# _reshard_with_packing so that every rank has the same number of
# batches. Each entry in _precomputed_batch_sizes is the number of
# examples in that batch (variable due to packing).
# rows. Each entry in _precomputed_batch_sizes is the number of
# examples in that row (variable due to packing). Each yielded batch
# stacks per_rank_batch_size rows into one rectangular dict; the train
# module splits it back into per-row microbatches.
if self._precomputed_batch_sizes is not None:
rows_per_batch = self._per_rank_batch_size
num_real = len(self._precomputed_batch_sizes) - self._num_padding_batches
offset = 0
for batch_idx, batch_size in enumerate(self._precomputed_batch_sizes):
if batch_idx < self.batches_processed:
offset += batch_size
continue
examples = []
for i in range(offset, offset + batch_size):
example = self.dataset[i]
examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"})
batch = to_device(self._collator(examples), self._device) | {"is_padding": batch_idx >= num_real}
offset += batch_size
yield batch
num_batches = len(self._precomputed_batch_sizes) // rows_per_batch
offsets = [0]
for row_size in self._precomputed_batch_sizes:
offsets.append(offsets[-1] + row_size)
for batch_idx in range(self.batches_processed, num_batches):
rows = []
for row_idx in range(batch_idx * rows_per_batch, (batch_idx + 1) * rows_per_batch):
examples = []
for i in range(offsets[row_idx], offsets[row_idx + 1]):
example = self.dataset[i]
examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"})
rows.append(self._collator(examples) | {"is_padding": row_idx >= num_real})
yield to_device(padding_free_collator.stack_packed_rows(rows), self._device)
return

start_example = self.batches_processed * self._per_rank_batch_size
Expand All @@ -219,7 +233,7 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]:
def total_batches(self) -> int:
"""Return the total number of batches in an epoch."""
if self._precomputed_batch_sizes is not None:
return len(self._precomputed_batch_sizes)
return len(self._precomputed_batch_sizes) // self._per_rank_batch_size
return self.effective_size // self._per_rank_batch_size

def state_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -272,9 +286,8 @@ def _reshard(self, epoch: int) -> None:
mask = np.isin(all_indices, list(self._excluded_indices), invert=True)
all_indices = all_indices[mask]

packing_enabled = hasattr(self._collator, "max_seq_length") and self._collator.max_seq_length is not None
if packing_enabled:
self._reshard_with_packing(all_indices)
if self._collator_max_seq_length is not None:
self._reshard_with_packing(all_indices, self._collator_max_seq_length)
return

self._precomputed_batch_sizes = None
Expand All @@ -300,15 +313,14 @@ def _reshard(self, epoch: int) -> None:
self.effective_size = len(rank_indices)
self.dataset = self._full_dataset.select(rank_indices.tolist())

def _reshard_with_packing(self, all_indices: np.ndarray) -> None:
def _reshard_with_packing(self, all_indices: np.ndarray, max_seq_length: int) -> None:
"""Reshard with world-aware packing so all ranks get the same batch count.

Instead of distributing examples to ranks and letting each rank pack
independently (which can produce different batch counts due to variable
overflow), this packs globally first and then distributes packed batches
round-robin to ranks.
"""
max_seq_length = self._collator.max_seq_length
column_names = self._full_dataset.column_names
subset = self._full_dataset.select(all_indices.tolist())
if "chosen_input_ids" in column_names:
Expand All @@ -324,9 +336,9 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None:
for i in range(len(all_indices)):
new_totals = [running_totals[s] + lengths[i][s] for s in range(num_streams)]
would_exceed = len(current_batch) > 0 and any(t > max_seq_length for t in new_totals)
at_max_samples = len(current_batch) >= self._per_rank_batch_size
at_cap = self._microbatch_sample_cap is not None and len(current_batch) >= self._microbatch_sample_cap

if would_exceed or at_max_samples:
if would_exceed or at_cap:
batches.append(current_batch)
current_batch = [i]
running_totals = list(lengths[i])
Expand All @@ -337,14 +349,19 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None:
if current_batch:
batches.append(current_batch)

# Rows are distributed round-robin to ranks and then stacked into
# per_rank_batch_size-sized batches, so the global count must be a
# multiple of dp_world_size * per_rank_batch_size for every rank to have
# the same number of complete batches.
group_size = self.dp_world_size * self._per_rank_batch_size
num_batches = len(batches)
padding_start = num_batches
if self._drop_last:
num_batches = (num_batches // self.dp_world_size) * self.dp_world_size
num_batches = (num_batches // group_size) * group_size
batches = batches[:num_batches]
else:
if (remainder := num_batches % self.dp_world_size) > 0:
for _ in range(self.dp_world_size - remainder):
if (remainder := num_batches % group_size) > 0:
for _ in range(group_size - remainder):
batches.append(batches[-1])

rank_global_indices = list(range(self.dp_rank, len(batches), self.dp_world_size))
Expand All @@ -369,8 +386,16 @@ def get_mock_batch(self) -> dict[str, Any]:
forward and backward pass before training officially starts.
"""
num_examples = min(self._per_rank_batch_size, len(self.dataset))
# When packing, the collator consumes only as many examples as fit the token
# budget, so at most max_seq_length examples (each >= 1 token) can be used.
# Bound the rows loaded so a large microbatch_sample_cap doesn't load the dataset.
if self._collator_max_seq_length is not None:
num_examples = min(num_examples, self._collator_max_seq_length)
examples = [self.dataset[i] for i in range(num_examples)]
return to_device(self._collator(examples), self._device)
collated = self._collator(examples)
if self._collator_max_seq_length is not None:
collated = padding_free_collator.stack_packed_rows([collated])
return to_device(collated, self._device)

def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int:
"""Return the total number of tokens in the batch across all ranks.
Expand Down
29 changes: 18 additions & 11 deletions open_instruct/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from open_instruct import data_loader as data_loader_lib
from open_instruct import dataset_transformation, dpo_utils, logger_utils, model_utils, olmo_core_utils, utils
from open_instruct.olmo_core_callbacks import PerfCallback
from open_instruct.olmo_core_train_modules import DPOTrainModule
from open_instruct.olmo_core_train_modules import DPOMetricsCallback, DPOTrainModule
from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO

logger = logger_utils.setup_logger(__name__)
Expand Down Expand Up @@ -66,13 +66,13 @@ def _setup_callbacks(args: dpo_utils.DPOExperimentConfig, dp_world_size: int):
wandb_entity=args.wandb_entity,
save_async=False,
)
trainer_callbacks["dpo_metrics"] = DPOMetricsCallback()
slack_webhook_url = os.environ.get("SLACK_WEBHOOK_URL")
if args.send_slack_alerts and slack_webhook_url:
trainer_callbacks["slack"] = callbacks.SlackNotifierCallback(name=run_name, webhook_url=slack_webhook_url)
model_dims = utils.ModelDims.from_hf_config(args.model_name_or_path)
trainer_callbacks["perf"] = PerfCallback(
model_dims=model_dims,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
dp_world_size=dp_world_size,
tensor_parallel_degree=args.tensor_parallel_degree,
Expand Down Expand Up @@ -191,8 +191,15 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz
else:
collator = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=None, padding="longest")

rank_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps
global_batch_size = rank_batch_size * dp_world_size
# With packing, an instance is one packed row (one microbatch), so a rank accumulates
# gradient_accumulation_steps rows per optimizer step. Without packing, an instance is
# one example and a rank consumes per_device_train_batch_size * gradient_accumulation_steps
# examples per step. Either way, each instance spans 2 * max_seq_length tokens
# (chosen + rejected streams).
if args.packing:
global_batch_size = args.gradient_accumulation_steps * dp_world_size
else:
global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * dp_world_size
data_loader = data_loader_lib.HFDataLoader(
dataset=dataset,
batch_size=global_batch_size,
Expand All @@ -204,14 +211,13 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz
device=device,
drop_last=True,
fs_local_rank=global_rank,
max_seq_length=2 * args.max_seq_length,
microbatch_sample_cap=args.per_device_train_batch_size,
)
# 4x batch size: forward-only (no backward), so no activation storage needed.
# With packing, the collator's token budget controls the actual forward-pass size
# and the overflow mechanism in HFDataLoader ensures no examples are dropped.
# We could probably have logic to use a longer sequence length here when packing
# is enabled, but for simplicity we just keep the 4x increase in batch size regardless of packing.
# We want the batch size to be as large as possible so that we always pack efficiently.
cache_batch_size = int(args.per_device_train_batch_size * 4 * dp_world_size)
# Forward-only (no backward), so no activation storage is needed. With packing each
# instance is already a full token-budget row, so one row per rank per batch suffices;
# without packing we use a 4x larger example batch.
cache_batch_size = dp_world_size if args.packing else int(args.per_device_train_batch_size * 4 * dp_world_size)
cache_data_loader = data_loader_lib.HFDataLoader(
dataset=dataset,
batch_size=cache_batch_size,
Expand All @@ -224,6 +230,7 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz
# We need to process every example to cache reference logprobs, so we can't drop the last batch.
drop_last=False,
fs_local_rank=global_rank,
max_seq_length=2 * args.max_seq_length,
)

forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo
Expand Down
51 changes: 26 additions & 25 deletions open_instruct/dpo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,34 +459,35 @@ def build_reference_logprobs_cache(

with torch.no_grad():
for batch in (pbar := tqdm(dataloader, disable=not is_main_process, desc="Caching reference logprobs")):
batch_start = time.perf_counter()
if use_lora and disable_adapter_context is not None:
with disable_adapter_context():
for row in padding_free_collator.unstack_packed_rows(batch):
row_start = time.perf_counter()
if use_lora and disable_adapter_context is not None:
with disable_adapter_context():
chosen_logps, rejected_logps, _ = forward_fn(
model, row, average_log_prob=average_log_prob, **(forward_kwargs or {})
)
else:
chosen_logps, rejected_logps, _ = forward_fn(
model, batch, average_log_prob=average_log_prob, **(forward_kwargs or {})
model, row, average_log_prob=average_log_prob, **(forward_kwargs or {})
)
else:
chosen_logps, rejected_logps, _ = forward_fn(
model, batch, average_log_prob=average_log_prob, **(forward_kwargs or {})
)

if batch.get("is_padding", False):
continue

chosen_tensor[batch["index"]] = chosen_logps
rejected_tensor[batch["index"]] = rejected_logps

batch_tokens, batch_size, chosen_lengths, rejected_lengths = _get_batch_stats(batch)
total_tokens += batch_tokens
total_examples += batch_size
pbar.set_postfix(
{
"avg_tok/ex": f"{total_tokens / total_examples:.0f}",
"MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - batch_start):.1f}",
"mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}",
"mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}",
}
)
if row.get("is_padding", False):
continue

chosen_tensor[row["index"]] = chosen_logps
rejected_tensor[row["index"]] = rejected_logps

batch_tokens, batch_size, chosen_lengths, rejected_lengths = _get_batch_stats(row)
total_tokens += batch_tokens
total_examples += batch_size
pbar.set_postfix(
{
"avg_tok/ex": f"{total_tokens / total_examples:.0f}",
"MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - row_start):.1f}",
"mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}",
"mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}",
}
)

dist.all_reduce(chosen_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(rejected_tensor, op=dist.ReduceOp.MAX)
Expand Down
Loading
Loading