From 78f2dd1c5cf6db10e4af3dc68dba1180f2fb8783 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 1 Jun 2026 11:54:53 -0600 Subject: [PATCH 01/35] minor tweaks to script --- .../olmo-hybrid/7b_instruct_dpo_sweep.sh | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh index 992ff01645..b134186c8d 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh @@ -16,28 +16,21 @@ BEAKER_IMAGE="${1:-finbarrt/hybrid-dpo-stable}" BASE_PATH="/weka/oe-adapt-default/nathanl/checkpoints" SFT_MODELS=( - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_6e-5/step3256-hf" - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_5e-5/step3256-hf" - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_8e-5/step3256-hf" - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_9e-5/step3256-hf" - "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_2.5e-5/step3256-hf" # Final Instruct SFT Model - # "${BASE_PATH}/HYBRID_INSTRUCT_SFT_0218_1e-4/step3256-hf" + allenai/Olmo-Hybrid-Instruct-SFT-7B ) DPO_LRS=( - 2e-6 - # 1e-6 - 8.5e-7 - 7e-7 - 5e-7 - 2.5e-7 + #2e-6 + 1e-6 + #8.5e-7 + #7e-7 + #5e-7 + #2.5e-7 ) for MODEL_PATH in "${SFT_MODELS[@]}"; do - # Extract SFT LR from path, e.g. HYBRID_INSTRUCT_SFT_0218_6e-5 -> 6e-5 - SFT_LR=$(basename "$(dirname "$MODEL_PATH")" | sed 's/.*_\([0-9.e-]*\)$/\1/') for LR in "${DPO_LRS[@]}"; do - EXP_NAME="hybrid-7b-DPO-0219-SFT-${SFT_LR}-LR-${LR}" + EXP_NAME="hybrid-7b-DPO-0219-SFT-public-LR-${LR}" echo "=====================================" echo "Launching: ${EXP_NAME}" echo " SFT model: ${MODEL_PATH}" From ac6a4adc713167ac323ccdfa02bfe08f57224dfb Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 1 Jun 2026 12:06:16 -0600 Subject: [PATCH 02/35] using ai2/linear-rnns workspace --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh index b134186c8d..28aa919c8f 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh @@ -40,7 +40,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do uv run python mason.py \ --cluster ai2/jupiter \ --description "Hybrid 7B DPO sweep, SFT-${SFT_LR}, LR=${LR}, 4 nodes, 16k seq, ZeRO-3." \ - --workspace ai2/olmo-instruct \ + --workspace ai2/linear-rnns \ --priority urgent \ --max_retries 0 \ --preemptible \ From d0d8ea1d8d46f89d09b15a10aa3bbb8303dff45d Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 1 Jun 2026 15:49:11 -0600 Subject: [PATCH 03/35] modified sweep --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh index 28aa919c8f..07e7b861b1 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh @@ -20,12 +20,12 @@ SFT_MODELS=( ) DPO_LRS=( - #2e-6 - 1e-6 - #8.5e-7 - #7e-7 - #5e-7 - #2.5e-7 + 2e-6 + #1e-6 + 8.5e-7 + 7e-7 + 5e-7 + 2.5e-7 ) for MODEL_PATH in "${SFT_MODELS[@]}"; do From 86160a65fe4ec34ae62614cc5dae34eaf6273674 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 1 Jun 2026 19:27:51 -0600 Subject: [PATCH 04/35] only one lr --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh index 07e7b861b1..28aa919c8f 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh @@ -20,12 +20,12 @@ SFT_MODELS=( ) DPO_LRS=( - 2e-6 - #1e-6 - 8.5e-7 - 7e-7 - 5e-7 - 2.5e-7 + #2e-6 + 1e-6 + #8.5e-7 + #7e-7 + #5e-7 + #2.5e-7 ) for MODEL_PATH in "${SFT_MODELS[@]}"; do From 7b34e2362bfdd76995592667b1091668bda2c826 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 08:35:34 -0600 Subject: [PATCH 05/35] Add Olmo Hybrid DPO sweep (olmo-core) and GDN-aware ModelDims FLOPs/memory Co-Authored-By: Claude Opus 4.8 --- open_instruct/test_utils.py | 105 +++++++++++++ open_instruct/utils.py | 140 ++++++++++++++++-- .../7b_instruct_dpo_sweep_olmo_core.sh | 86 +++++++++++ 3 files changed, 318 insertions(+), 13 deletions(-) create mode 100755 scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 29db0b61c1..4e1f673e61 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -739,6 +739,111 @@ def test_from_hf_config_cpu_only(self): self.assertIsNone(model_dims.device_name) + def test_from_hf_config_hybrid(self): + config = SimpleNamespace( + model_type="olmo_hybrid", + hidden_size=3840, + intermediate_size=11008, + num_attention_heads=30, + num_key_value_heads=30, + head_dim=128, + num_hidden_layers=8, + vocab_size=100352, + layer_types=["linear_attention", "linear_attention", "linear_attention", "full_attention"] * 2, + linear_num_key_heads=30, + linear_num_value_heads=30, + linear_key_head_dim=96, + linear_value_head_dim=192, + linear_conv_kernel_dim=4, + ) + config.get_text_config = lambda: config + + with ( + mock.patch("transformers.AutoConfig.from_pretrained", return_value=config), + mock.patch("torch.cuda.is_available", return_value=False), + ): + model_dims = utils.ModelDims.from_hf_config("test/hybrid") + + self.assertEqual(model_dims.num_linear_attn_layers, 6) + self.assertEqual(model_dims.linear_attn_key_dim, 30 * 96) + self.assertEqual(model_dims.linear_attn_value_dim, 30 * 192) + self.assertEqual(model_dims.linear_attn_conv_size, 4) + + +class TestModelDimsHybrid(unittest.TestCase): + def _hybrid_dims(self) -> utils.ModelDims: + return utils.ModelDims( + num_layers=8, + hidden_size=3840, + intermediate_size=11008, + vocab_size=100352, + num_attn_heads=30, + head_dim=128, + num_kv_heads=30, + num_linear_attn_layers=6, + linear_attn_num_k_heads=30, + linear_attn_num_v_heads=30, + linear_attn_key_head_dim=96, + linear_attn_value_head_dim=192, + linear_attn_conv_size=4, + device_name="h100", + ) + + def test_linear_attn_flops_scale_linearly(self): + dims = self._hybrid_dims() + self.assertEqual(dims.linear_attn_flops(2000), 2 * dims.linear_attn_flops(1000)) + + def test_decode_flops_constant_per_prompt_length_for_gdn(self): + # A purely linear-attention model has no growing context, so decode FLOPs are independent + # of prompt length (unlike softmax attention, where they grow with kv_len). + gdn_only = utils.ModelDims( + num_layers=6, + hidden_size=3840, + intermediate_size=11008, + vocab_size=100352, + num_attn_heads=30, + head_dim=128, + num_kv_heads=30, + num_linear_attn_layers=6, + linear_attn_num_k_heads=30, + linear_attn_num_v_heads=30, + linear_attn_key_head_dim=96, + linear_attn_value_head_dim=192, + linear_attn_conv_size=4, + device_name="h100", + ) + self.assertEqual(gdn_only.decode_flops([10], [32]), gdn_only.decode_flops([10000], [32])) + + def test_gdn_layers_excluded_from_kv_cache(self): + dims = self._hybrid_dims() + # Only the 2 full-attention layers (8 total - 6 GDN) write a KV cache. + per_token = 2 * dims.num_kv_heads * dims.head_dim * 2 + self.assertEqual(dims.kv_cache_write_bytes(100), 2 * 100 * per_token) + + def test_gdn_state_bytes_present(self): + dims = self._hybrid_dims() + state_elems = 30 * 96 * 192 + self.assertEqual(dims.gdn_state_bytes(10), 6 * 10 * 2 * state_elems * 2) + + def test_utilization_under_100(self): + dims = self._hybrid_dims() + prompt_lengths = [256] * 8 + response_lengths = [256] * 16 + metrics = utils.calculate_utilization_metrics( + model_dims=dims, + prompt_lengths=prompt_lengths, + response_lengths=response_lengths, + total_generation_time=8.0, + samples_per_prompt=2, + num_engines=2, + num_gpus_per_engine=2, + training_time=4.0, + num_training_gpus=4, + ) + self.assertLessEqual(metrics["actor_mfu"], 100) + self.assertLessEqual(metrics["actor_mbu"], 100) + self.assertLessEqual(metrics["learner_mfu"], 100) + # useful for checking if public datasets are still available # class CheckTuluDatasetsTest(unittest.TestCase): diff --git a/open_instruct/utils.py b/open_instruct/utils.py index efc0cdb061..56fb6aa08a 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1823,6 +1823,9 @@ def check_oe_eval_internal(): # Approximate softmax cost per attention score: # ~4 scalar ops/score: exp + subtract max (stabilization) + sum + divide. SOFTMAX_FLOPS_PER_SCORE = 4 +# Approximate number of state-matmul-equivalent passes in the gated delta rule +# recurrence per token (state readout, delta error, decayed state update, gating). +GDN_RECURRENCE_PASSES = 4 @dataclasses.dataclass @@ -1838,11 +1841,25 @@ class ModelDims: device_name: str | None = None sliding_window: int | None = None num_sliding_window_layers: int = 0 + num_linear_attn_layers: int = 0 + linear_attn_num_k_heads: int | None = None + linear_attn_num_v_heads: int | None = None + linear_attn_key_head_dim: int | None = None + linear_attn_value_head_dim: int | None = None + linear_attn_conv_size: int = 4 def __post_init__(self): if self.num_kv_heads is None: self.num_kv_heads = self.num_attn_heads + if self.num_linear_attn_layers > 0: + assert None not in ( + self.linear_attn_num_k_heads, + self.linear_attn_num_v_heads, + self.linear_attn_key_head_dim, + self.linear_attn_value_head_dim, + ), "linear attention dims must be set when num_linear_attn_layers > 0" + self.num_params = self.num_params or self._calculate_num_params() if self.device_name is None and torch.cuda.is_available(): @@ -1852,21 +1869,44 @@ def __post_init__(self): assert self.num_attn_heads % self.num_kv_heads == 0, ( "num_attn_heads must be divisible by num_kv_heads (GQA/MQA)" ) - assert self.num_sliding_window_layers <= self.num_layers, ( - f"num_sliding_window_layers ({self.num_sliding_window_layers}) cannot exceed num_layers ({self.num_layers})" + assert self.num_sliding_window_layers + self.num_linear_attn_layers <= self.num_layers, ( + f"num_sliding_window_layers ({self.num_sliding_window_layers}) + num_linear_attn_layers " + f"({self.num_linear_attn_layers}) cannot exceed num_layers ({self.num_layers})" ) + @property + def linear_attn_key_dim(self) -> int: + return self.linear_attn_num_k_heads * self.linear_attn_key_head_dim + + @property + def linear_attn_value_dim(self) -> int: + return self.linear_attn_num_v_heads * self.linear_attn_value_head_dim + + def _gdn_layer_params(self) -> int: + """Parameter count for one Gated Delta Net (linear attention) layer, excluding MLP.""" + h = self.hidden_size + kd = self.linear_attn_key_dim + vd = self.linear_attn_value_dim + proj_params = 2 * h * kd + 3 * h * vd + 2 * h * self.linear_attn_num_v_heads + conv_params = (2 * kd + vd) * self.linear_attn_conv_size + return proj_params + conv_params + def _calculate_num_params(self) -> int: embedding_params = self.vocab_size * self.hidden_size q_params = self.hidden_size * (self.num_attn_heads * self.head_dim) kv_params = self.hidden_size * (self.num_kv_heads * self.head_dim) * 2 o_params = (self.num_attn_heads * self.head_dim) * self.hidden_size + attn_params = q_params + kv_params + o_params + mlp_up_params = self.hidden_size * self.intermediate_size * 2 mlp_down_params = self.intermediate_size * self.hidden_size + mlp_params = mlp_up_params + mlp_down_params - per_layer_params = q_params + kv_params + o_params + mlp_up_params + mlp_down_params - layer_params = self.num_layers * per_layer_params + num_attn_layers = self.num_layers - self.num_linear_attn_layers + layer_params = self.num_layers * mlp_params + num_attn_layers * attn_params + if self.num_linear_attn_layers > 0: + layer_params += self.num_linear_attn_layers * self._gdn_layer_params() lm_head_params = self.vocab_size * self.hidden_size @@ -1888,6 +1928,19 @@ def from_hf_config(cls, model_name_or_path: str) -> "ModelDims": else: num_sliding_window_layers = config.num_hidden_layers head_dim = getattr(config, "head_dim", hidden_size // config.num_attention_heads) + + layer_types = getattr(config, "layer_types", None) + num_linear_attn_layers = layer_types.count("linear_attention") if layer_types is not None else 0 + linear_attn_kwargs = {} + if num_linear_attn_layers > 0: + linear_attn_kwargs = { + "linear_attn_num_k_heads": config.linear_num_key_heads, + "linear_attn_num_v_heads": config.linear_num_value_heads, + "linear_attn_key_head_dim": config.linear_key_head_dim, + "linear_attn_value_head_dim": config.linear_value_head_dim, + "linear_attn_conv_size": config.linear_conv_kernel_dim, + } + return cls( num_layers=config.num_hidden_layers, hidden_size=hidden_size, @@ -1898,7 +1951,9 @@ def from_hf_config(cls, model_name_or_path: str) -> "ModelDims": head_dim=head_dim, sliding_window=sliding_window, num_sliding_window_layers=num_sliding_window_layers, + num_linear_attn_layers=num_linear_attn_layers, device_name=get_device_name(torch.cuda.get_device_name(0)) if torch.cuda.is_available() else None, + **linear_attn_kwargs, ) @property @@ -1952,9 +2007,39 @@ def mlp_flops(self, seq_len: int) -> int: second = mul * seq_len * self.intermediate_size * self.hidden_size return first + act + second + def linear_attn_flops(self, seq_len: int) -> int: + """FLOPs for one Gated Delta Net (linear attention) layer over seq_len tokens. + + Linear attention is O(seq_len): the per-token cost is constant and independent of + context length, so prefill and decode share the same per-token cost. Dominated by the + input/output projections plus the recurrent gated-delta state update; conv and + gate/decay projections to the head count are minor but included. + """ + mul = FLOP_PER_MAC + kd = self.linear_attn_key_dim + vd = self.linear_attn_value_dim + + qk_proj = mul * 2 * seq_len * self.hidden_size * kd + v_proj = mul * seq_len * self.hidden_size * vd + g_proj = mul * seq_len * self.hidden_size * vd + ab_proj = mul * 2 * seq_len * self.hidden_size * self.linear_attn_num_v_heads + conv = mul * seq_len * self.linear_attn_conv_size * (2 * kd + vd) + out_proj = mul * seq_len * vd * self.hidden_size + + recurrence = ( + mul + * GDN_RECURRENCE_PASSES + * seq_len + * self.linear_attn_num_v_heads + * self.linear_attn_key_head_dim + * self.linear_attn_value_head_dim + ) + + return qk_proj + v_proj + g_proj + ab_proj + conv + out_proj + recurrence + def prefill_flops(self, prompt_lengths: list[int]) -> int: """Prefill builds the KV cache; logits are computed once after each prompt.""" - num_full_attn_layers = self.num_layers - self.num_sliding_window_layers + num_full_attn_layers = self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers num_sliding_layers = self.num_sliding_window_layers total = 0 @@ -1967,6 +2052,9 @@ def prefill_flops(self, prompt_lengths: list[int]) -> int: self.attn_flops(L, L, sliding_window=self.sliding_window) + self.mlp_flops(L) ) + if self.num_linear_attn_layers > 0: + total += self.num_linear_attn_layers * (self.linear_attn_flops(L) + self.mlp_flops(L)) + # LM head is applied to each token position during training total += L * FLOP_PER_MAC * self.hidden_size * self.vocab_size @@ -1986,7 +2074,7 @@ def decode_flops(self, prompt_lengths: list[int], response_lengths: list[int], s f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}" ) - num_full_attn_layers = self.num_layers - self.num_sliding_window_layers + num_full_attn_layers = self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers num_sliding_layers = self.num_sliding_window_layers total = 0 @@ -1996,6 +2084,9 @@ def decode_flops(self, prompt_lengths: list[int], response_lengths: list[int], s for _ in range(samples_per_prompt): R = response_lengths[response_idx] total += R * self.num_layers * self.mlp_flops(seq_len=1) + # Linear attention per-token cost is constant in context length. + if self.num_linear_attn_layers > 0: + total += self.num_linear_attn_layers * self.linear_attn_flops(R) for t in range(R): kv_len = P + t + 1 # prompt + generated so far + current if num_full_attn_layers > 0: @@ -2046,16 +2137,24 @@ def weight_memory_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int: hidden_q = self.num_attn_heads * self.head_dim hidden_kv = self.num_kv_heads * self.head_dim - # Per-layer weight params (Q, K, V, O, MLP up, MLP down) + # Per-attention-layer weight params (Q, K, V, O) w_q = self.hidden_size * hidden_q w_k = self.hidden_size * hidden_kv w_v = self.hidden_size * hidden_kv w_o = hidden_q * self.hidden_size + attn_weights = w_q + w_k + w_v + w_o + + # Per-layer MLP weights (up, down), shared by all layer types w_up = self.hidden_size * (self.intermediate_size * 2) # times 2 due to SwiGLU w_dn = self.intermediate_size * self.hidden_size + mlp_weights = w_up + w_dn + + num_attn_layers = self.num_layers - self.num_linear_attn_layers + total_weights = self.num_layers * mlp_weights + num_attn_layers * attn_weights + if self.num_linear_attn_layers > 0: + total_weights += self.num_linear_attn_layers * self._gdn_layer_params() - per_layer_weight_bytes = (w_q + w_k + w_v + w_o + w_up + w_dn) * dtype_bytes - return self.num_layers * num_tokens * per_layer_weight_bytes + return num_tokens * total_weights * dtype_bytes def kv_cache_write_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int: """Memory bytes for writing KV cache for a given number of tokens. @@ -2067,9 +2166,10 @@ def kv_cache_write_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int: Returns: Total bytes for KV cache writes across all layers """ - # 2x for K and V + # 2x for K and V. Linear attention layers keep a fixed recurrent state, not a KV cache. + num_kv_cache_layers = self.num_layers - self.num_linear_attn_layers kv_write_bytes_per_token = 2 * self.num_kv_heads * self.head_dim * dtype_bytes - return self.num_layers * num_tokens * kv_write_bytes_per_token + return num_kv_cache_layers * num_tokens * kv_write_bytes_per_token def kv_cache_read_bytes( self, prompt_lengths: list[int], response_lengths: list[int], samples_per_prompt: int = 1, dtype_bytes: int = 2 @@ -2092,7 +2192,7 @@ def kv_cache_read_bytes( f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}" ) - num_full_attn_layers = self.num_layers - self.num_sliding_window_layers + num_full_attn_layers = self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers num_sliding_layers = self.num_sliding_window_layers # For batched sampling with shared prompt KV cache: @@ -2128,6 +2228,19 @@ def kv_cache_read_bytes( kv_bytes_per_token = 2 * self.num_kv_heads * self.head_dim * dtype_bytes return kv_bytes_per_token * kv_read_terms + def gdn_state_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int: + """Memory bytes for reading and writing the Gated Delta Net recurrent state. + + Linear attention layers carry a fixed-size recurrent state instead of a growing KV + cache. Each decode step reads and writes the full state once per layer, independent of + context length. + """ + if self.num_linear_attn_layers == 0: + return 0 + state_elems = self.linear_attn_num_v_heads * self.linear_attn_key_head_dim * self.linear_attn_value_head_dim + # 2x for read + write + return self.num_linear_attn_layers * num_tokens * 2 * state_elems * dtype_bytes + def prefill_memory_bytes(self, prompt_lengths: list[int], dtype_bytes: int = 2) -> int: """Memory bytes for prefill phase. @@ -2187,7 +2300,8 @@ def decode_memory_bytes( kv_write_bytes = self.kv_cache_write_bytes(total_decode_tokens, dtype_bytes) kv_read_bytes = self.kv_cache_read_bytes(prompt_lengths, response_lengths, samples_per_prompt, dtype_bytes) - return weight_bytes + kv_write_bytes + kv_read_bytes + gdn_state_bytes = self.gdn_state_bytes(total_decode_tokens, dtype_bytes) + return weight_bytes + kv_write_bytes + kv_read_bytes + gdn_state_bytes def memory_bytes( self, diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh new file mode 100755 index 0000000000..0068b16644 --- /dev/null +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# DPO sweep for hybrid instruct models, using OLMo-core (dpo.py) instead of +# dpo_tune_cache.py (Accelerate + DeepSpeed ZeRO-3). +# +# Usage (with pre-built image, no Docker build needed): +# bash scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +# +# Usage (with build_image_and_launch.sh, slow ~1hr Docker build): +# ./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +# +# NOTE: dpo.py builds the model with OLMo-core's native TransformerConfig, so the +# hybrid architecture must be resolvable from --config_name. See OLMO_MODEL_CONFIG_MAP +# / get_transformer_config in open_instruct/olmo_core_utils.py. + +BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" + +SFT_MODELS=( + allenai/Olmo-Hybrid-Instruct-SFT-7B +) + +DPO_LRS=( + 1e-6 +) + +# OLMo-core TransformerConfig preset for the hybrid 7B model. Must be a config +# name registered with olmo-core's TransformerConfig (see olmo_core_utils.py). +CONFIG_NAME=olmo3_hybrid_7B + +for MODEL_PATH in "${SFT_MODELS[@]}"; do + for LR in "${DPO_LRS[@]}"; do + EXP_NAME="hybrid-7b-DPO-oc-0219-SFT-public-LR-${LR}" + echo "=====================================" + echo "Launching: ${EXP_NAME}" + echo " SFT model: ${MODEL_PATH}" + echo " DPO LR: ${LR}" + echo "=====================================" + + uv run python mason.py \ + --cluster ai2/jupiter \ + --description "Hybrid 7B DPO sweep (OLMo-core), LR=${LR}, 4 nodes, 16k seq." \ + --workspace ai2/linear-rnns \ + --priority urgent \ + --max_retries 0 \ + --preemptible \ + --image "$BEAKER_IMAGE" \ + --pure_docker_mode \ + --no_auto_dataset_cache \ + --env OLMO_SHARED_FS=1 \ + --env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + --env NCCL_IB_HCA=^=mlx5_bond_0 \ + --env NCCL_SOCKET_IFNAME=ib \ + --env TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ + --env TORCH_DIST_INIT_BARRIER=1 \ + --env TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=1800 \ + --env TRITON_PRINT_AUTOTUNING=1 \ + --num_nodes 4 \ + --gpus 8 -- torchrun \ + --nnodes=4 \ + --node_rank=\$BEAKER_REPLICA_RANK \ + --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ + --master_port=29400 \ + --nproc_per_node=8 \ + open_instruct/dpo.py \ + --exp_name "$EXP_NAME" \ + --model_name_or_path "$MODEL_PATH" \ + --config_name "$CONFIG_NAME" \ + --chat_template_name olmo123 \ + --mixer_list allenai/Dolci-Instruct-DPO-fixed 259922 \ + --max_seq_length 16384 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 4 \ + --learning_rate "$LR" \ + --lr_scheduler_type linear \ + --checkpointing_steps 500 \ + --keep_last_n_checkpoints -1 \ + --warmup_ratio 0.1 \ + --weight_decay 0.0 \ + --num_epochs 1 \ + --logging_steps 1 \ + --loss_type dpo_norm \ + --beta 5 \ + --packing \ + --activation_memory_budget 0.5 \ + --with_tracking + done +done From 1fbac3fbc70f30c98b7f1ee33113780a42da07d3 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 08:56:03 -0600 Subject: [PATCH 06/35] Simplify ModelDims GDN handling: zero-default linear-attn dims, dedup tests Co-Authored-By: Claude Opus 4.8 --- open_instruct/test_utils.py | 73 +++++++++++-------------------------- open_instruct/utils.py | 40 ++++++++------------ 2 files changed, 36 insertions(+), 77 deletions(-) diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 4e1f673e61..85651beef1 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # Copied from https://github.com/huggingface/alignment-handbook/blob/main/tests/test_data.py +import dataclasses import json import os import pathlib @@ -70,6 +71,22 @@ def _load_mbu_test_cases(): num_kv_heads=8, device_name="h100", ), + "olmo-hybrid-7b": utils.ModelDims( + num_layers=8, + hidden_size=3840, + intermediate_size=11008, + vocab_size=100352, + num_attn_heads=30, + head_dim=128, + num_kv_heads=30, + num_linear_attn_layers=6, + linear_attn_num_k_heads=30, + linear_attn_num_v_heads=30, + linear_attn_key_head_dim=96, + linear_attn_value_head_dim=192, + linear_attn_conv_size=4, + device_name="h100", + ), } @@ -592,6 +609,7 @@ def test_mbu_reproduction(self, name, case_data): ("two_engines_four_gpus_each", "Qwen/Qwen2.5-7B", 16, 2, 256, 256, 8, 2, 4, 4, 8.0, 4.0), ("four_engines_two_gpus_each", "Qwen/Qwen2.5-7B", 16, 2, 256, 256, 8, 4, 2, 4, 8.0, 4.0), ("single_engine_eight_gpus", "Qwen/Qwen2.5-7B", 16, 2, 256, 256, 8, 1, 8, 4, 8.0, 4.0), + ("hybrid_two_engines_two_gpus_each", "olmo-hybrid-7b", 8, 2, 256, 256, 4, 2, 2, 4, 8.0, 4.0), ] ) def test_multi_engine_utilization( @@ -771,23 +789,8 @@ def test_from_hf_config_hybrid(self): class TestModelDimsHybrid(unittest.TestCase): - def _hybrid_dims(self) -> utils.ModelDims: - return utils.ModelDims( - num_layers=8, - hidden_size=3840, - intermediate_size=11008, - vocab_size=100352, - num_attn_heads=30, - head_dim=128, - num_kv_heads=30, - num_linear_attn_layers=6, - linear_attn_num_k_heads=30, - linear_attn_num_v_heads=30, - linear_attn_key_head_dim=96, - linear_attn_value_head_dim=192, - linear_attn_conv_size=4, - device_name="h100", - ) + def _hybrid_dims(self, num_layers: int = 8) -> utils.ModelDims: + return dataclasses.replace(MODEL_DIMS["olmo-hybrid-7b"], num_layers=num_layers, num_params=None) def test_linear_attn_flops_scale_linearly(self): dims = self._hybrid_dims() @@ -796,22 +799,7 @@ def test_linear_attn_flops_scale_linearly(self): def test_decode_flops_constant_per_prompt_length_for_gdn(self): # A purely linear-attention model has no growing context, so decode FLOPs are independent # of prompt length (unlike softmax attention, where they grow with kv_len). - gdn_only = utils.ModelDims( - num_layers=6, - hidden_size=3840, - intermediate_size=11008, - vocab_size=100352, - num_attn_heads=30, - head_dim=128, - num_kv_heads=30, - num_linear_attn_layers=6, - linear_attn_num_k_heads=30, - linear_attn_num_v_heads=30, - linear_attn_key_head_dim=96, - linear_attn_value_head_dim=192, - linear_attn_conv_size=4, - device_name="h100", - ) + gdn_only = self._hybrid_dims(num_layers=6) self.assertEqual(gdn_only.decode_flops([10], [32]), gdn_only.decode_flops([10000], [32])) def test_gdn_layers_excluded_from_kv_cache(self): @@ -825,25 +813,6 @@ def test_gdn_state_bytes_present(self): state_elems = 30 * 96 * 192 self.assertEqual(dims.gdn_state_bytes(10), 6 * 10 * 2 * state_elems * 2) - def test_utilization_under_100(self): - dims = self._hybrid_dims() - prompt_lengths = [256] * 8 - response_lengths = [256] * 16 - metrics = utils.calculate_utilization_metrics( - model_dims=dims, - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - total_generation_time=8.0, - samples_per_prompt=2, - num_engines=2, - num_gpus_per_engine=2, - training_time=4.0, - num_training_gpus=4, - ) - self.assertLessEqual(metrics["actor_mfu"], 100) - self.assertLessEqual(metrics["actor_mbu"], 100) - self.assertLessEqual(metrics["learner_mfu"], 100) - # useful for checking if public datasets are still available # class CheckTuluDatasetsTest(unittest.TestCase): diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 56fb6aa08a..c136b3ffb4 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -1842,24 +1842,16 @@ class ModelDims: sliding_window: int | None = None num_sliding_window_layers: int = 0 num_linear_attn_layers: int = 0 - linear_attn_num_k_heads: int | None = None - linear_attn_num_v_heads: int | None = None - linear_attn_key_head_dim: int | None = None - linear_attn_value_head_dim: int | None = None + linear_attn_num_k_heads: int = 0 + linear_attn_num_v_heads: int = 0 + linear_attn_key_head_dim: int = 0 + linear_attn_value_head_dim: int = 0 linear_attn_conv_size: int = 4 def __post_init__(self): if self.num_kv_heads is None: self.num_kv_heads = self.num_attn_heads - if self.num_linear_attn_layers > 0: - assert None not in ( - self.linear_attn_num_k_heads, - self.linear_attn_num_v_heads, - self.linear_attn_key_head_dim, - self.linear_attn_value_head_dim, - ), "linear attention dims must be set when num_linear_attn_layers > 0" - self.num_params = self.num_params or self._calculate_num_params() if self.device_name is None and torch.cuda.is_available(): @@ -1882,6 +1874,10 @@ def linear_attn_key_dim(self) -> int: def linear_attn_value_dim(self) -> int: return self.linear_attn_num_v_heads * self.linear_attn_value_head_dim + @property + def num_full_attn_layers(self) -> int: + return self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers + def _gdn_layer_params(self) -> int: """Parameter count for one Gated Delta Net (linear attention) layer, excluding MLP.""" h = self.hidden_size @@ -1905,8 +1901,7 @@ def _calculate_num_params(self) -> int: num_attn_layers = self.num_layers - self.num_linear_attn_layers layer_params = self.num_layers * mlp_params + num_attn_layers * attn_params - if self.num_linear_attn_layers > 0: - layer_params += self.num_linear_attn_layers * self._gdn_layer_params() + layer_params += self.num_linear_attn_layers * self._gdn_layer_params() lm_head_params = self.vocab_size * self.hidden_size @@ -1920,16 +1915,15 @@ def from_hf_config(cls, model_name_or_path: str) -> "ModelDims": hidden_size = config.hidden_size intermediate_size = getattr(config, "intermediate_size", 4 * hidden_size) sliding_window = getattr(config, "sliding_window", None) + layer_types = getattr(config, "layer_types", None) num_sliding_window_layers = 0 if sliding_window is not None: - layer_types = getattr(config, "layer_types", None) if layer_types is not None: num_sliding_window_layers = layer_types.count("sliding_attention") else: num_sliding_window_layers = config.num_hidden_layers head_dim = getattr(config, "head_dim", hidden_size // config.num_attention_heads) - layer_types = getattr(config, "layer_types", None) num_linear_attn_layers = layer_types.count("linear_attention") if layer_types is not None else 0 linear_attn_kwargs = {} if num_linear_attn_layers > 0: @@ -2039,7 +2033,7 @@ def linear_attn_flops(self, seq_len: int) -> int: def prefill_flops(self, prompt_lengths: list[int]) -> int: """Prefill builds the KV cache; logits are computed once after each prompt.""" - num_full_attn_layers = self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers + num_full_attn_layers = self.num_full_attn_layers num_sliding_layers = self.num_sliding_window_layers total = 0 @@ -2052,8 +2046,7 @@ def prefill_flops(self, prompt_lengths: list[int]) -> int: self.attn_flops(L, L, sliding_window=self.sliding_window) + self.mlp_flops(L) ) - if self.num_linear_attn_layers > 0: - total += self.num_linear_attn_layers * (self.linear_attn_flops(L) + self.mlp_flops(L)) + total += self.num_linear_attn_layers * (self.linear_attn_flops(L) + self.mlp_flops(L)) # LM head is applied to each token position during training total += L * FLOP_PER_MAC * self.hidden_size * self.vocab_size @@ -2074,7 +2067,7 @@ def decode_flops(self, prompt_lengths: list[int], response_lengths: list[int], s f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}" ) - num_full_attn_layers = self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers + num_full_attn_layers = self.num_full_attn_layers num_sliding_layers = self.num_sliding_window_layers total = 0 @@ -2085,8 +2078,7 @@ def decode_flops(self, prompt_lengths: list[int], response_lengths: list[int], s R = response_lengths[response_idx] total += R * self.num_layers * self.mlp_flops(seq_len=1) # Linear attention per-token cost is constant in context length. - if self.num_linear_attn_layers > 0: - total += self.num_linear_attn_layers * self.linear_attn_flops(R) + total += self.num_linear_attn_layers * self.linear_attn_flops(R) for t in range(R): kv_len = P + t + 1 # prompt + generated so far + current if num_full_attn_layers > 0: @@ -2192,7 +2184,7 @@ def kv_cache_read_bytes( f"Expected {len(prompt_lengths) * samples_per_prompt} response lengths, got {len(response_lengths)}" ) - num_full_attn_layers = self.num_layers - self.num_sliding_window_layers - self.num_linear_attn_layers + num_full_attn_layers = self.num_full_attn_layers num_sliding_layers = self.num_sliding_window_layers # For batched sampling with shared prompt KV cache: @@ -2235,8 +2227,6 @@ def gdn_state_bytes(self, num_tokens: int, dtype_bytes: int = 2) -> int: cache. Each decode step reads and writes the full state once per layer, independent of context length. """ - if self.num_linear_attn_layers == 0: - return 0 state_elems = self.linear_attn_num_v_heads * self.linear_attn_key_head_dim * self.linear_attn_value_head_dim # 2x for read + write return self.num_linear_attn_layers * num_tokens * 2 * state_elems * dtype_bytes From 2c23960adea8c2af5cdc1baaad92f24f424a6bbe Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 09:35:02 -0600 Subject: [PATCH 07/35] Bump olmo-core to hybrid-dpo-conversion branch for Olmo-Hybrid DPO support Co-Authored-By: Claude Opus 4.8 --- CHANGELOG.md | 1 + pyproject.toml | 2 +- requirements.txt | 2 +- uv.lock | 4 ++-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3cbe25d55..fb866204ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed +- Support training the Olmo-Hybrid (GDN) model with the OLMo-core DPO trainer (`dpo.py`): bump olmo-core to a commit that adds the `olmo3_hybrid_7B` config preset and HF→olmo-core hybrid weight conversion (`convert_hybrid_state_from_hf`), and add an OLMo-core hybrid DPO sweep script (https://github.com/allenai/open-instruct/pull/1713). - Expand type-checking coverage by replacing `# ty: ignore` directives with typed casts and fixing related type issues (https://github.com/allenai/open-instruct/pull/1688). - Add TV divergence rho filtering for GRPO (https://github.com/allenai/open-instruct/pull/1681). - Export `SETUPTOOLS_SCM_PRETEND_VERSION_FOR_OPEN_INSTRUCT=0.0.0+debug` in `scripts/train/debug/grpo.sh` and `grpo_fast.sh` (local Ray debug scripts that disable torch compile) so setuptools-scm can resolve the package version (https://github.com/allenai/open-instruct/pull/1696). diff --git a/pyproject.toml b/pyproject.toml index ed81a0c382..e7d11889a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ flash-attn-3 = { FLASH_ATTENTION_SKIP_CUDA_BUILD = "TRUE" } # pytorch related setups [tool.uv.sources] -ai2-olmo-core = { git = "https://github.com/allenai/OLMo-core.git", rev = "f1b69d796aa9691c0abbf29683ee04fe1e42aa50" } +ai2-olmo-core = { git = "https://github.com/allenai/OLMo-core.git", rev = "ae85c1100b81436ca7e29c50cccc45f6d206fd7b" } flash-attn = [ { url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.0/flash_attn-2.8.3+cu128torch2.10-cp312-cp312-linux_x86_64.whl", marker = "platform_system == 'Linux' and platform_machine == 'x86_64'" }, ] diff --git a/requirements.txt b/requirements.txt index 99287652d7..013fd4638d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ accelerate==1.12.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') # via # open-instruct # peft -ai2-olmo-core @ git+https://github.com/allenai/OLMo-core.git@f1b69d796aa9691c0abbf29683ee04fe1e42aa50 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' +ai2-olmo-core @ git+https://github.com/allenai/OLMo-core.git@ae85c1100b81436ca7e29c50cccc45f6d206fd7b ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via open-instruct aiofiles==25.1.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via crawl4ai diff --git a/uv.lock b/uv.lock index 5a8e3732b9..a956ca3325 100644 --- a/uv.lock +++ b/uv.lock @@ -47,7 +47,7 @@ wheels = [ [[package]] name = "ai2-olmo-core" version = "2.5.0" -source = { git = "https://github.com/allenai/OLMo-core.git?rev=f1b69d796aa9691c0abbf29683ee04fe1e42aa50#f1b69d796aa9691c0abbf29683ee04fe1e42aa50" } +source = { git = "https://github.com/allenai/OLMo-core.git?rev=ae85c1100b81436ca7e29c50cccc45f6d206fd7b#ae85c1100b81436ca7e29c50cccc45f6d206fd7b" } dependencies = [ { name = "bettermap", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "cached-path", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, @@ -2774,7 +2774,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "accelerate", specifier = ">=1.10.1" }, - { name = "ai2-olmo-core", git = "https://github.com/allenai/OLMo-core.git?rev=f1b69d796aa9691c0abbf29683ee04fe1e42aa50" }, + { name = "ai2-olmo-core", git = "https://github.com/allenai/OLMo-core.git?rev=ae85c1100b81436ca7e29c50cccc45f6d206fd7b" }, { name = "antlr4-python3-runtime", specifier = "==4.11" }, { name = "authlib", marker = "extra == 'dr-tulu'" }, { name = "backoff", specifier = ">=2.2.1" }, From 600f83e3d62610ca1980f2dd51d5cebcc80dfdc9 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 10:56:01 -0600 Subject: [PATCH 08/35] Fully shard (fsdp_shard_degree=32) Olmo-Hybrid DPO sweep to fix OOM, matching ZeRO-3 reference Co-Authored-By: Claude Opus 4.8 --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index 0068b16644..784ed12fb5 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -69,6 +69,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --max_seq_length 16384 \ --per_device_train_batch_size 1 \ --gradient_accumulation_steps 4 \ + --fsdp_shard_degree 32 \ --learning_rate "$LR" \ --lr_scheduler_type linear \ --checkpointing_steps 500 \ From 73ad5601321d1d591e72d82a8d2f9a719bb671c8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 11:13:15 -0600 Subject: [PATCH 09/35] Add full-block activation checkpointing mode for olmo-core DPO to fit GDN at 16k seq Co-Authored-By: Claude Opus 4.8 --- open_instruct/dpo.py | 4 +++- open_instruct/olmo_core_utils.py | 12 +++++++++++- .../olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index bda1c479ed..c7c35ca8da 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -268,7 +268,9 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz reduce_dtype=DType.float32, wrapping_strategy=transformer_config.TransformerDataParallelWrappingStrategy.blocks, ) - ac_config = olmo_core_utils.build_ac_config(args.activation_memory_budget, args.compile_model) + ac_config = olmo_core_utils.build_ac_config( + args.activation_memory_budget, args.compile_model, args.activation_checkpointing_mode + ) train_module = DPOTrainModule( model=model, diff --git a/open_instruct/olmo_core_utils.py b/open_instruct/olmo_core_utils.py index ef954b0bc6..70cbb69e1d 100644 --- a/open_instruct/olmo_core_utils.py +++ b/open_instruct/olmo_core_utils.py @@ -108,6 +108,14 @@ class TrainingConfig: typically faster; lower values use less memory and are typically slower, so use the highest value your hardware can support. See: https://pytorch.org/blog/activation-checkpointing-techniques/. """ + activation_checkpointing_mode: Literal["budget", "full"] = "budget" + """Activation checkpointing mode. + + "budget" uses torch.compile's partitioner with `activation_memory_budget` (requires compilation, + and cannot checkpoint through opaque custom ops such as the GDN `fla` kernels). "full" wraps every + transformer block in `torch.utils.checkpoint`, keeping only one block's activations live at a time, + which is required to fit linear-attention (GDN) models at long sequence lengths. + """ compile_model: bool = True """Whether to apply torch.compile to model blocks.""" fused_optimizer: bool = True @@ -119,8 +127,10 @@ class TrainingConfig: def build_ac_config( - activation_memory_budget: float, compile_model: bool + activation_memory_budget: float, compile_model: bool, mode: str = "budget" ) -> TransformerActivationCheckpointingConfig | None: + if mode == "full": + return TransformerActivationCheckpointingConfig(mode=TransformerActivationCheckpointingMode.full) if activation_memory_budget < 1.0 and compile_model: return TransformerActivationCheckpointingConfig( mode=TransformerActivationCheckpointingMode.budget, activation_memory_budget=activation_memory_budget diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index 784ed12fb5..e9d0411f94 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -81,7 +81,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ - --activation_memory_budget 0.5 \ + --activation_checkpointing_mode full \ --with_tracking done done From 2daaa2a0b6780a19f17758f858fe9048f65decbf Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 11:28:37 -0600 Subject: [PATCH 10/35] Disable torch.compile in Olmo-Hybrid DPO sweep: compile+full-block checkpoint of GDN op fails recompute metadata check Co-Authored-By: Claude Opus 4.8 --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index e9d0411f94..99c5a403b5 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -82,6 +82,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --beta 5 \ --packing \ --activation_checkpointing_mode full \ + --compile_model false \ --with_tracking done done From e2a385ce59fb5e368f10aaf4649acbbf9263a8be Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 13:20:06 -0600 Subject: [PATCH 11/35] Bump flash-linear-attention 0.4.2 -> 0.5.0 --- pyproject.toml | 2 +- requirements.txt | 4 ++-- uv.lock | 14 +++++++------- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e7d11889a9..685511e00b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "mcp>=1.9.0", "openenv-core>=0.2.1", "docker>=7.0.0", - "flash-linear-attention>=0.4.2", + "flash-linear-attention>=0.5.0", ] [build-system] diff --git a/requirements.txt b/requirements.txt index 013fd4638d..75cc8f5f70 100644 --- a/requirements.txt +++ b/requirements.txt @@ -265,7 +265,7 @@ filelock==3.20.3 ; (platform_machine == 'aarch64' and sys_platform == 'linux') o # torch # virtualenv # vllm -fla-core==0.4.2 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' +fla-core==0.5.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via flash-linear-attention flash-attn @ https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.0/flash_attn-2.8.3+cu128torch2.10-cp312-cp312-linux_x86_64.whl ; platform_machine == 'x86_64' and sys_platform == 'linux' # via open-instruct @@ -273,7 +273,7 @@ flash-attn-3 @ https://github.com/windreamer/flash-attention3-wheels/releases/do # via open-instruct flash-attn-4 @ https://github.com/Dao-AILab/flash-attention/releases/download/fa4-v4.0.0.beta6/flash_attn_4-4.0.0b5-py3-none-any.whl ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') # via open-instruct -flash-linear-attention==0.4.2 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' +flash-linear-attention==0.5.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via open-instruct flashinfer-cubin==0.6.6 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') # via vllm diff --git a/uv.lock b/uv.lock index a956ca3325..be0fe24907 100644 --- a/uv.lock +++ b/uv.lock @@ -1169,7 +1169,7 @@ wheels = [ [[package]] name = "fla-core" -version = "0.4.2" +version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "einops", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, @@ -1177,9 +1177,9 @@ dependencies = [ { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torch", version = "2.10.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/f9/9e05c48f92b1388a8a357141eb557ed0dd6d4bb936e1d05d35f01976657f/fla_core-0.4.2.tar.gz", hash = "sha256:e9fef6fcdf122029f9feb7dccfeb85eb9650e6aabc72d2a65b36558e9c590edd", size = 377722, upload-time = "2026-03-12T14:45:46.101Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/14/2aabd37839b9f3c6a67fbc5678f906d04d0c242c603ac234eefe02df99a6/fla_core-0.5.0.tar.gz", hash = "sha256:476dd94711702af81cc4827010d9209f6053d8cdceac8e43d3c8497071f07a81", size = 418171, upload-time = "2026-04-21T20:25:40.948Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/36/3c303f92bafea7c3f97d68bbb83d18cc42e30cd0bfb1b7cfe589360f11d6/fla_core-0.4.2-py3-none-any.whl", hash = "sha256:cba3db29380002da3cbfc0db94d6efac19aaf528900d19c05c2765e8f3cc485b", size = 510239, upload-time = "2026-03-12T14:45:43.708Z" }, + { url = "https://files.pythonhosted.org/packages/f4/03/96e6820d176256353670b41ca56dabbbebe129674b4f4ad7b54a152b7b36/fla_core-0.5.0-py3-none-any.whl", hash = "sha256:5c826ff32daf6b629658e3e4f6125d87cf8c32eea937e3be9ba85f51951d809a", size = 595276, upload-time = "2026-04-21T20:25:37.698Z" }, ] [[package]] @@ -1256,15 +1256,15 @@ provides-extras = ["dev"] [[package]] name = "flash-linear-attention" -version = "0.4.2" +version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "fla-core", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d1/cb/46cc27a829a10b308927c5dbc99176906a021bb0770253699e93f3cd81a0/flash_linear_attention-0.4.2.tar.gz", hash = "sha256:f97c01ebe7cf390323af07dd3fb65ade07da16724339bf70c78607bc0c007c34", size = 148464, upload-time = "2026-03-12T14:45:46.945Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/5c/1db76cc829c951117a3112f306d50333bd71399d2e35807fe7c99ffc2007/flash_linear_attention-0.5.0.tar.gz", hash = "sha256:22b789a47f07738b4382ecdf775d7bb40e0d803c467c34f8e2ecd6a1dc780938", size = 160419, upload-time = "2026-04-21T20:25:42.344Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/60/ee/a3cba17965482b35c4990af90bad108e82c32edcb59911c37f318b5f4198/flash_linear_attention-0.4.2-py3-none-any.whl", hash = "sha256:c08be006ce4dbe1be81f54938ee8e6fc7968cfba397c8d06c7669e97b8c44c0d", size = 284661, upload-time = "2026-03-12T14:45:44.905Z" }, + { url = "https://files.pythonhosted.org/packages/cc/16/7736db08806981562c728f32ea1dcb4565948fa9faffdbf4ffbf72522fbf/flash_linear_attention-0.5.0-py3-none-any.whl", hash = "sha256:92e64e989ed34355c1f838232597b2e39783ee0494ada3199b58e156aa1d8eb8", size = 319037, upload-time = "2026-04-21T20:25:39.473Z" }, ] [[package]] @@ -2789,7 +2789,7 @@ requires-dist = [ { name = "flash-attn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", url = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.9.0/flash_attn-2.8.3+cu128torch2.10-cp312-cp312-linux_x86_64.whl" }, { name = "flash-attn-3", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", url = "https://github.com/windreamer/flash-attention3-wheels/releases/download/2026.03.19-850211f/flash_attn_3-3.0.0%2B20260318.cu128torch2100cxx11abitrue.8afc61-cp39-abi3-linux_x86_64.whl" }, { name = "flash-attn-4", marker = "sys_platform != 'darwin'", url = "https://github.com/Dao-AILab/flash-attention/releases/download/fa4-v4.0.0.beta6/flash_attn_4-4.0.0b5-py3-none-any.whl" }, - { name = "flash-linear-attention", specifier = ">=0.4.2" }, + { name = "flash-linear-attention", specifier = ">=0.5.0" }, { name = "hf-transfer", specifier = ">=0.1.8" }, { name = "immutabledict", specifier = "==1.2.0" }, { name = "langdetect", specifier = "==1.0.9" }, From 53e2af3b823bc125c299a4b25979ceb1860fdfd4 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 13:48:15 -0600 Subject: [PATCH 12/35] Add selected_modules activation checkpointing to enable compile with GDN (checkpoint only compile-safe MLPs, leave opaque GDN mixer activations live) --- open_instruct/dpo.py | 5 ++++- open_instruct/olmo_core_utils.py | 15 ++++++++++++--- .../7b_instruct_dpo_sweep_olmo_core.sh | 4 ++-- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index c7c35ca8da..b279988958 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -269,7 +269,10 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz wrapping_strategy=transformer_config.TransformerDataParallelWrappingStrategy.blocks, ) ac_config = olmo_core_utils.build_ac_config( - args.activation_memory_budget, args.compile_model, args.activation_checkpointing_mode + args.activation_memory_budget, + args.compile_model, + args.activation_checkpointing_mode, + args.activation_checkpointing_modules, ) train_module = DPOTrainModule( diff --git a/open_instruct/olmo_core_utils.py b/open_instruct/olmo_core_utils.py index 70cbb69e1d..0e9d5a9eee 100644 --- a/open_instruct/olmo_core_utils.py +++ b/open_instruct/olmo_core_utils.py @@ -108,14 +108,19 @@ class TrainingConfig: typically faster; lower values use less memory and are typically slower, so use the highest value your hardware can support. See: https://pytorch.org/blog/activation-checkpointing-techniques/. """ - activation_checkpointing_mode: Literal["budget", "full"] = "budget" + activation_checkpointing_mode: Literal["budget", "full", "selected_modules"] = "budget" """Activation checkpointing mode. "budget" uses torch.compile's partitioner with `activation_memory_budget` (requires compilation, and cannot checkpoint through opaque custom ops such as the GDN `fla` kernels). "full" wraps every transformer block in `torch.utils.checkpoint`, keeping only one block's activations live at a time, - which is required to fit linear-attention (GDN) models at long sequence lengths. + which is required to fit linear-attention (GDN) models at long sequence lengths. "selected_modules" + wraps only the modules in `activation_checkpointing_modules` (by default the feed-forward MLPs), + leaving the opaque GDN mixer activations live so they are never recomputed, which lets + `torch.compile` coexist with checkpointing (the recomputed MLP regions are compile-safe). """ + activation_checkpointing_modules: list[str] = field(default_factory=lambda: ["blocks.*.feed_forward"]) + """Module-name globs to wrap when `activation_checkpointing_mode` is "selected_modules".""" compile_model: bool = True """Whether to apply torch.compile to model blocks.""" fused_optimizer: bool = True @@ -127,10 +132,14 @@ class TrainingConfig: def build_ac_config( - activation_memory_budget: float, compile_model: bool, mode: str = "budget" + activation_memory_budget: float, compile_model: bool, mode: str = "budget", modules: list[str] | None = None ) -> TransformerActivationCheckpointingConfig | None: if mode == "full": return TransformerActivationCheckpointingConfig(mode=TransformerActivationCheckpointingMode.full) + if mode == "selected_modules": + return TransformerActivationCheckpointingConfig( + mode=TransformerActivationCheckpointingMode.selected_modules, modules=modules + ) if activation_memory_budget < 1.0 and compile_model: return TransformerActivationCheckpointingConfig( mode=TransformerActivationCheckpointingMode.budget, activation_memory_budget=activation_memory_budget diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index 99c5a403b5..eebccb6f57 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -81,8 +81,8 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ - --activation_checkpointing_mode full \ - --compile_model false \ + --activation_checkpointing_mode selected_modules \ + --compile_model true \ --with_tracking done done From ed6b21854c26425ee1f530c7243dc777cc818f04 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 13:53:50 -0600 Subject: [PATCH 13/35] Checkpoint all Olmo-Hybrid block submodules except the GDN mixer for selected_modules AC Co-Authored-By: Claude Opus 4.8 --- open_instruct/olmo_core_utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/open_instruct/olmo_core_utils.py b/open_instruct/olmo_core_utils.py index 0e9d5a9eee..92fe211efe 100644 --- a/open_instruct/olmo_core_utils.py +++ b/open_instruct/olmo_core_utils.py @@ -115,12 +115,25 @@ class TrainingConfig: and cannot checkpoint through opaque custom ops such as the GDN `fla` kernels). "full" wraps every transformer block in `torch.utils.checkpoint`, keeping only one block's activations live at a time, which is required to fit linear-attention (GDN) models at long sequence lengths. "selected_modules" - wraps only the modules in `activation_checkpointing_modules` (by default the feed-forward MLPs), - leaving the opaque GDN mixer activations live so they are never recomputed, which lets - `torch.compile` coexist with checkpointing (the recomputed MLP regions are compile-safe). + wraps only the modules in `activation_checkpointing_modules` (by default every block submodule + except the GDN mixer), leaving the opaque GDN mixer activations live so they are never recomputed, + which lets `torch.compile` coexist with checkpointing (the recomputed regions are compile-safe). + """ + activation_checkpointing_modules: list[str] = field( + default_factory=lambda: [ + "blocks.*.attention_norm", + "blocks.*.attention_residual_stream", + "blocks.*.feed_forward_norm", + "blocks.*.feed_forward", + "blocks.*.feed_forward_residual_stream", + ] + ) + """Module-name globs to wrap when `activation_checkpointing_mode` is "selected_modules". + + Defaults to every transformer-block submodule except the GDN mixer (`blocks.*.attention`), so the + opaque `fla` kernel activations stay resident (never recomputed) while everything else is + checkpointed, which both recovers memory and keeps `torch.compile` compatible. """ - activation_checkpointing_modules: list[str] = field(default_factory=lambda: ["blocks.*.feed_forward"]) - """Module-name globs to wrap when `activation_checkpointing_mode` is "selected_modules".""" compile_model: bool = True """Whether to apply torch.compile to model blocks.""" fused_optimizer: bool = True From 66bbd9c0afb172c5042af73845b410aaeb6a9ed1 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 15:05:49 -0600 Subject: [PATCH 14/35] Use full AC + compile for Olmo-Hybrid DPO by skipping checkpoint determinism check Co-Authored-By: Claude Opus 4.8 --- open_instruct/dpo.py | 1 + open_instruct/olmo_core_utils.py | 20 +++++++++++++++++++ .../7b_instruct_dpo_sweep_olmo_core.sh | 2 +- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index b279988958..45efc1f031 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -274,6 +274,7 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz args.activation_checkpointing_mode, args.activation_checkpointing_modules, ) + olmo_core_utils.patch_checkpoint_wrapper_determinism_check() train_module = DPOTrainModule( model=model, diff --git a/open_instruct/olmo_core_utils.py b/open_instruct/olmo_core_utils.py index 92fe211efe..a8b3f0b196 100644 --- a/open_instruct/olmo_core_utils.py +++ b/open_instruct/olmo_core_utils.py @@ -27,6 +27,7 @@ TransformerActivationCheckpointingMode, ) from olmo_core.train.train_module.transformer.config import TransformerContextParallelConfig +from torch.distributed.algorithms._checkpoint import checkpoint_wrapper as ptd_checkpoint_wrapper_mod from open_instruct import logger_utils, model_utils, olmo_core_callbacks, utils from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_tulu @@ -160,6 +161,25 @@ def build_ac_config( return None +def patch_checkpoint_wrapper_determinism_check() -> None: + """Make olmo-core's activation checkpointing skip torch's recompute determinism check. + + olmo-core hardcodes `ptd_checkpoint_wrapper(block, preserve_rng_state=False)`, leaving torch's + `determinism_check="default"`, which compares forward vs. recompute tensor metadata. The opaque + `fla` GDN kernel's recompute produces a spurious metadata mismatch under `torch.compile`, raising + CheckpointError. Eager recompute already proves the values match, so we forward + `determinism_check="none"` to suppress the (metadata-only) check and let full checkpointing and + compile coexist. + """ + original = ptd_checkpoint_wrapper_mod.checkpoint_wrapper + + def patched(module, **kwargs): + kwargs.setdefault("determinism_check", "none") + return original(module, **kwargs) + + ptd_checkpoint_wrapper_mod.checkpoint_wrapper = patched + + def build_cp_config(training: TrainingConfig) -> TransformerContextParallelConfig | None: if training.cp_degree is None: return None diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index eebccb6f57..cdfc0f9bd2 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -81,7 +81,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ - --activation_checkpointing_mode selected_modules \ + --activation_checkpointing_mode full \ --compile_model true \ --with_tracking done From 872ad777431522d0396b9a7b4247f3209efa3076 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 15:34:07 -0600 Subject: [PATCH 15/35] DPO: checkpoint GDN mixer via selected_modules to keep compile outside checkpoint (avoids full-mode inductor stride guard failure) Co-Authored-By: Claude Opus 4.8 --- open_instruct/olmo_core_utils.py | 20 +++++++++++-------- .../7b_instruct_dpo_sweep_olmo_core.sh | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/open_instruct/olmo_core_utils.py b/open_instruct/olmo_core_utils.py index a8b3f0b196..ec8817dd9c 100644 --- a/open_instruct/olmo_core_utils.py +++ b/open_instruct/olmo_core_utils.py @@ -114,15 +114,17 @@ class TrainingConfig: "budget" uses torch.compile's partitioner with `activation_memory_budget` (requires compilation, and cannot checkpoint through opaque custom ops such as the GDN `fla` kernels). "full" wraps every - transformer block in `torch.utils.checkpoint`, keeping only one block's activations live at a time, - which is required to fit linear-attention (GDN) models at long sequence lengths. "selected_modules" - wraps only the modules in `activation_checkpointing_modules` (by default every block submodule - except the GDN mixer), leaving the opaque GDN mixer activations live so they are never recomputed, - which lets `torch.compile` coexist with checkpointing (the recomputed regions are compile-safe). + transformer block in `torch.utils.checkpoint`, applying compile *outside* the checkpoint, whose + recompute re-enters the compiled block's backward and fails an inductor stride guard — so "full" + is incompatible with `torch.compile` for GDN models. "selected_modules" wraps the individual block + submodules in `activation_checkpointing_modules` (by default all of them, including the GDN mixer), + which keeps compile *outside* the checkpoint boundary (the supported order) while still recovering + full-block memory savings, letting `torch.compile` coexist with checkpointing at long sequences. """ activation_checkpointing_modules: list[str] = field( default_factory=lambda: [ "blocks.*.attention_norm", + "blocks.*.attention", "blocks.*.attention_residual_stream", "blocks.*.feed_forward_norm", "blocks.*.feed_forward", @@ -131,9 +133,11 @@ class TrainingConfig: ) """Module-name globs to wrap when `activation_checkpointing_mode` is "selected_modules". - Defaults to every transformer-block submodule except the GDN mixer (`blocks.*.attention`), so the - opaque `fla` kernel activations stay resident (never recomputed) while everything else is - checkpointed, which both recovers memory and keeps `torch.compile` compatible. + Defaults to every transformer-block submodule, including the GDN mixer (`blocks.*.attention`). + Wrapping submodules individually (rather than the whole block, as "full" does) keeps `torch.compile` + *outside* the checkpoint boundary, which is the order compile supports, while still recovering + full-block activation memory. The opaque `fla` kernel's recompute metadata check is suppressed by + `patch_checkpoint_wrapper_determinism_check`. """ compile_model: bool = True """Whether to apply torch.compile to model blocks.""" diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index cdfc0f9bd2..eebccb6f57 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -81,7 +81,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ - --activation_checkpointing_mode full \ + --activation_checkpointing_mode selected_modules \ --compile_model true \ --with_tracking done From 63348c851295abef420d69399653c9f9e500f0c7 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 2 Jun 2026 16:39:33 -0600 Subject: [PATCH 16/35] Add tilelang dep so fla routes GDN chunk_bwd_dqkwg around the broken Triton>=3.4 Hopper kernel (fla #640) Co-Authored-By: Claude Opus 4.8 --- pyproject.toml | 3 ++ requirements.txt | 25 ++++++++++++++-- uv.lock | 76 ++++++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 96 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 685511e00b..4c3dbfd758 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,9 @@ dependencies = [ "openenv-core>=0.2.1", "docker>=7.0.0", "flash-linear-attention>=0.5.0", + # Provides the tilelang backend for fla's gated chunk_bwd_dqkwg, which fla routes to on + # Hopper with Triton>=3.4.0 (the Triton kernel produces incorrect gradients there; see fla #640). + "tilelang>=0.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'", ] [build-system] diff --git a/requirements.txt b/requirements.txt index 75cc8f5f70..22ace6a502 100644 --- a/requirements.txt +++ b/requirements.txt @@ -47,11 +47,17 @@ anyio==4.12.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or ( # sse-starlette # starlette # watchfiles -apache-tvm-ffi==0.1.8.post2 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') +apache-tvm-ffi==0.1.8.post2 ; platform_machine == 'aarch64' and sys_platform == 'linux' # via # flash-attn-4 # flashinfer-python # quack-kernels +apache-tvm-ffi==0.1.11 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via + # flash-attn-4 + # flashinfer-python + # quack-kernels + # tilelang astor==0.8.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') # via depyf attrs==25.4.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' @@ -134,7 +140,9 @@ click==8.3.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (p click-log==0.4.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via alphashape cloudpickle==3.1.2 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') - # via vllm + # via + # tilelang + # vllm cohere==5.20.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via dr-agent colorama==0.4.6 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' @@ -506,6 +514,8 @@ mkdocs-get-deps==0.2.0 ; (platform_machine == 'aarch64' and sys_platform == 'lin mkdocs-material==9.7.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' mkdocs-material-extensions==1.3.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via mkdocs-material +ml-dtypes==0.5.4 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via tilelang model-hosting-container-standards==0.1.13 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') # via vllm more-itertools==10.8.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' @@ -559,6 +569,7 @@ numpy==2.2.6 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (p # gguf # matplotlib # mistral-common + # ml-dtypes # numba # nvidia-cutlass-dsl-libs-base # open-instruct @@ -569,6 +580,7 @@ numpy==2.2.6 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (p # scipy # shapely # tensorboard + # tilelang # torchvision # transformers # trimesh @@ -847,6 +859,7 @@ psutil==7.2.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or ( # deepspeed # nvitop # peft + # tilelang # vllm py==1.11.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via retry @@ -1145,6 +1158,8 @@ tiktoken==0.12.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') o # litellm # mistral-common # vllm +tilelang==0.1.10 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via open-instruct tinydb==4.8.2 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via # dr-agent @@ -1185,6 +1200,7 @@ torch==2.10.0+cu128 ; platform_machine == 'x86_64' and sys_platform == 'linux' # open-instruct # peft # quack-kernels + # tilelang # torch-c-dlpack-ext # torchaudio # torchvision @@ -1213,6 +1229,7 @@ torch-c-dlpack-ext==0.1.5 ; (platform_machine == 'aarch64' and sys_platform == ' # via # flash-attn-4 # quack-kernels + # tilelang torchaudio==2.10.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') # via vllm torchvision==0.25.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') @@ -1227,6 +1244,7 @@ tqdm==4.67.1 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (p # nltk # openai # peft + # tilelang # transformers # vllm transformers==5.4.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' @@ -1289,6 +1307,7 @@ typing-extensions==4.15.0 ; (platform_machine == 'aarch64' and sys_platform == ' # referencing # rich-toolkit # starlette + # tilelang # torch # typer # typing-inspection @@ -1353,5 +1372,7 @@ xxhash==3.6.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or ( # datasets yarl==1.22.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via aiohttp +z3-solver==4.15.4.0 ; platform_machine == 'x86_64' and sys_platform == 'linux' + # via tilelang zipp==3.23.0 ; (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin' # via importlib-metadata diff --git a/uv.lock b/uv.lock index be0fe24907..84ed230394 100644 --- a/uv.lock +++ b/uv.lock @@ -224,15 +224,32 @@ wheels = [ name = "apache-tvm-ffi" version = "0.1.8.post2" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'aarch64' and sys_platform == 'linux'", +] dependencies = [ - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/e3/e9/a13952726228fa6282154ecf927092396bc759739e5e045019f6ab92f3ca/apache_tvm_ffi-0.1.8.post2.tar.gz", hash = "sha256:4513e38852894f290172ecfefcbc18d34e817fd29c16a0f1770e130c82b4067e", size = 2441111, upload-time = "2026-01-13T18:11:27.864Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8e/3a/7b1c9edcaeaebb945038144896cf17eb828a40b6ace0371823e133132664/apache_tvm_ffi-0.1.8.post2-cp312-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c78b4caf17304a1f47881bccdb2f9ac24d98b3b7fbe761a6dd4fd0585934d96", size = 1967259, upload-time = "2026-01-13T18:10:47.851Z" }, - { url = "https://files.pythonhosted.org/packages/6c/b6/463602f57dda2e1c69165c044c07061cd59404593f313a427a3ad9c02cf3/apache_tvm_ffi-0.1.8.post2-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4a48da3fa8f47130f3502134f01e97044388c5217e7b91be4b0acec4feab81a0", size = 2044821, upload-time = "2026-01-13T18:10:49.396Z" }, { url = "https://files.pythonhosted.org/packages/fe/e6/9cdc7f4814b2fbdfceba5dc640c3704d07d8db18e3d1aef5aa49bbf1ba7e/apache_tvm_ffi-0.1.8.post2-cp312-abi3-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:61cc98e489ebc03bc96d1a966dc863eb1c0a607383f6bf4a416ff0a96170ca85", size = 1910964, upload-time = "2026-01-13T18:10:51.345Z" }, - { url = "https://files.pythonhosted.org/packages/7d/f5/a2e5487cdad575fe6cf34f8a23f8c49e08ce5808fa75dc19d98bcebc20ec/apache_tvm_ffi-0.1.8.post2-cp312-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:caa48509f0c7d9b896823b492a9ee42afac2548065c1ec7ef07f9a0dc30d2796", size = 2025814, upload-time = "2026-01-13T18:10:52.804Z" }, +] + +[[package]] +name = "apache-tvm-ffi" +version = "0.1.11" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'x86_64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/3d/4b9226cd45aa800a6904603dda9b323d728f3c3869952a673f3483b78b19/apache_tvm_ffi-0.1.11.tar.gz", hash = "sha256:153cd2c5a9717804cb0bcd9b2709f22a1e5f80ed05b5a490faf5949b136eedba", size = 2798354, upload-time = "2026-05-04T17:48:43.852Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/99/2848df4e8ed5bf51df1d286d1718510584fa61e88adbc9c5b23d71b38f7c/apache_tvm_ffi-0.1.11-cp312-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:78aa1857b04a2ea718317041ab3f01288b3d496e6036eb1b99ebdc9da0fdaef5", size = 2725887, upload-time = "2026-05-04T17:48:01.381Z" }, + { url = "https://files.pythonhosted.org/packages/4d/18/95569107ee83619d61a3bb0d28743a0599f85c5161981e3e098c82c2b185/apache_tvm_ffi-0.1.11-cp312-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2843f084cdc94dedacd8b257a395a2b71b8a3dc7fc99711b148bf1d161983128", size = 2697683, upload-time = "2026-05-04T17:48:05.222Z" }, ] [[package]] @@ -1227,7 +1244,8 @@ name = "flash-attn-4" version = "4.0.0b5" source = { url = "https://github.com/Dao-AILab/flash-attention/releases/download/fa4-v4.0.0.beta6/flash_attn_4-4.0.0b5-py3-none-any.whl" } dependencies = [ - { name = "apache-tvm-ffi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "apache-tvm-ffi", version = "0.1.8.post2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "apache-tvm-ffi", version = "0.1.11", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "einops", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "nvidia-cutlass-dsl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "quack-kernels", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -1280,7 +1298,8 @@ name = "flashinfer-python" version = "0.6.6" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "apache-tvm-ffi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "apache-tvm-ffi", version = "0.1.8.post2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "apache-tvm-ffi", version = "0.1.11", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "einops", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "ninja", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -2193,6 +2212,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/54/662a4743aa81d9582ee9339d4ffa3c8fd40a4965e033d77b9da9774d3960/mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31", size = 8728, upload-time = "2023-11-22T19:09:43.465Z" }, ] +[[package]] +name = "ml-dtypes" +version = "0.5.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/4a/c27b42ed9b1c7d13d9ba8b6905dece787d6259152f2309338aed29b2447b/ml_dtypes-0.5.4.tar.gz", hash = "sha256:8ab06a50fb9bf9666dd0fe5dfb4676fa2b0ac0f31ecff72a6c3af8e22c063453", size = 692314, upload-time = "2025-11-17T22:32:31.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/cb/28ce52eb94390dda42599c98ea0204d74799e4d8047a0eb559b6fd648056/ml_dtypes-0.5.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ad459e99793fa6e13bd5b7e6792c8f9190b4e5a1b45c63aba14a4d0a7f1d5ff", size = 5009002, upload-time = "2025-11-17T22:31:52.001Z" }, +] + [[package]] name = "model-hosting-container-standards" version = "0.1.13" @@ -2738,6 +2769,7 @@ dependencies = [ { name = "ray", extra = ["default"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, { name = "tensorboard", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, + { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torch", version = "2.10.0", source = { registry = "https://pypi.org/simple" }, marker = "sys_platform == 'darwin'" }, { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torch", version = "2.10.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, @@ -2812,6 +2844,7 @@ requires-dist = [ { name = "scipy", marker = "extra == 'dr-tulu'" }, { name = "setuptools", specifier = ">=75.6.0,<80.0.0" }, { name = "tensorboard", specifier = ">=2.18.0" }, + { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'", specifier = ">=0.1.0" }, { name = "torch", marker = "sys_platform != 'linux'", specifier = ">=2.10.0" }, { name = "torch", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu128" }, { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'", specifier = ">=2.10.0", index = "https://download.pytorch.org/whl/cu130" }, @@ -3826,7 +3859,8 @@ name = "quack-kernels" version = "0.3.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "apache-tvm-ffi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "apache-tvm-ffi", version = "0.1.8.post2", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "apache-tvm-ffi", version = "0.1.11", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cutlass-dsl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torch", version = "2.10.0+cu130", source = { registry = "https://download.pytorch.org/whl/cu130" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, @@ -4376,6 +4410,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, ] +[[package]] +name = "tilelang" +version = "0.1.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi", version = "0.1.11", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "cloudpickle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "ml-dtypes", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "psutil", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torch", version = "2.10.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torch-c-dlpack-ext", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tqdm", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "z3-solver", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/5c/07146b4527656102e48d21c2599aa80477e83ea3f149ac0df3b15a247bd4/tilelang-0.1.10.tar.gz", hash = "sha256:d8813e668fcf75843bc2d68c633c352b419c1e292895a6038a4aadd943e56c2b", size = 93184128, upload-time = "2026-05-25T03:58:57.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b0/66/ab4301dc38ca9f09832df2936c73388c611c198dc938634acb6ce80dfa74/tilelang-0.1.10-cp38-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85180d1a96defeecdf52d5d075a31c3fc551d8485981e6b636762a9cd7eb02fe", size = 49768455, upload-time = "2026-05-25T03:56:17.081Z" }, +] + [[package]] name = "tinydb" version = "4.8.2" @@ -4985,6 +5040,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/ae/b48f95715333080afb75a4504487cbe142cae1268afc482d06692d605ae6/yarl-1.22.0-py3-none-any.whl", hash = "sha256:1380560bdba02b6b6c90de54133c81c9f2a453dee9912fe58c1dcced1edb7cff", size = 46814, upload-time = "2025-10-06T14:12:53.872Z" }, ] +[[package]] +name = "z3-solver" +version = "4.15.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/8e/0c8f17309549d2e5cde9a3ccefa6365437f1e7bafe71878eaf9478e47b18/z3_solver-4.15.4.0.tar.gz", hash = "sha256:928c29b58c4eb62106da51c1914f6a4a55d0441f8f48a81b9da07950434a8946", size = 5018600, upload-time = "2025-10-29T18:12:03.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/c9/bb51a96af0091324c81b803f16c49f719f9f6ea0b0bb52200f5c97ec4892/z3_solver-4.15.4.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e103a6f203f505b8b8b8e5c931cc407c95b61556512d4921c1ddc0b3f41b08e", size = 29268352, upload-time = "2025-10-29T18:11:53.032Z" }, +] + [[package]] name = "zipp" version = "3.23.0" From f0c8b07a196c330c87292fb7e7b8f412a9b26a98 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Jun 2026 08:06:40 -0600 Subject: [PATCH 17/35] DPO: align dpo.py wandb metric keys with dpo_tune_cache.py (rename train/* and perf/* keys, add learning_rate/epoch/training_step) Co-Authored-By: Claude Opus 4.8 --- open_instruct/olmo_core_callbacks.py | 6 ++--- open_instruct/olmo_core_train_modules.py | 31 +++++++++++++++--------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/open_instruct/olmo_core_callbacks.py b/open_instruct/olmo_core_callbacks.py index afbb5eada8..6b7f18837b 100644 --- a/open_instruct/olmo_core_callbacks.py +++ b/open_instruct/olmo_core_callbacks.py @@ -201,11 +201,11 @@ def post_step(self) -> None: seconds_per_step = interval_end - self._step_start_time - self.trainer.record_metric("perf/mfu", mfu_result["mfu"], reduce_type=None) + self.trainer.record_metric("perf/mfu_step", mfu_result["mfu"], reduce_type=None) self.trainer.record_metric("perf/mfu_avg", mfu_avg, reduce_type=None) self.trainer.record_metric("perf/seconds_per_step", seconds_per_step, reduce_type=None) - self.trainer.record_metric("perf/tokens_per_second", tokens_per_second, reduce_type=None) - self.trainer.record_metric("perf/tokens_per_second_avg", tokens_per_second_avg, reduce_type=None) + self.trainer.record_metric("perf/tokens_per_second_step", tokens_per_second, reduce_type=None) + self.trainer.record_metric("perf/tokens_per_second_total", tokens_per_second_avg, reduce_type=None) self.trainer.record_metric( "perf/tokens_per_second_per_gpu", tokens_per_second / (self.dp_world_size * self.tensor_parallel_degree), diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 2d6959b3d3..6e6e6280d8 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -253,29 +253,38 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: global_total_tokens = local_sums[0] global_metrics = {k: local_sums[i + 1] / global_total_tokens for i, k in enumerate(metric_keys)} - self.record_metric("train/loss", global_metrics["loss"].item(), reduce_type=None) - self.record_metric("train/logps_chosen", global_metrics["chosen_logps"].item(), reduce_type=None) - self.record_metric("train/logps_rejected", global_metrics["rejected_logps"].item(), reduce_type=None) + self.record_metric("train_loss", global_metrics["loss"].item(), reduce_type=None) + self.record_metric("logps/chosen", global_metrics["chosen_logps"].item(), reduce_type=None) + self.record_metric("logps/rejected", global_metrics["rejected_logps"].item(), reduce_type=None) token_count = self.trainer.data_loader.global_num_tokens_in_batch(batch) assert token_count is not None self.record_metric("train/token_count", token_count, reduce_type=None) + self.record_metric("training_step", float(self.trainer.global_step), reduce_type=None) + if self.trainer.steps_per_epoch is not None: + self.record_metric("epoch", self.trainer.global_step / self.trainer.steps_per_epoch, reduce_type=None) + if self.scheduler is not None and self.trainer.max_steps is not None: + lr = self.scheduler.get_lr( + self.optim.param_groups[0].get("initial_lr", self.optim.param_groups[0]["lr"]), + self.trainer.global_step, + self.trainer.max_steps, + ) + self.record_metric("learning_rate", float(lr), reduce_type=None) + if self.dpo_config.loss_type.computes_reward_metrics: margin = global_metrics["chosen_rewards"] - global_metrics["rejected_rewards"] - self.record_metric("train/rewards_chosen", global_metrics["chosen_rewards"].item(), reduce_type=None) - self.record_metric( - "train/rewards_rejected", global_metrics["rejected_rewards"].item(), reduce_type=None - ) + self.record_metric("rewards/chosen", global_metrics["chosen_rewards"].item(), reduce_type=None) + self.record_metric("rewards/rejected", global_metrics["rejected_rewards"].item(), reduce_type=None) self.record_metric( - "train/rewards_average", + "rewards/average", ((global_metrics["chosen_rewards"] + global_metrics["rejected_rewards"]) / 2).item(), reduce_type=None, ) - self.record_metric("train/rewards_accuracy", global_metrics["accuracy"].item(), reduce_type=None) - self.record_metric("train/rewards_margin", margin.item(), reduce_type=None) + self.record_metric("rewards/accuracy", global_metrics["accuracy"].item(), reduce_type=None) + self.record_metric("rewards/margin", margin.item(), reduce_type=None) if "aux_loss" in global_metrics: - self.record_metric("train/aux_loss", global_metrics["aux_loss"].item(), reduce_type=None) + self.record_metric("aux_loss", global_metrics["aux_loss"].item(), reduce_type=None) class GRPOTrainModule(TransformerTrainModule): From 7856a4584673f5245020223aed6515a3cf5c64c2 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Jun 2026 08:13:45 -0600 Subject: [PATCH 18/35] committed changes --- .../7b_instruct_dpo_sweep_olmo_core.sh | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index eebccb6f57..9405a52850 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -14,6 +14,20 @@ BEAKER_IMAGE="${1:-nathanl/open_instruct_auto}" +# MFU-tuning knobs (defaults reproduce the original fully-sharded config). +FSDP_SHARD_DEGREE="${FSDP_SHARD_DEGREE:-32}" +FSDP_NUM_REPLICAS="${FSDP_NUM_REPLICAS:-1}" +ACTIVATION_CHECKPOINTING_MODE="${ACTIVATION_CHECKPOINTING_MODE:-selected_modules}" +PER_DEVICE_TRAIN_BATCH_SIZE="${PER_DEVICE_TRAIN_BATCH_SIZE:-1}" +GRADIENT_ACCUMULATION_STEPS="${GRADIENT_ACCUMULATION_STEPS:-4}" +EXP_TAG="${EXP_TAG:-}" +PROFILING="${PROFILING:-false}" +PROFILING_FLAG="" +if [ "$PROFILING" = "true" ]; then PROFILING_FLAG="--profiling"; fi +AC_MODULES_FLAG="" +if [ -n "$AC_MODULES" ]; then AC_MODULES_FLAG="--activation_checkpointing_modules $AC_MODULES"; fi +TENSOR_PARALLEL_DEGREE="${TENSOR_PARALLEL_DEGREE:-1}" + SFT_MODELS=( allenai/Olmo-Hybrid-Instruct-SFT-7B ) @@ -28,7 +42,7 @@ CONFIG_NAME=olmo3_hybrid_7B for MODEL_PATH in "${SFT_MODELS[@]}"; do for LR in "${DPO_LRS[@]}"; do - EXP_NAME="hybrid-7b-DPO-oc-0219-SFT-public-LR-${LR}" + EXP_NAME="hybrid-7b-DPO-oc-0219-SFT-public-LR-${LR}${EXP_TAG}" echo "=====================================" echo "Launching: ${EXP_NAME}" echo " SFT model: ${MODEL_PATH}" @@ -67,9 +81,11 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --chat_template_name olmo123 \ --mixer_list allenai/Dolci-Instruct-DPO-fixed 259922 \ --max_seq_length 16384 \ - --per_device_train_batch_size 1 \ - --gradient_accumulation_steps 4 \ - --fsdp_shard_degree 32 \ + --per_device_train_batch_size "$PER_DEVICE_TRAIN_BATCH_SIZE" \ + --gradient_accumulation_steps "$GRADIENT_ACCUMULATION_STEPS" \ + --fsdp_shard_degree "$FSDP_SHARD_DEGREE" \ + --fsdp_num_replicas "$FSDP_NUM_REPLICAS" \ + --tensor_parallel_degree "$TENSOR_PARALLEL_DEGREE" \ --learning_rate "$LR" \ --lr_scheduler_type linear \ --checkpointing_steps 500 \ @@ -81,8 +97,10 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ - --activation_checkpointing_mode selected_modules \ + --activation_checkpointing_mode "$ACTIVATION_CHECKPOINTING_MODE" \ + $AC_MODULES_FLAG \ --compile_model true \ + $PROFILING_FLAG \ --with_tracking done done From 9e00c78befbb4fc610ff24b00e926c6775039ded Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Jun 2026 08:14:07 -0600 Subject: [PATCH 19/35] Added scripts --- docs/dpo-mfu-optimization.md | 142 +++++++++++++++++++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 docs/dpo-mfu-optimization.md diff --git a/docs/dpo-mfu-optimization.md b/docs/dpo-mfu-optimization.md new file mode 100644 index 0000000000..0a1093d5a6 --- /dev/null +++ b/docs/dpo-mfu-optimization.md @@ -0,0 +1,142 @@ +# DPO MFU Optimization Plan + +## Context + +Baseline run: + +- Beaker: `01KT58SPP2MVYN4EAKYHZN3DN6` +- W&B: https://wandb.ai/ai2-llm/open_instruct_internal/runs/gzhp5mp7 +- Script: `scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh` +- Model: `allenai/Olmo-Hybrid-Instruct-SFT-7B` +- Objective: improve DPO training MFU from the observed ~8.5%. + +Baseline configuration: + +- 4 nodes, 8 GPUs per node, 32 GPUs total +- `--max_seq_length 16384` +- `--per_device_train_batch_size 1` +- `--gradient_accumulation_steps 4` +- `--fsdp_shard_degree 32` +- `--fsdp_num_replicas 1` +- `--packing` +- `--activation_checkpointing_mode selected_modules` +- `--compile_model true` + +Runtime observations from Beaker logs: + +- Data loading is not the bottleneck: about 0.1-0.2% of wall time. +- GPU active memory is about 52 GiB, leaving meaningful headroom on 80 GiB GPUs. +- Steady-state wall time is about 3.6 seconds per optimizer step. +- Real token throughput varies heavily across steps, indicating uneven packed-token occupancy. + +The main bet is that the run is doing too much fixed-shape compute and cross-node communication for too few useful tokens. + +## Hypotheses + +1. Packed-token occupancy is too low. + + The DPO data loader uses `per_device_train_batch_size * gradient_accumulation_steps` as the per-rank candidate limit for packing. With the baseline values, each rank only considers 4 examples when building a 16k chosen / 16k rejected packed row. The model still runs over the padded packed shape, so underfilled packs waste compute. + +2. Full 32-way sharding is too communication-heavy. + + The baseline uses one FSDP shard group over all 32 GPUs. For a 7B model with 52 GiB active memory, smaller shard groups with HSDP replication may reduce cross-node all-gather/reduce-scatter overhead while still fitting in memory. + +3. Activation checkpointing may be more aggressive than needed. + + `selected_modules` checkpointing saves memory but adds recompute. Since the baseline has memory headroom, a lighter checkpointing mode, or no checkpointing, may improve step time. + +## Experiment Matrix + +Run short A/B experiments first. Each run only needs enough steady-state steps after compile and cache warmup to compare throughput, e.g. 100-200 training steps after training starts. + +| Run | Goal | Key changes | Expected outcome | Main risk | +| --- | --- | --- | --- | --- | +| Baseline repeat | Confirm current behavior on same image/code | Original config | MFU around 8.5%; ~0.28 device-BPS | Cluster noise | +| Pack candidates 16 | Improve useful-token occupancy | `GRADIENT_ACCUMULATION_STEPS=16` | More real tokens per step, higher MFU | Larger effective batch changes optimization | +| HSDP 8x4 | Reduce cross-node FSDP communication | `FSDP_SHARD_DEGREE=8`, `FSDP_NUM_REPLICAS=4` | Faster step time at similar token count | OOM or worse memory pressure | +| HSDP 4x8 | More aggressive communication reduction | `FSDP_SHARD_DEGREE=4`, `FSDP_NUM_REPLICAS=8` | Faster than 8x4 if memory allows | Higher OOM risk | +| No selected-module AC | Reduce recompute | `ACTIVATION_CHECKPOINTING_MODE=budget` with default budget | Faster step time if memory fits | OOM or compile incompatibility | +| Combined best | Validate interaction effects | Best packing + best HSDP + best AC mode | Highest MFU candidate | Interactions may differ from isolated runs | + +## Launch Commands + +Assuming the script keeps env-var overrides for the tuning knobs: + +```bash +# 1. Baseline repeat +EXP_TAG=-baseline \ +./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh + +# 2. More examples considered per packed row +GRADIENT_ACCUMULATION_STEPS=16 \ +EXP_TAG=-pack16 \ +./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh + +# 3. Per-node shard groups, 4 replicas +FSDP_SHARD_DEGREE=8 \ +FSDP_NUM_REPLICAS=4 \ +EXP_TAG=-hsdp8x4 \ +./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh + +# 4. Smaller shard groups, 8 replicas +FSDP_SHARD_DEGREE=4 \ +FSDP_NUM_REPLICAS=8 \ +EXP_TAG=-hsdp4x8 \ +./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh + +# 5. Lighter activation checkpointing +ACTIVATION_CHECKPOINTING_MODE=budget \ +EXP_TAG=-budget-ac \ +./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +``` + +Per repo workflow, `build_image_and_launch.sh` requires committed changes. Commit the script/doc changes before launching experiments. + +## Metrics To Compare + +Track these from W&B and Beaker logs: + +- `perf/mfu_step` +- `perf/mfu_avg` +- `perf/tokens_per_second_step` +- `perf/tokens_per_second_per_gpu` +- `perf/seconds_per_step` +- `perf/data_loading_pct` +- `throughput/device/BPS` +- `throughput/device/BPS (actual avg)` +- `throughput/total tokens` +- `gpu_memory/GPU active mem (GiB)` +- `gpu_memory/GPU reserved mem (GiB)` +- Loss curve and reward metrics, especially if changing effective batch size. + +For packing experiments, compute useful-token occupancy: + +```text +occupancy = real_tokens_per_step / (num_gpus * 2 * max_seq_length) +``` + +For the baseline, the denominator is: + +```text +32 * 2 * 16384 = 1,048,576 tokens per step +``` + +The baseline logs showed roughly 200k real tokens per step in a recent window, or about 20% occupancy. + +## Decision Rules + +1. If `GRADIENT_ACCUMULATION_STEPS=16` materially increases token occupancy and MFU without hurting loss behavior, keep it or sweep nearby values such as 8, 16, and 32. + +2. If `FSDP_SHARD_DEGREE=8, FSDP_NUM_REPLICAS=4` improves step time without OOM, prefer it over 32-way sharding. Try 4x8 only if 8x4 is stable and memory still has headroom. + +3. If `ACTIVATION_CHECKPOINTING_MODE=budget` fits and improves step time, keep it. If it OOMs, try a budgeted run with an explicit memory budget before returning to `selected_modules`. + +4. Once the best individual knobs are identified, run a combined experiment and compare against the baseline repeat, not only against the original run. + +5. Do not judge by MFU alone. A configuration that raises MFU by changing effective batch size still needs a sanity check on loss, reward margin, and downstream eval plan. + +## Possible Code Follow-Up + +The cleaner long-term fix is to decouple packing candidate count from optimizer batch semantics. Today, `gradient_accumulation_steps` controls how many examples the packing loader considers per rank, but with padding-free DPO the packed batch is still one row and `split_batch_dpo()` usually does not create multiple backward microbatches. + +Add a separate argument such as `--packing_max_examples_per_rank` or `--packing_candidate_multiplier`, then use that in `HFDataLoader` for packing while preserving the intended optimizer batch size. This would let us improve token occupancy without changing DPO effective batch size. From 0ba1ec8a342fb3d1c244e5fc9eb6044dd7d5a663 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Jun 2026 11:07:54 -0600 Subject: [PATCH 20/35] DPO: bucket-pad packed microbatches to next power-of-two (not max_seq_length) and wire HSDP knobs to cut padding-FLOP waste Co-Authored-By: Claude Opus 4.8 --- open_instruct/dpo.py | 5 +- open_instruct/padding_free_collator.py | 50 ++++++++++++++++--- open_instruct/test_padding_free_collator.py | 42 ++++++++++++++++ .../7b_instruct_dpo_sweep_olmo_core.sh | 4 ++ 4 files changed, 92 insertions(+), 9 deletions(-) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index 45efc1f031..1cfe59b3a0 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -186,7 +186,10 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz if args.packing: logger.info("Using packing/padding-free collation") collator = TensorDataCollatorWithFlatteningDPO( - return_position_ids=True, return_flash_attn_kwargs=True, max_seq_length=args.max_seq_length + return_position_ids=True, + return_flash_attn_kwargs=True, + max_seq_length=args.max_seq_length, + pad_to_bucket=True, ) else: collator = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=None, padding="longest") diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 85cf9fbd20..0f10c2fd55 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -7,6 +7,20 @@ from open_instruct import tensor_utils +def bucket_length(length: int, max_seq_length: int, min_bucket: int = 512) -> int: + """Round ``length`` up to the next power of two, clamped to [min_bucket, max_seq_length]. + + Padding every packed microbatch to ``max_seq_length`` wastes compute when batches are + short (the model's matmuls run over the padding). Bucketing to a small fixed set of + power-of-two lengths keeps torch.compile from recompiling for every distinct packed + length while padding far less than the full ``max_seq_length``. + """ + if length >= max_seq_length: + return max_seq_length + bucket = max(min_bucket, 1 << max(length - 1, 0).bit_length()) + return min(bucket, max_seq_length) + + def calculate_per_token_logps(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: shifted_labels = torch.full_like(labels, -100) shifted_labels[:, :-1] = labels[:, 1:] @@ -87,8 +101,9 @@ class TensorDataCollatorWithFlattening(DefaultDataCollator): return_seq_idx: bool = True separator_id: int = -100 max_seq_length: int | None = None + pad_to_bucket: bool = False - def __call__(self, features, return_tensors=None, separator_id=None): + def __call__(self, features, return_tensors=None, separator_id=None, pad_target=None): if return_tensors is None: return_tensors = self.return_tensors if separator_id is None: @@ -111,12 +126,17 @@ def __call__(self, features, return_tensors=None, separator_id=None): ret["seq_idx"] = torch.cat(seq_idx, dim=0)[None] if self.max_seq_length is not None: - ret["input_ids"] = tensor_utils.pad_to_length(ret["input_ids"], self.max_seq_length, pad_value=0) - ret["labels"] = tensor_utils.pad_to_length(ret["labels"], self.max_seq_length, pad_value=-100) + if pad_target is None: + content_length = ret["input_ids"].shape[-1] + pad_target = ( + bucket_length(content_length, self.max_seq_length) if self.pad_to_bucket else self.max_seq_length + ) + ret["input_ids"] = tensor_utils.pad_to_length(ret["input_ids"], pad_target, pad_value=0) + ret["labels"] = tensor_utils.pad_to_length(ret["labels"], pad_target, pad_value=-100) if "position_ids" in ret: - ret["position_ids"] = tensor_utils.pad_to_length(ret["position_ids"], self.max_seq_length, pad_value=0) + ret["position_ids"] = tensor_utils.pad_to_length(ret["position_ids"], pad_target, pad_value=0) if "seq_idx" in ret: - ret["seq_idx"] = tensor_utils.pad_to_length(ret["seq_idx"], self.max_seq_length, pad_value=-1) + ret["seq_idx"] = tensor_utils.pad_to_length(ret["seq_idx"], pad_target, pad_value=-1) return ret @@ -144,14 +164,28 @@ def count_features_within_token_budget(features: list[dict[str, Any]], max_seq_l @dataclass class TensorDataCollatorWithFlatteningDPO(TensorDataCollatorWithFlattening): - def __call__(self, features, return_tensors=None, separator_id=None): + def __call__(self, features, return_tensors=None, separator_id=None, pad_target=None): keep = count_features_within_token_budget(features, self.max_seq_length) features = features[:keep] + + # Pad chosen and rejected to a single shared bucket so the concatenated sequence + # (see concatenated_inputs) only ever takes one of a few static lengths under compile. + if pad_target is None and self.max_seq_length is not None and self.pad_to_bucket: + chosen_len = sum(len(f["chosen_input_ids"]) for f in features) + rejected_len = sum(len(f["rejected_input_ids"]) for f in features) + pad_target = bucket_length(max(chosen_len, rejected_len), self.max_seq_length) + chosen_features = super().__call__( - _filter_feature_dicts(features, "chosen_"), return_tensors=return_tensors, separator_id=separator_id + _filter_feature_dicts(features, "chosen_"), + return_tensors=return_tensors, + separator_id=separator_id, + pad_target=pad_target, ) rejected_features = super().__call__( - _filter_feature_dicts(features, "rejected_"), return_tensors=return_tensors, separator_id=separator_id + _filter_feature_dicts(features, "rejected_"), + return_tensors=return_tensors, + separator_id=separator_id, + pad_target=pad_target, ) result = {} diff --git a/open_instruct/test_padding_free_collator.py b/open_instruct/test_padding_free_collator.py index d3c28e1fb3..cb34812555 100644 --- a/open_instruct/test_padding_free_collator.py +++ b/open_instruct/test_padding_free_collator.py @@ -22,6 +22,7 @@ from open_instruct.padding_free_collator import ( TensorDataCollatorWithFlattening, TensorDataCollatorWithFlatteningDPO, + bucket_length, calculate_per_token_logps, concatenated_inputs, get_batch_logps, @@ -397,3 +398,44 @@ def test_forward_without_labels_returns_logits(self): output = head(x, labels=None) self.assertIsInstance(output, torch.Tensor) self.assertEqual(output.shape, (1, seq_len, vocab_size)) + + +class TestBucketLength(unittest.TestCase): + @parameterized.expand( + [ + ("below_min", 10, 16384, 512), + ("at_min", 512, 16384, 512), + ("just_above_min", 513, 16384, 1024), + ("power_of_two", 1024, 16384, 1024), + ("just_above_power", 1025, 16384, 2048), + ("mid", 3000, 16384, 4096), + ("at_max", 16384, 16384, 16384), + ("above_max", 20000, 16384, 16384), + ("near_max", 16383, 16384, 16384), + ] + ) + def test_bucket_length(self, name, length, max_seq_length, expected): + self.assertEqual(bucket_length(length, max_seq_length), expected) + + +class TestDPOBucketedPadding(unittest.TestCase): + def test_pads_to_shared_bucket_below_max(self): + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=16384, pad_to_bucket=True) + features = _make_dpo_features(num_samples=4, chosen_lengths=[100], rejected_lengths=[120]) + batch = collator(features) + + chosen_total = sum(100 for _ in range(4)) + rejected_total = sum(120 for _ in range(4)) + expected = bucket_length(max(chosen_total, rejected_total), 16384) + + self.assertEqual(batch["chosen_input_ids"].shape[-1], expected) + self.assertEqual(batch["rejected_input_ids"].shape[-1], expected) + self.assertLess(expected, 16384) + + def test_default_pads_to_max_seq_length(self): + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=16384) + features = _make_dpo_features(num_samples=4, chosen_lengths=[100], rejected_lengths=[120]) + batch = collator(features) + + self.assertEqual(batch["chosen_input_ids"].shape[-1], 16384) + self.assertEqual(batch["rejected_input_ids"].shape[-1], 16384) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index 9405a52850..c7eb4b19be 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -24,6 +24,9 @@ EXP_TAG="${EXP_TAG:-}" PROFILING="${PROFILING:-false}" PROFILING_FLAG="" if [ "$PROFILING" = "true" ]; then PROFILING_FLAG="--profiling"; fi +MAX_TRAIN_STEPS="${MAX_TRAIN_STEPS:-}" +MAX_TRAIN_STEPS_FLAG="" +if [ -n "$MAX_TRAIN_STEPS" ]; then MAX_TRAIN_STEPS_FLAG="--max_train_steps $MAX_TRAIN_STEPS"; fi AC_MODULES_FLAG="" if [ -n "$AC_MODULES" ]; then AC_MODULES_FLAG="--activation_checkpointing_modules $AC_MODULES"; fi TENSOR_PARALLEL_DEGREE="${TENSOR_PARALLEL_DEGREE:-1}" @@ -101,6 +104,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do $AC_MODULES_FLAG \ --compile_model true \ $PROFILING_FLAG \ + $MAX_TRAIN_STEPS_FLAG \ --with_tracking done done From e1dfe41fd76b916ca90cbdf938f1b855e9917f4f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Jun 2026 11:47:50 -0600 Subject: [PATCH 21/35] =?UTF-8?q?DPO:=20pack=20microbatches=20to=20the=20m?= =?UTF-8?q?ax=5Fseq=5Flength=20token=20budget=20instead=20of=20capping=20a?= =?UTF-8?q?t=20per=5Fdevice=5Fbatch=C3=97GAS=20sequences=20(fixes=20paddin?= =?UTF-8?q?g-FLOP=20MFU=20waste);=20revert=20bucketing=20approach=20Co-Aut?= =?UTF-8?q?hored-By:=20Claude=20Opus=204.8=20?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- open_instruct/data_loader.py | 3 +- open_instruct/dpo.py | 5 +-- open_instruct/padding_free_collator.py | 50 ++++----------------- open_instruct/test_data_loader.py | 48 ++++++++++++++++++++ open_instruct/test_padding_free_collator.py | 42 ----------------- 5 files changed, 58 insertions(+), 90 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 2d1660c793..43fdd7117c 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -324,9 +324,8 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: for i in range(len(all_indices)): new_totals = [running_totals[s] + lengths[i][s] for s in range(num_streams)] would_exceed = len(current_batch) > 0 and any(t > max_seq_length for t in new_totals) - at_max_samples = len(current_batch) >= self._per_rank_batch_size - if would_exceed or at_max_samples: + if would_exceed: batches.append(current_batch) current_batch = [i] running_totals = list(lengths[i]) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index 1cfe59b3a0..45efc1f031 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -186,10 +186,7 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz if args.packing: logger.info("Using packing/padding-free collation") collator = TensorDataCollatorWithFlatteningDPO( - return_position_ids=True, - return_flash_attn_kwargs=True, - max_seq_length=args.max_seq_length, - pad_to_bucket=True, + return_position_ids=True, return_flash_attn_kwargs=True, max_seq_length=args.max_seq_length ) else: collator = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=None, padding="longest") diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 0f10c2fd55..85cf9fbd20 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -7,20 +7,6 @@ from open_instruct import tensor_utils -def bucket_length(length: int, max_seq_length: int, min_bucket: int = 512) -> int: - """Round ``length`` up to the next power of two, clamped to [min_bucket, max_seq_length]. - - Padding every packed microbatch to ``max_seq_length`` wastes compute when batches are - short (the model's matmuls run over the padding). Bucketing to a small fixed set of - power-of-two lengths keeps torch.compile from recompiling for every distinct packed - length while padding far less than the full ``max_seq_length``. - """ - if length >= max_seq_length: - return max_seq_length - bucket = max(min_bucket, 1 << max(length - 1, 0).bit_length()) - return min(bucket, max_seq_length) - - def calculate_per_token_logps(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: shifted_labels = torch.full_like(labels, -100) shifted_labels[:, :-1] = labels[:, 1:] @@ -101,9 +87,8 @@ class TensorDataCollatorWithFlattening(DefaultDataCollator): return_seq_idx: bool = True separator_id: int = -100 max_seq_length: int | None = None - pad_to_bucket: bool = False - def __call__(self, features, return_tensors=None, separator_id=None, pad_target=None): + def __call__(self, features, return_tensors=None, separator_id=None): if return_tensors is None: return_tensors = self.return_tensors if separator_id is None: @@ -126,17 +111,12 @@ def __call__(self, features, return_tensors=None, separator_id=None, pad_target= ret["seq_idx"] = torch.cat(seq_idx, dim=0)[None] if self.max_seq_length is not None: - if pad_target is None: - content_length = ret["input_ids"].shape[-1] - pad_target = ( - bucket_length(content_length, self.max_seq_length) if self.pad_to_bucket else self.max_seq_length - ) - ret["input_ids"] = tensor_utils.pad_to_length(ret["input_ids"], pad_target, pad_value=0) - ret["labels"] = tensor_utils.pad_to_length(ret["labels"], pad_target, pad_value=-100) + ret["input_ids"] = tensor_utils.pad_to_length(ret["input_ids"], self.max_seq_length, pad_value=0) + ret["labels"] = tensor_utils.pad_to_length(ret["labels"], self.max_seq_length, pad_value=-100) if "position_ids" in ret: - ret["position_ids"] = tensor_utils.pad_to_length(ret["position_ids"], pad_target, pad_value=0) + ret["position_ids"] = tensor_utils.pad_to_length(ret["position_ids"], self.max_seq_length, pad_value=0) if "seq_idx" in ret: - ret["seq_idx"] = tensor_utils.pad_to_length(ret["seq_idx"], pad_target, pad_value=-1) + ret["seq_idx"] = tensor_utils.pad_to_length(ret["seq_idx"], self.max_seq_length, pad_value=-1) return ret @@ -164,28 +144,14 @@ def count_features_within_token_budget(features: list[dict[str, Any]], max_seq_l @dataclass class TensorDataCollatorWithFlatteningDPO(TensorDataCollatorWithFlattening): - def __call__(self, features, return_tensors=None, separator_id=None, pad_target=None): + def __call__(self, features, return_tensors=None, separator_id=None): keep = count_features_within_token_budget(features, self.max_seq_length) features = features[:keep] - - # Pad chosen and rejected to a single shared bucket so the concatenated sequence - # (see concatenated_inputs) only ever takes one of a few static lengths under compile. - if pad_target is None and self.max_seq_length is not None and self.pad_to_bucket: - chosen_len = sum(len(f["chosen_input_ids"]) for f in features) - rejected_len = sum(len(f["rejected_input_ids"]) for f in features) - pad_target = bucket_length(max(chosen_len, rejected_len), self.max_seq_length) - chosen_features = super().__call__( - _filter_feature_dicts(features, "chosen_"), - return_tensors=return_tensors, - separator_id=separator_id, - pad_target=pad_target, + _filter_feature_dicts(features, "chosen_"), return_tensors=return_tensors, separator_id=separator_id ) rejected_features = super().__call__( - _filter_feature_dicts(features, "rejected_"), - return_tensors=return_tensors, - separator_id=separator_id, - pad_target=pad_target, + _filter_feature_dicts(features, "rejected_"), return_tensors=return_tensors, separator_id=separator_id ) result = {} diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index 2bb6390fde..1cc5f8a4a6 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -79,5 +79,53 @@ def test_packing_equal_batches_across_ranks( self.assertEqual(all_indices, expected_indices, f"Missing indices: {expected_indices - all_indices}") +def _make_fixed_length_dpo_dataset(num_samples: int, seq_len: int) -> Dataset: + rng = torch.Generator().manual_seed(42) + data = { + "chosen_input_ids": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "chosen_labels": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "rejected_input_ids": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "rejected_labels": [torch.randint(0, 1000, (seq_len,), generator=rng) for _ in range(num_samples)], + "index": list(range(num_samples)), + } + ds = Dataset.from_dict(data) + ds.set_format(type="pt") + return ds + + +class TestTokenBudgetPacking(unittest.TestCase): + def test_packs_to_token_budget_not_sample_cap(self): + max_seq_length = 16384 + seq_len = 100 + num_samples = 200 + global_batch_size = 4 + dataset = _make_fixed_length_dpo_dataset(num_samples, seq_len) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + + with tempfile.TemporaryDirectory() as work_dir: + loader = data_loader.HFDataLoader( + dataset=dataset, + batch_size=global_batch_size, + seed=42, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + collator=collator, + drop_last=False, + ) + + batch_sizes = [] + seen_indices = set() + for batch in loader: + num_seqs = len(batch["index"]) + batch_sizes.append(num_seqs) + seen_indices.update(batch["index"].tolist()) + self.assertLessEqual(batch["chosen_cu_seq_lens_k"][-1].item(), max_seq_length) + self.assertLessEqual(batch["rejected_cu_seq_lens_k"][-1].item(), max_seq_length) + + self.assertGreater(max(batch_sizes), global_batch_size) + self.assertEqual(seen_indices, set(range(num_samples))) + + if __name__ == "__main__": unittest.main() diff --git a/open_instruct/test_padding_free_collator.py b/open_instruct/test_padding_free_collator.py index cb34812555..d3c28e1fb3 100644 --- a/open_instruct/test_padding_free_collator.py +++ b/open_instruct/test_padding_free_collator.py @@ -22,7 +22,6 @@ from open_instruct.padding_free_collator import ( TensorDataCollatorWithFlattening, TensorDataCollatorWithFlatteningDPO, - bucket_length, calculate_per_token_logps, concatenated_inputs, get_batch_logps, @@ -398,44 +397,3 @@ def test_forward_without_labels_returns_logits(self): output = head(x, labels=None) self.assertIsInstance(output, torch.Tensor) self.assertEqual(output.shape, (1, seq_len, vocab_size)) - - -class TestBucketLength(unittest.TestCase): - @parameterized.expand( - [ - ("below_min", 10, 16384, 512), - ("at_min", 512, 16384, 512), - ("just_above_min", 513, 16384, 1024), - ("power_of_two", 1024, 16384, 1024), - ("just_above_power", 1025, 16384, 2048), - ("mid", 3000, 16384, 4096), - ("at_max", 16384, 16384, 16384), - ("above_max", 20000, 16384, 16384), - ("near_max", 16383, 16384, 16384), - ] - ) - def test_bucket_length(self, name, length, max_seq_length, expected): - self.assertEqual(bucket_length(length, max_seq_length), expected) - - -class TestDPOBucketedPadding(unittest.TestCase): - def test_pads_to_shared_bucket_below_max(self): - collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=16384, pad_to_bucket=True) - features = _make_dpo_features(num_samples=4, chosen_lengths=[100], rejected_lengths=[120]) - batch = collator(features) - - chosen_total = sum(100 for _ in range(4)) - rejected_total = sum(120 for _ in range(4)) - expected = bucket_length(max(chosen_total, rejected_total), 16384) - - self.assertEqual(batch["chosen_input_ids"].shape[-1], expected) - self.assertEqual(batch["rejected_input_ids"].shape[-1], expected) - self.assertLess(expected, 16384) - - def test_default_pads_to_max_seq_length(self): - collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=16384) - features = _make_dpo_features(num_samples=4, chosen_lengths=[100], rejected_lengths=[120]) - batch = collator(features) - - self.assertEqual(batch["chosen_input_ids"].shape[-1], 16384) - self.assertEqual(batch["rejected_input_ids"].shape[-1], 16384) From 5cf152887af9ec5d34d7b313751bed867e165e6f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 08:29:47 -0600 Subject: [PATCH 22/35] DPO: add configurable per-microbatch sample cap + real gradient accumulation (microbatches_per_step); add train/padding_fraction and train/sequences_per_step metrics Co-Authored-By: Claude Opus 4.8 --- open_instruct/data_loader.py | 57 ++++++++++++++------- open_instruct/dpo.py | 2 + open_instruct/olmo_core_train_modules.py | 40 ++++++++++----- open_instruct/padding_free_collator.py | 13 +++-- open_instruct/test_data_loader.py | 65 ++++++++++++++++++++++++ 5 files changed, 143 insertions(+), 34 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 43fdd7117c..615ed9b52b 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -91,6 +91,8 @@ def __init__( drop_last: bool = True, fs_local_rank: int | None = None, max_seq_length: int = 1, + microbatch_sample_cap: int | None = None, + microbatches_per_step: int = 1, ) -> None: """Initialize the HFDataLoader. @@ -110,6 +112,12 @@ def __init__( fs_local_rank: File system local rank. Defaults to dp_rank when None. max_seq_length: Maximum sequence length. Used to report global_batch_size in tokens to the trainer for batch-size validation. + microbatch_sample_cap: When packing, the maximum number of examples per packed + microbatch. A microbatch closes when either the token budget or this cap is + reached. None means pack purely to the token budget. + microbatches_per_step: Number of packed microbatches grouped into each yielded batch + for gradient accumulation. When > 1, each yielded batch is a list of this many + collated microbatches; the trainer runs one optimizer step per yielded batch. Note: The dataset must have an 'index' column for tracking samples across epochs. @@ -144,6 +152,10 @@ def __init__( f"The effective global batch size will be {batch_size // dp_world_size * dp_world_size}." ) self._per_rank_batch_size = batch_size // dp_world_size + if microbatches_per_step < 1: + raise ValueError(f"microbatches_per_step must be >= 1, got {microbatches_per_step}") + self._microbatch_sample_cap = microbatch_sample_cap + self._microbatches_per_step = microbatches_per_step self._collator = collator if collator is not None else (lambda x: {"examples": x}) self._automatic_reshuffle = automatic_reshuffle self._drop_last = drop_last @@ -181,19 +193,24 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: # batches. Each entry in _precomputed_batch_sizes is the number of # examples in that batch (variable due to packing). if self._precomputed_batch_sizes is not None: + mbs_per_step = self._microbatches_per_step num_real = len(self._precomputed_batch_sizes) - self._num_padding_batches - offset = 0 - for batch_idx, batch_size in enumerate(self._precomputed_batch_sizes): - if batch_idx < self.batches_processed: - offset += batch_size + num_groups = len(self._precomputed_batch_sizes) // mbs_per_step + offsets = [0] + for batch_size in self._precomputed_batch_sizes: + offsets.append(offsets[-1] + batch_size) + for group_idx in range(num_groups): + if group_idx < self.batches_processed: continue - examples = [] - for i in range(offset, offset + batch_size): - example = self.dataset[i] - examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"}) - batch = to_device(self._collator(examples), self._device) | {"is_padding": batch_idx >= num_real} - offset += batch_size - yield batch + group = [] + for mb_idx in range(group_idx * mbs_per_step, (group_idx + 1) * mbs_per_step): + examples = [] + for i in range(offsets[mb_idx], offsets[mb_idx + 1]): + example = self.dataset[i] + examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"}) + collated = to_device(self._collator(examples), self._device) | {"is_padding": mb_idx >= num_real} + group.append(collated) + yield group if mbs_per_step > 1 else group[0] return start_example = self.batches_processed * self._per_rank_batch_size @@ -219,7 +236,7 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: def total_batches(self) -> int: """Return the total number of batches in an epoch.""" if self._precomputed_batch_sizes is not None: - return len(self._precomputed_batch_sizes) + return len(self._precomputed_batch_sizes) // self._microbatches_per_step return self.effective_size // self._per_rank_batch_size def state_dict(self) -> dict[str, Any]: @@ -324,8 +341,9 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: for i in range(len(all_indices)): new_totals = [running_totals[s] + lengths[i][s] for s in range(num_streams)] would_exceed = len(current_batch) > 0 and any(t > max_seq_length for t in new_totals) + at_cap = self._microbatch_sample_cap is not None and len(current_batch) >= self._microbatch_sample_cap - if would_exceed: + if would_exceed or at_cap: batches.append(current_batch) current_batch = [i] running_totals = list(lengths[i]) @@ -336,14 +354,19 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: if current_batch: batches.append(current_batch) + # Batches are distributed round-robin to ranks and then grouped into + # microbatches_per_step-sized optimizer steps, so the global count must be a + # multiple of dp_world_size * microbatches_per_step for every rank to have the + # same number of complete groups. + group_size = self.dp_world_size * self._microbatches_per_step num_batches = len(batches) padding_start = num_batches if self._drop_last: - num_batches = (num_batches // self.dp_world_size) * self.dp_world_size + num_batches = (num_batches // group_size) * group_size batches = batches[:num_batches] else: - if (remainder := num_batches % self.dp_world_size) > 0: - for _ in range(self.dp_world_size - remainder): + if (remainder := num_batches % group_size) > 0: + for _ in range(group_size - remainder): batches.append(batches[-1]) rank_global_indices = list(range(self.dp_rank, len(batches), self.dp_world_size)) @@ -371,7 +394,7 @@ def get_mock_batch(self) -> dict[str, Any]: examples = [self.dataset[i] for i in range(num_examples)] return to_device(self._collator(examples), self._device) - def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int: + def global_num_tokens_in_batch(self, batch: dict[str, Any] | list[dict[str, Any]]) -> int: """Return the total number of tokens in the batch across all ranks. Counts tokens from all keys containing 'input_ids' that are torch tensors. diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index 45efc1f031..90cf4c17aa 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -204,6 +204,8 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz device=device, drop_last=True, fs_local_rank=global_rank, + microbatch_sample_cap=args.per_device_train_batch_size, + microbatches_per_step=args.gradient_accumulation_steps, ) # 4x batch size: forward-only (no backward), so no activation storage needed. # With packing, the collator's token budget controls the actual forward-pass size diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 6e6e6280d8..15557372c0 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -178,11 +178,10 @@ def __init__( def pre_train(self): pass - def global_num_flops_in_batch(self, batch: dict[str, Any]) -> int | None: - global_num_tokens = self.trainer.data_loader.global_num_tokens_in_batch(batch) - if global_num_tokens is None: - return None - seq_len = batch["chosen_input_ids"].shape[1] + def global_num_flops_in_batch(self, batch: dict[str, Any] | list[dict[str, Any]]) -> int | None: + global_num_tokens = padding_free_collator.get_num_tokens(batch) * self.trainer.data_loader.dp_world_size + first = batch[0] if isinstance(batch, list) else batch + seq_len = first["chosen_input_ids"].shape[1] flops_per_token = self.num_flops_per_token(seq_len=seq_len) return flops_per_token * global_num_tokens if flops_per_token is not None else None @@ -220,12 +219,12 @@ def _compute_microbatch_loss(self, micro_batch: dict[str, Any]) -> tuple[torch.T return loss, step_metrics - def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: + def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: bool = False) -> None: self.model.train() - micro_batches = split_batch_dpo(batch, self.sample_microbatch_size) + micro_batches = batch if isinstance(batch, list) else split_batch_dpo(batch, self.sample_microbatch_size) num_micro_batches = len(micro_batches) - device = batch["chosen_input_ids"].device + device = micro_batches[0]["chosen_input_ids"].device total_tokens = padding_free_collator.get_num_tokens(batch) for v in self._metrics.values(): @@ -243,22 +242,35 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: self.model.post_batch(dry_run=dry_run) if not dry_run: + local_padded_tokens = sum( + v.numel() for mb in micro_batches for k, v in mb.items() if k.endswith("input_ids") + ) + local_num_sequences = padding_free_collator.get_num_sequences(batch) + if local_num_sequences is None: + local_num_sequences = sum(mb["chosen_input_ids"].shape[0] * 2 for mb in micro_batches) metric_keys = sorted(self._metrics.keys()) - local_sums_list = [torch.tensor(total_tokens, dtype=torch.float32, device=device)] + [ - self._metrics[k] for k in metric_keys - ] + local_sums_list = [ + torch.tensor(total_tokens, dtype=torch.float32, device=device), + torch.tensor(local_padded_tokens, dtype=torch.float32, device=device), + torch.tensor(local_num_sequences, dtype=torch.float32, device=device), + ] + [self._metrics[k] for k in metric_keys] local_sums = torch.stack(local_sums_list) dist.all_reduce(local_sums, op=dist.ReduceOp.SUM, group=self.trainer.dp_process_group) global_total_tokens = local_sums[0] - global_metrics = {k: local_sums[i + 1] / global_total_tokens for i, k in enumerate(metric_keys)} + global_padded_tokens = local_sums[1] + global_num_sequences = local_sums[2] + global_metrics = {k: local_sums[i + 3] / global_total_tokens for i, k in enumerate(metric_keys)} self.record_metric("train_loss", global_metrics["loss"].item(), reduce_type=None) self.record_metric("logps/chosen", global_metrics["chosen_logps"].item(), reduce_type=None) self.record_metric("logps/rejected", global_metrics["rejected_logps"].item(), reduce_type=None) - token_count = self.trainer.data_loader.global_num_tokens_in_batch(batch) - assert token_count is not None + token_count = total_tokens * self.trainer.data_loader.dp_world_size self.record_metric("train/token_count", token_count, reduce_type=None) + self.record_metric( + "train/padding_fraction", (1.0 - global_total_tokens / global_padded_tokens).item(), reduce_type=None + ) + self.record_metric("train/sequences_per_step", global_num_sequences.item(), reduce_type=None) self.record_metric("training_step", float(self.trainer.global_step), reduce_type=None) if self.trainer.steps_per_epoch is not None: diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 85cf9fbd20..fb40465626 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -229,13 +229,16 @@ def get_batch_logps( return segment_sums -def get_num_tokens(batch: dict[str, Any]) -> int: +def get_num_tokens(batch: dict[str, Any] | list[dict[str, Any]]) -> int: """Return total non-padding token count from a training batch. For packed batches (DPO or GRPO), reads cu_seq_lens_k tensors whose last element is the total token count for that branch. For padded batches, sums - the attention_mask. Falls back to counting input_ids elements. + the attention_mask. Falls back to counting input_ids elements. A list of + batches (gradient-accumulation microbatches) is summed. """ + if isinstance(batch, list): + return sum(get_num_tokens(b) for b in batch) # cu_seq_lens_k is a cumulative sequence length tensor from the padding-free # collator. Its last element equals the total token count for that branch. # DPO has chosen_cu_seq_lens_k + rejected_cu_seq_lens_k; GRPO has cu_seq_lens_k. @@ -249,12 +252,16 @@ def get_num_tokens(batch: dict[str, Any]) -> int: return sum(v.numel() for k, v in batch.items() if "input_ids" in k and isinstance(v, torch.Tensor)) -def get_num_sequences(batch: dict[str, Any]) -> int | None: +def get_num_sequences(batch: dict[str, Any] | list[dict[str, Any]]) -> int | None: """Return total sequence count from a training batch, or None for non-packing batches. For packed batches, reads cu_seq_lens_k tensors which each have num_seqs + 1 elements (including a leading 0). Returns None if no cu_seq_lens_k keys are found. + A list of batches (gradient-accumulation microbatches) is summed. """ + if isinstance(batch, list): + counts = [get_num_sequences(b) for b in batch] + return sum(c for c in counts if c is not None) cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] if cu_keys: # Each cu_seq_lens tensor has num_seqs + 1 elements (leading 0 boundary). diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index 1cc5f8a4a6..48a1e7a330 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -126,6 +126,71 @@ def test_packs_to_token_budget_not_sample_cap(self): self.assertGreater(max(batch_sizes), global_batch_size) self.assertEqual(seen_indices, set(range(num_samples))) + def test_microbatch_sample_cap_binds(self): + max_seq_length = 16384 + seq_len = 100 + num_samples = 200 + cap = 3 + dataset = _make_fixed_length_dpo_dataset(num_samples, seq_len) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + + with tempfile.TemporaryDirectory() as work_dir: + loader = data_loader.HFDataLoader( + dataset=dataset, + batch_size=4, + seed=42, + dp_rank=0, + dp_world_size=1, + work_dir=work_dir, + collator=collator, + drop_last=False, + microbatch_sample_cap=cap, + ) + + for batch in loader: + self.assertLessEqual(len(batch["index"]), cap) + + +class TestGradientAccumulationGrouping(unittest.TestCase): + @parameterized.parameterized.expand([("gas2_dp1", 2, 1), ("gas4_dp1", 4, 1), ("gas2_dp2", 2, 2)]) + def test_groups_microbatches_per_step(self, _name, microbatches_per_step, dp_world_size): + max_seq_length = 16384 + seq_len = 100 + num_samples = 200 + cap = 2 + dataset = _make_fixed_length_dpo_dataset(num_samples, seq_len) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + + with tempfile.TemporaryDirectory() as work_dir: + loaders = [ + data_loader.HFDataLoader( + dataset=dataset, + batch_size=4, + seed=42, + dp_rank=rank, + dp_world_size=dp_world_size, + work_dir=work_dir, + collator=collator, + drop_last=True, + microbatch_sample_cap=cap, + microbatches_per_step=microbatches_per_step, + ) + for rank in range(dp_world_size) + ] + + batch_counts = [loader.total_batches for loader in loaders] + self.assertTrue(all(c == batch_counts[0] for c in batch_counts), f"Step counts differ: {batch_counts}") + + for loader in loaders: + num_steps = 0 + for step in loader: + self.assertIsInstance(step, list) + self.assertEqual(len(step), microbatches_per_step) + for micro_batch in step: + self.assertLessEqual(len(micro_batch["index"]), cap) + num_steps += 1 + self.assertEqual(num_steps, loader.total_batches) + if __name__ == "__main__": unittest.main() From 79492081bfb1ce29f1f981cfaed55e3b26305fbf Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 08:36:14 -0600 Subject: [PATCH 23/35] DPO: bound get_mock_batch rows by token budget so a large microbatch_sample_cap doesn't load the dataset Co-Authored-By: Claude Opus 4.8 --- open_instruct/data_loader.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 615ed9b52b..0ef09afc0a 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -391,6 +391,11 @@ def get_mock_batch(self) -> dict[str, Any]: forward and backward pass before training officially starts. """ num_examples = min(self._per_rank_batch_size, len(self.dataset)) + # When packing, the collator consumes only as many examples as fit the token + # budget, so at most max_seq_length examples (each >= 1 token) can be used. + # Bound the rows loaded so a large microbatch_sample_cap doesn't load the dataset. + if getattr(self._collator, "max_seq_length", None) is not None: + num_examples = min(num_examples, self._collator.max_seq_length) examples = [self.dataset[i] for i in range(num_examples)] return to_device(self._collator(examples), self._device) From ac010f0f8a2d401e0ed54b204064edc9a52063aa Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 09:19:24 -0600 Subject: [PATCH 24/35] set flag --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index c7eb4b19be..6fd447931e 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -77,6 +77,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ --master_port=29400 \ --nproc_per_node=8 \ + --artifact_ttl 1d \ open_instruct/dpo.py \ --exp_name "$EXP_NAME" \ --model_name_or_path "$MODEL_PATH" \ From 7957881b2f90e5b8c9cc943dd6522af955d09f1c Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 09:20:31 -0600 Subject: [PATCH 25/35] disable HF upload --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index 6fd447931e..fec3ae7bc3 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -101,6 +101,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ + --push_to_hub False \ --activation_checkpointing_mode "$ACTIVATION_CHECKPOINTING_MODE" \ $AC_MODULES_FLAG \ --compile_model true \ From 91afdf0517ab4224ded5bb0c54b9c1b5c83ab147 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 10:14:57 -0600 Subject: [PATCH 26/35] moved flag --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index fec3ae7bc3..1ab489e744 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -58,6 +58,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --workspace ai2/linear-rnns \ --priority urgent \ --max_retries 0 \ + --artifact_ttl 1d \ --preemptible \ --image "$BEAKER_IMAGE" \ --pure_docker_mode \ @@ -77,7 +78,6 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --master_addr=\$BEAKER_LEADER_REPLICA_HOSTNAME \ --master_port=29400 \ --nproc_per_node=8 \ - --artifact_ttl 1d \ open_instruct/dpo.py \ --exp_name "$EXP_NAME" \ --model_name_or_path "$MODEL_PATH" \ From efd9ad43935d08c7f0505759ab03586bf6ffadf3 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 11:20:58 -0600 Subject: [PATCH 27/35] set flags correctly --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index 1ab489e744..a40ebb25fb 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -101,7 +101,8 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --loss_type dpo_norm \ --beta 5 \ --packing \ - --push_to_hub False \ + --push_to_hub False \ + --try_launch_beaker_eval_jobs False \ --activation_checkpointing_mode "$ACTIVATION_CHECKPOINTING_MODE" \ $AC_MODULES_FLAG \ --compile_model true \ From a59d9d21bcaf17f9d2e690cb67dd0e7ef6b982bf Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 13:03:01 -0600 Subject: [PATCH 28/35] cleaned up pr --- docs/dpo-mfu-optimization.md | 142 ----------------------- open_instruct/data_loader.py | 4 +- open_instruct/dpo.py | 3 +- open_instruct/olmo_core_train_modules.py | 100 +++++++++------- open_instruct/padding_free_collator.py | 11 ++ 5 files changed, 73 insertions(+), 187 deletions(-) delete mode 100644 docs/dpo-mfu-optimization.md diff --git a/docs/dpo-mfu-optimization.md b/docs/dpo-mfu-optimization.md deleted file mode 100644 index 0a1093d5a6..0000000000 --- a/docs/dpo-mfu-optimization.md +++ /dev/null @@ -1,142 +0,0 @@ -# DPO MFU Optimization Plan - -## Context - -Baseline run: - -- Beaker: `01KT58SPP2MVYN4EAKYHZN3DN6` -- W&B: https://wandb.ai/ai2-llm/open_instruct_internal/runs/gzhp5mp7 -- Script: `scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh` -- Model: `allenai/Olmo-Hybrid-Instruct-SFT-7B` -- Objective: improve DPO training MFU from the observed ~8.5%. - -Baseline configuration: - -- 4 nodes, 8 GPUs per node, 32 GPUs total -- `--max_seq_length 16384` -- `--per_device_train_batch_size 1` -- `--gradient_accumulation_steps 4` -- `--fsdp_shard_degree 32` -- `--fsdp_num_replicas 1` -- `--packing` -- `--activation_checkpointing_mode selected_modules` -- `--compile_model true` - -Runtime observations from Beaker logs: - -- Data loading is not the bottleneck: about 0.1-0.2% of wall time. -- GPU active memory is about 52 GiB, leaving meaningful headroom on 80 GiB GPUs. -- Steady-state wall time is about 3.6 seconds per optimizer step. -- Real token throughput varies heavily across steps, indicating uneven packed-token occupancy. - -The main bet is that the run is doing too much fixed-shape compute and cross-node communication for too few useful tokens. - -## Hypotheses - -1. Packed-token occupancy is too low. - - The DPO data loader uses `per_device_train_batch_size * gradient_accumulation_steps` as the per-rank candidate limit for packing. With the baseline values, each rank only considers 4 examples when building a 16k chosen / 16k rejected packed row. The model still runs over the padded packed shape, so underfilled packs waste compute. - -2. Full 32-way sharding is too communication-heavy. - - The baseline uses one FSDP shard group over all 32 GPUs. For a 7B model with 52 GiB active memory, smaller shard groups with HSDP replication may reduce cross-node all-gather/reduce-scatter overhead while still fitting in memory. - -3. Activation checkpointing may be more aggressive than needed. - - `selected_modules` checkpointing saves memory but adds recompute. Since the baseline has memory headroom, a lighter checkpointing mode, or no checkpointing, may improve step time. - -## Experiment Matrix - -Run short A/B experiments first. Each run only needs enough steady-state steps after compile and cache warmup to compare throughput, e.g. 100-200 training steps after training starts. - -| Run | Goal | Key changes | Expected outcome | Main risk | -| --- | --- | --- | --- | --- | -| Baseline repeat | Confirm current behavior on same image/code | Original config | MFU around 8.5%; ~0.28 device-BPS | Cluster noise | -| Pack candidates 16 | Improve useful-token occupancy | `GRADIENT_ACCUMULATION_STEPS=16` | More real tokens per step, higher MFU | Larger effective batch changes optimization | -| HSDP 8x4 | Reduce cross-node FSDP communication | `FSDP_SHARD_DEGREE=8`, `FSDP_NUM_REPLICAS=4` | Faster step time at similar token count | OOM or worse memory pressure | -| HSDP 4x8 | More aggressive communication reduction | `FSDP_SHARD_DEGREE=4`, `FSDP_NUM_REPLICAS=8` | Faster than 8x4 if memory allows | Higher OOM risk | -| No selected-module AC | Reduce recompute | `ACTIVATION_CHECKPOINTING_MODE=budget` with default budget | Faster step time if memory fits | OOM or compile incompatibility | -| Combined best | Validate interaction effects | Best packing + best HSDP + best AC mode | Highest MFU candidate | Interactions may differ from isolated runs | - -## Launch Commands - -Assuming the script keeps env-var overrides for the tuning knobs: - -```bash -# 1. Baseline repeat -EXP_TAG=-baseline \ -./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh - -# 2. More examples considered per packed row -GRADIENT_ACCUMULATION_STEPS=16 \ -EXP_TAG=-pack16 \ -./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh - -# 3. Per-node shard groups, 4 replicas -FSDP_SHARD_DEGREE=8 \ -FSDP_NUM_REPLICAS=4 \ -EXP_TAG=-hsdp8x4 \ -./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh - -# 4. Smaller shard groups, 8 replicas -FSDP_SHARD_DEGREE=4 \ -FSDP_NUM_REPLICAS=8 \ -EXP_TAG=-hsdp4x8 \ -./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh - -# 5. Lighter activation checkpointing -ACTIVATION_CHECKPOINTING_MODE=budget \ -EXP_TAG=-budget-ac \ -./scripts/train/build_image_and_launch.sh scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh -``` - -Per repo workflow, `build_image_and_launch.sh` requires committed changes. Commit the script/doc changes before launching experiments. - -## Metrics To Compare - -Track these from W&B and Beaker logs: - -- `perf/mfu_step` -- `perf/mfu_avg` -- `perf/tokens_per_second_step` -- `perf/tokens_per_second_per_gpu` -- `perf/seconds_per_step` -- `perf/data_loading_pct` -- `throughput/device/BPS` -- `throughput/device/BPS (actual avg)` -- `throughput/total tokens` -- `gpu_memory/GPU active mem (GiB)` -- `gpu_memory/GPU reserved mem (GiB)` -- Loss curve and reward metrics, especially if changing effective batch size. - -For packing experiments, compute useful-token occupancy: - -```text -occupancy = real_tokens_per_step / (num_gpus * 2 * max_seq_length) -``` - -For the baseline, the denominator is: - -```text -32 * 2 * 16384 = 1,048,576 tokens per step -``` - -The baseline logs showed roughly 200k real tokens per step in a recent window, or about 20% occupancy. - -## Decision Rules - -1. If `GRADIENT_ACCUMULATION_STEPS=16` materially increases token occupancy and MFU without hurting loss behavior, keep it or sweep nearby values such as 8, 16, and 32. - -2. If `FSDP_SHARD_DEGREE=8, FSDP_NUM_REPLICAS=4` improves step time without OOM, prefer it over 32-way sharding. Try 4x8 only if 8x4 is stable and memory still has headroom. - -3. If `ACTIVATION_CHECKPOINTING_MODE=budget` fits and improves step time, keep it. If it OOMs, try a budgeted run with an explicit memory budget before returning to `selected_modules`. - -4. Once the best individual knobs are identified, run a combined experiment and compare against the baseline repeat, not only against the original run. - -5. Do not judge by MFU alone. A configuration that raises MFU by changing effective batch size still needs a sanity check on loss, reward margin, and downstream eval plan. - -## Possible Code Follow-Up - -The cleaner long-term fix is to decouple packing candidate count from optimizer batch semantics. Today, `gradient_accumulation_steps` controls how many examples the packing loader considers per rank, but with padding-free DPO the packed batch is still one row and `split_batch_dpo()` usually does not create multiple backward microbatches. - -Add a separate argument such as `--packing_max_examples_per_rank` or `--packing_candidate_multiplier`, then use that in `HFDataLoader` for packing while preserving the intended optimizer batch size. This would let us improve token occupancy without changing DPO effective batch size. diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 0ef09afc0a..162b61c5fb 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -199,9 +199,7 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: offsets = [0] for batch_size in self._precomputed_batch_sizes: offsets.append(offsets[-1] + batch_size) - for group_idx in range(num_groups): - if group_idx < self.batches_processed: - continue + for group_idx in range(self.batches_processed, num_groups): group = [] for mb_idx in range(group_idx * mbs_per_step, (group_idx + 1) * mbs_per_step): examples = [] diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index 90cf4c17aa..940446fb8c 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -25,7 +25,7 @@ from open_instruct import data_loader as data_loader_lib from open_instruct import dataset_transformation, dpo_utils, logger_utils, model_utils, olmo_core_utils, utils from open_instruct.olmo_core_callbacks import PerfCallback -from open_instruct.olmo_core_train_modules import DPOTrainModule +from open_instruct.olmo_core_train_modules import DPOMetricsCallback, DPOTrainModule from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO logger = logger_utils.setup_logger(__name__) @@ -66,6 +66,7 @@ def _setup_callbacks(args: dpo_utils.DPOExperimentConfig, dp_world_size: int): wandb_entity=args.wandb_entity, save_async=False, ) + trainer_callbacks["dpo_metrics"] = DPOMetricsCallback() slack_webhook_url = os.environ.get("SLACK_WEBHOOK_URL") if args.send_slack_alerts and slack_webhook_url: trainer_callbacks["slack"] = callbacks.SlackNotifierCallback(name=run_name, webhook_url=slack_webhook_url) diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 15557372c0..96382995ac 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -18,6 +18,8 @@ from olmo_core.nn.transformer import Transformer from olmo_core.optim import OptimConfig from olmo_core.optim.scheduler import Scheduler +from olmo_core.train.callbacks import Callback +from olmo_core.train.common import ReduceType from olmo_core.train.train_module import TransformerTrainModule from olmo_core.train.train_module.transformer import config as transformer_config from torch.distributed.tensor import DTensor, Replicate, Shard @@ -34,6 +36,33 @@ {AttentionBackendName.flash_2, AttentionBackendName.flash_3, AttentionBackendName.flash_4} ) +# DPO token-weighted metrics are ratios sum_ranks(sum_mb(metric*tokens)) / sum_ranks(tokens). +# DPOTrainModule records each numerator under the _DPO_REDUCE_NS namespace and the shared +# denominators (real and padded token counts) under the keys below, all with ReduceType.sum so +# the trainer reduces them in its batched per-interval all-reduce. DPOMetricsCallback then divides +# numerator/denominator after reduction, avoiding a per-step host sync and explicit all-reduce. +_DPO_REDUCE_NS = "_dpo_reduce" +_DPO_TOKENS_KEY = f"{_DPO_REDUCE_NS}/__tokens__" +_DPO_PADDED_KEY = f"{_DPO_REDUCE_NS}/__padded__" + + +class DPOMetricsCallback(Callback): + """Reconstructs token-weighted DPO metrics from reduced numerator/denominator sums.""" + + priority = 10 + + def pre_log_metrics(self, step: int, metrics: dict[str, float]) -> None: + del step + if _DPO_TOKENS_KEY not in metrics: + return + tokens = metrics.pop(_DPO_TOKENS_KEY) + padded = metrics.pop(_DPO_PADDED_KEY) + prefix = f"{_DPO_REDUCE_NS}/" + for key in [k for k in metrics if k.startswith(prefix)]: + metrics[key[len(prefix) :]] = metrics.pop(key) / tokens + metrics["train/token_count"] = tokens + metrics["train/padding_fraction"] = 1.0 - tokens / padded + class DPOLMHead(LMHead): """LM head that returns per-token log-probabilities for DPO training. @@ -225,7 +254,8 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo micro_batches = batch if isinstance(batch, list) else split_batch_dpo(batch, self.sample_microbatch_size) num_micro_batches = len(micro_batches) device = micro_batches[0]["chosen_input_ids"].device - total_tokens = padding_free_collator.get_num_tokens(batch) + micro_token_counts = [padding_free_collator.get_num_tokens(mb) for mb in micro_batches] + total_tokens = sum(micro_token_counts) for v in self._metrics.values(): v.zero_() @@ -233,7 +263,7 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo for micro_batch_idx, micro_batch in enumerate(micro_batches): with self._train_microbatch_context(micro_batch_idx, num_micro_batches): loss, step_metrics = self._compute_microbatch_loss(micro_batch) - micro_tokens = padding_free_collator.get_num_tokens(micro_batch) + micro_tokens = micro_token_counts[micro_batch_idx] weight = micro_tokens / total_tokens for k, v in step_metrics.items(): self._metrics[k] += v.detach() * micro_tokens @@ -242,36 +272,39 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo self.model.post_batch(dry_run=dry_run) if not dry_run: - local_padded_tokens = sum( - v.numel() for mb in micro_batches for k, v in mb.items() if k.endswith("input_ids") - ) + local_padded_tokens = padding_free_collator.get_num_padded_tokens(micro_batches) local_num_sequences = padding_free_collator.get_num_sequences(batch) if local_num_sequences is None: local_num_sequences = sum(mb["chosen_input_ids"].shape[0] * 2 for mb in micro_batches) - metric_keys = sorted(self._metrics.keys()) - local_sums_list = [ - torch.tensor(total_tokens, dtype=torch.float32, device=device), - torch.tensor(local_padded_tokens, dtype=torch.float32, device=device), - torch.tensor(local_num_sequences, dtype=torch.float32, device=device), - ] + [self._metrics[k] for k in metric_keys] - local_sums = torch.stack(local_sums_list) - dist.all_reduce(local_sums, op=dist.ReduceOp.SUM, group=self.trainer.dp_process_group) - - global_total_tokens = local_sums[0] - global_padded_tokens = local_sums[1] - global_num_sequences = local_sums[2] - global_metrics = {k: local_sums[i + 3] / global_total_tokens for i, k in enumerate(metric_keys)} - - self.record_metric("train_loss", global_metrics["loss"].item(), reduce_type=None) - self.record_metric("logps/chosen", global_metrics["chosen_logps"].item(), reduce_type=None) - self.record_metric("logps/rejected", global_metrics["rejected_logps"].item(), reduce_type=None) - token_count = total_tokens * self.trainer.data_loader.dp_world_size - self.record_metric("train/token_count", token_count, reduce_type=None) + + tokens_tensor = torch.tensor(float(total_tokens), device=device) + self.record_metric(_DPO_TOKENS_KEY, tokens_tensor, reduce_type=ReduceType.sum) self.record_metric( - "train/padding_fraction", (1.0 - global_total_tokens / global_padded_tokens).item(), reduce_type=None + _DPO_PADDED_KEY, torch.tensor(float(local_padded_tokens), device=device), reduce_type=ReduceType.sum ) - self.record_metric("train/sequences_per_step", global_num_sequences.item(), reduce_type=None) + weighted_sums = { + "train_loss": self._metrics["loss"], + "logps/chosen": self._metrics["chosen_logps"], + "logps/rejected": self._metrics["rejected_logps"], + } + if self.dpo_config.loss_type.computes_reward_metrics: + chosen_rewards = self._metrics["chosen_rewards"] + rejected_rewards = self._metrics["rejected_rewards"] + weighted_sums["rewards/chosen"] = chosen_rewards + weighted_sums["rewards/rejected"] = rejected_rewards + weighted_sums["rewards/average"] = (chosen_rewards + rejected_rewards) / 2 + weighted_sums["rewards/accuracy"] = self._metrics["accuracy"] + weighted_sums["rewards/margin"] = chosen_rewards - rejected_rewards + if "aux_loss" in self._metrics: + weighted_sums["aux_loss"] = self._metrics["aux_loss"] + for name, value in weighted_sums.items(): + self.record_metric(f"{_DPO_REDUCE_NS}/{name}", value, reduce_type=ReduceType.sum) + self.record_metric( + "train/sequences_per_step", + torch.tensor(float(local_num_sequences), device=device), + reduce_type=ReduceType.sum, + ) self.record_metric("training_step", float(self.trainer.global_step), reduce_type=None) if self.trainer.steps_per_epoch is not None: self.record_metric("epoch", self.trainer.global_step / self.trainer.steps_per_epoch, reduce_type=None) @@ -283,21 +316,6 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo ) self.record_metric("learning_rate", float(lr), reduce_type=None) - if self.dpo_config.loss_type.computes_reward_metrics: - margin = global_metrics["chosen_rewards"] - global_metrics["rejected_rewards"] - self.record_metric("rewards/chosen", global_metrics["chosen_rewards"].item(), reduce_type=None) - self.record_metric("rewards/rejected", global_metrics["rejected_rewards"].item(), reduce_type=None) - self.record_metric( - "rewards/average", - ((global_metrics["chosen_rewards"] + global_metrics["rejected_rewards"]) / 2).item(), - reduce_type=None, - ) - self.record_metric("rewards/accuracy", global_metrics["accuracy"].item(), reduce_type=None) - self.record_metric("rewards/margin", margin.item(), reduce_type=None) - - if "aux_loss" in global_metrics: - self.record_metric("aux_loss", global_metrics["aux_loss"].item(), reduce_type=None) - class GRPOTrainModule(TransformerTrainModule): """ diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index fb40465626..1f9dab1700 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -252,6 +252,17 @@ def get_num_tokens(batch: dict[str, Any] | list[dict[str, Any]]) -> int: return sum(v.numel() for k, v in batch.items() if "input_ids" in k and isinstance(v, torch.Tensor)) +def get_num_padded_tokens(batch: dict[str, Any] | list[dict[str, Any]]) -> int: + """Return total token count including padding from a training batch. + + Counts all elements of input_ids tensors. A list of batches + (gradient-accumulation microbatches) is summed. + """ + if isinstance(batch, list): + return sum(get_num_padded_tokens(b) for b in batch) + return sum(v.numel() for k, v in batch.items() if k.endswith("input_ids") and isinstance(v, torch.Tensor)) + + def get_num_sequences(batch: dict[str, Any] | list[dict[str, Any]]) -> int | None: """Return total sequence count from a training batch, or None for non-packing batches. From a4ccbfd40a6004d9af0677c0343798776150b925 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Jun 2026 14:29:26 -0600 Subject: [PATCH 29/35] DPO: restore per-step train/token_count record so PerfCallback can compute MFU (metric refactor moved it into the deferred callback, breaking get_metric) Co-Authored-By: Claude Opus 4.8 --- open_instruct/olmo_core_train_modules.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 96382995ac..98090f3501 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -60,7 +60,6 @@ def pre_log_metrics(self, step: int, metrics: dict[str, float]) -> None: prefix = f"{_DPO_REDUCE_NS}/" for key in [k for k in metrics if k.startswith(prefix)]: metrics[key[len(prefix) :]] = metrics.pop(key) / tokens - metrics["train/token_count"] = tokens metrics["train/padding_fraction"] = 1.0 - tokens / padded @@ -282,6 +281,11 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo self.record_metric( _DPO_PADDED_KEY, torch.tensor(float(local_padded_tokens), device=device), reduce_type=ReduceType.sum ) + # PerfCallback reads train/token_count from the per-step buffer (before the deferred + # reduction), so it must be recorded here rather than reconstructed in DPOMetricsCallback. + self.record_metric( + "train/token_count", float(total_tokens) * self.trainer.data_loader.dp_world_size, reduce_type=None + ) weighted_sums = { "train_loss": self._metrics["loss"], "logps/chosen": self._metrics["chosen_logps"], From 6d807e69a26d7b9c6d2075e44ae754d03d5c161a Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Jun 2026 08:59:49 -0600 Subject: [PATCH 30/35] Drop leftover TRITON_PRINT_AUTOTUNING debug env from oc DPO sweep script and extend CHANGELOG entry to cover the MFU work (token-budget packing, grad accumulation, selected_modules AC, GDN-aware ModelDims) Co-Authored-By: Claude Opus 4.8 --- CHANGELOG.md | 2 +- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb866204ce..4e09eb1805 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed -- Support training the Olmo-Hybrid (GDN) model with the OLMo-core DPO trainer (`dpo.py`): bump olmo-core to a commit that adds the `olmo3_hybrid_7B` config preset and HF→olmo-core hybrid weight conversion (`convert_hybrid_state_from_hf`), and add an OLMo-core hybrid DPO sweep script (https://github.com/allenai/open-instruct/pull/1713). +- Support training the Olmo-Hybrid (GDN) model with the OLMo-core DPO trainer (`dpo.py`) and improve its MFU: bump olmo-core to a commit that adds the `olmo3_hybrid_7B` config preset and HF→olmo-core hybrid weight conversion (`convert_hybrid_state_from_hf`), pack DPO microbatches to the `max_seq_length` token budget with real gradient accumulation (`microbatches_per_step`) instead of capping at `per_device_train_batch_size` sequences, add a `selected_modules` activation checkpointing mode so torch.compile and checkpointing coexist with GDN, make `ModelDims` FLOPs/memory GDN-aware for correct MFU reporting, and add an OLMo-core hybrid DPO sweep script (https://github.com/allenai/open-instruct/pull/1713). - Expand type-checking coverage by replacing `# ty: ignore` directives with typed casts and fixing related type issues (https://github.com/allenai/open-instruct/pull/1688). - Add TV divergence rho filtering for GRPO (https://github.com/allenai/open-instruct/pull/1681). - Export `SETUPTOOLS_SCM_PRETEND_VERSION_FOR_OPEN_INSTRUCT=0.0.0+debug` in `scripts/train/debug/grpo.sh` and `grpo_fast.sh` (local Ray debug scripts that disable torch compile) so setuptools-scm can resolve the package version (https://github.com/allenai/open-instruct/pull/1696). diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh index a40ebb25fb..994a624fe8 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep_olmo_core.sh @@ -70,7 +70,6 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do --env TORCH_NCCL_AVOID_RECORD_STREAMS=1 \ --env TORCH_DIST_INIT_BARRIER=1 \ --env TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=1800 \ - --env TRITON_PRINT_AUTOTUNING=1 \ --num_nodes 4 \ --gpus 8 -- torchrun \ --nnodes=4 \ From 066265bc250d134e575b1076e7f64e8f9740d84f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Jun 2026 09:00:56 -0600 Subject: [PATCH 31/35] Fix stale SFT_LR reference in DeepSpeed sweep description (PR review) Co-Authored-By: Claude Opus 4.8 --- scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh index 28aa919c8f..743333456e 100755 --- a/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh +++ b/scripts/train/olmo-hybrid/7b_instruct_dpo_sweep.sh @@ -39,7 +39,7 @@ for MODEL_PATH in "${SFT_MODELS[@]}"; do uv run python mason.py \ --cluster ai2/jupiter \ - --description "Hybrid 7B DPO sweep, SFT-${SFT_LR}, LR=${LR}, 4 nodes, 16k seq, ZeRO-3." \ + --description "Hybrid 7B DPO sweep, SFT-public, LR=${LR}, 4 nodes, 16k seq, ZeRO-3." \ --workspace ai2/linear-rnns \ --priority urgent \ --max_retries 0 \ From 9d99e3e570bc1e6bc899e64e817285ec6a3845fc Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Jun 2026 09:08:11 -0600 Subject: [PATCH 32/35] Simplify: delegate global_num_flops_in_batch token count to data_loader.global_num_tokens_in_batch and unify the collator packing probe behind _collator_max_seq_length Co-Authored-By: Claude Opus 4.8 --- open_instruct/data_loader.py | 13 ++++++------- open_instruct/olmo_core_train_modules.py | 5 +++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 162b61c5fb..968bfc0ae1 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -157,6 +157,7 @@ def __init__( self._microbatch_sample_cap = microbatch_sample_cap self._microbatches_per_step = microbatches_per_step self._collator = collator if collator is not None else (lambda x: {"examples": x}) + self._collator_max_seq_length = getattr(self._collator, "max_seq_length", None) self._automatic_reshuffle = automatic_reshuffle self._drop_last = drop_last self._excluded_indices: set[int] = set() @@ -287,9 +288,8 @@ def _reshard(self, epoch: int) -> None: mask = np.isin(all_indices, list(self._excluded_indices), invert=True) all_indices = all_indices[mask] - packing_enabled = hasattr(self._collator, "max_seq_length") and self._collator.max_seq_length is not None - if packing_enabled: - self._reshard_with_packing(all_indices) + if self._collator_max_seq_length is not None: + self._reshard_with_packing(all_indices, self._collator_max_seq_length) return self._precomputed_batch_sizes = None @@ -315,7 +315,7 @@ def _reshard(self, epoch: int) -> None: self.effective_size = len(rank_indices) self.dataset = self._full_dataset.select(rank_indices.tolist()) - def _reshard_with_packing(self, all_indices: np.ndarray) -> None: + def _reshard_with_packing(self, all_indices: np.ndarray, max_seq_length: int) -> None: """Reshard with world-aware packing so all ranks get the same batch count. Instead of distributing examples to ranks and letting each rank pack @@ -323,7 +323,6 @@ def _reshard_with_packing(self, all_indices: np.ndarray) -> None: overflow), this packs globally first and then distributes packed batches round-robin to ranks. """ - max_seq_length = self._collator.max_seq_length column_names = self._full_dataset.column_names subset = self._full_dataset.select(all_indices.tolist()) if "chosen_input_ids" in column_names: @@ -392,8 +391,8 @@ def get_mock_batch(self) -> dict[str, Any]: # When packing, the collator consumes only as many examples as fit the token # budget, so at most max_seq_length examples (each >= 1 token) can be used. # Bound the rows loaded so a large microbatch_sample_cap doesn't load the dataset. - if getattr(self._collator, "max_seq_length", None) is not None: - num_examples = min(num_examples, self._collator.max_seq_length) + if self._collator_max_seq_length is not None: + num_examples = min(num_examples, self._collator_max_seq_length) examples = [self.dataset[i] for i in range(num_examples)] return to_device(self._collator(examples), self._device) diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 98090f3501..b6416b4c58 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -5,7 +5,7 @@ """ import math -from typing import Any, Literal +from typing import Any, Literal, cast import numpy as np import torch @@ -207,7 +207,8 @@ def pre_train(self): pass def global_num_flops_in_batch(self, batch: dict[str, Any] | list[dict[str, Any]]) -> int | None: - global_num_tokens = padding_free_collator.get_num_tokens(batch) * self.trainer.data_loader.dp_world_size + data_loader = cast(data_loader_lib.HFDataLoader, self.trainer.data_loader) + global_num_tokens = data_loader.global_num_tokens_in_batch(batch) first = batch[0] if isinstance(batch, list) else batch seq_len = first["chosen_input_ids"].shape[1] flops_per_token = self.num_flops_per_token(seq_len=seq_len) From cecab926f7d1f35dbcd33c9b01b95b611e95cca0 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Jun 2026 11:26:58 -0600 Subject: [PATCH 33/35] DPO packing: yield rectangular stacked-row batches (stack/unstack_packed_rows) so OLMo-core's dict batch contract, pre_train validation, and token accounting work natively; rank_microbatch_size = 2*max_seq_length tokens per packed row; drop microbatches_per_step and list-batch handling Co-Authored-By: Claude Opus 4.8 --- open_instruct/data_loader.py | 67 ++++++++--------- open_instruct/dpo.py | 25 ++++--- open_instruct/dpo_utils.py | 51 ++++++------- open_instruct/olmo_core_train_modules.py | 45 ++++++------ open_instruct/padding_free_collator.py | 94 ++++++++++++++++++------ open_instruct/test_data_loader.py | 73 ++++++++++++------ 6 files changed, 220 insertions(+), 135 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 968bfc0ae1..8f43909ea6 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -92,13 +92,16 @@ def __init__( fs_local_rank: int | None = None, max_seq_length: int = 1, microbatch_sample_cap: int | None = None, - microbatches_per_step: int = 1, ) -> None: """Initialize the HFDataLoader. Args: dataset: The HuggingFace Dataset to load data from. Must have an 'index' column. - batch_size: The global batch size (in sequences). + batch_size: The global batch size in instances. Without packing, an instance is one + example. With packing (a collator with max_seq_length set), an instance is one + packed row holding a variable number of examples, and each rank's batch is the + stack of its `batch_size // dp_world_size` rows (see + padding_free_collator.stack_packed_rows). seed: Random seed for shuffling. dp_rank: The rank of the current process in the distributed setup. dp_world_size: Total number of data-parallel processes in the distributed setup. @@ -110,14 +113,11 @@ def __init__( drop_last: If True, drop the last incomplete batch. If False, pad the last batch with repeated indices to fill a complete batch. fs_local_rank: File system local rank. Defaults to dp_rank when None. - max_seq_length: Maximum sequence length. Used to report global_batch_size in tokens + max_seq_length: Tokens per instance. Used to report global_batch_size in tokens to the trainer for batch-size validation. microbatch_sample_cap: When packing, the maximum number of examples per packed - microbatch. A microbatch closes when either the token budget or this cap is - reached. None means pack purely to the token budget. - microbatches_per_step: Number of packed microbatches grouped into each yielded batch - for gradient accumulation. When > 1, each yielded batch is a list of this many - collated microbatches; the trainer runs one optimizer step per yielded batch. + row. A row closes when either the token budget or this cap is reached. + None means pack purely to the token budget. Note: The dataset must have an 'index' column for tracking samples across epochs. @@ -152,10 +152,7 @@ def __init__( f"The effective global batch size will be {batch_size // dp_world_size * dp_world_size}." ) self._per_rank_batch_size = batch_size // dp_world_size - if microbatches_per_step < 1: - raise ValueError(f"microbatches_per_step must be >= 1, got {microbatches_per_step}") self._microbatch_sample_cap = microbatch_sample_cap - self._microbatches_per_step = microbatches_per_step self._collator = collator if collator is not None else (lambda x: {"examples": x}) self._collator_max_seq_length = getattr(self._collator, "max_seq_length", None) self._automatic_reshuffle = automatic_reshuffle @@ -189,27 +186,28 @@ def __next__(self) -> dict[str, Any]: def _iter_batches(self) -> Iterable[dict[str, Any]]: """Return an iterable over all batches in the epoch.""" - # World-aware packing: batch boundaries were precomputed by + # World-aware packing: row boundaries were precomputed by # _reshard_with_packing so that every rank has the same number of - # batches. Each entry in _precomputed_batch_sizes is the number of - # examples in that batch (variable due to packing). + # rows. Each entry in _precomputed_batch_sizes is the number of + # examples in that row (variable due to packing). Each yielded batch + # stacks per_rank_batch_size rows into one rectangular dict; the train + # module splits it back into per-row microbatches. if self._precomputed_batch_sizes is not None: - mbs_per_step = self._microbatches_per_step + rows_per_batch = self._per_rank_batch_size num_real = len(self._precomputed_batch_sizes) - self._num_padding_batches - num_groups = len(self._precomputed_batch_sizes) // mbs_per_step + num_batches = len(self._precomputed_batch_sizes) // rows_per_batch offsets = [0] - for batch_size in self._precomputed_batch_sizes: - offsets.append(offsets[-1] + batch_size) - for group_idx in range(self.batches_processed, num_groups): - group = [] - for mb_idx in range(group_idx * mbs_per_step, (group_idx + 1) * mbs_per_step): + for row_size in self._precomputed_batch_sizes: + offsets.append(offsets[-1] + row_size) + for batch_idx in range(self.batches_processed, num_batches): + rows = [] + for row_idx in range(batch_idx * rows_per_batch, (batch_idx + 1) * rows_per_batch): examples = [] - for i in range(offsets[mb_idx], offsets[mb_idx + 1]): + for i in range(offsets[row_idx], offsets[row_idx + 1]): example = self.dataset[i] examples.append(example | {"prompt_id": f"{self._epoch}_{example['index']}"}) - collated = to_device(self._collator(examples), self._device) | {"is_padding": mb_idx >= num_real} - group.append(collated) - yield group if mbs_per_step > 1 else group[0] + rows.append(self._collator(examples) | {"is_padding": row_idx >= num_real}) + yield to_device(padding_free_collator.stack_packed_rows(rows), self._device) return start_example = self.batches_processed * self._per_rank_batch_size @@ -235,7 +233,7 @@ def _iter_batches(self) -> Iterable[dict[str, Any]]: def total_batches(self) -> int: """Return the total number of batches in an epoch.""" if self._precomputed_batch_sizes is not None: - return len(self._precomputed_batch_sizes) // self._microbatches_per_step + return len(self._precomputed_batch_sizes) // self._per_rank_batch_size return self.effective_size // self._per_rank_batch_size def state_dict(self) -> dict[str, Any]: @@ -351,11 +349,11 @@ def _reshard_with_packing(self, all_indices: np.ndarray, max_seq_length: int) -> if current_batch: batches.append(current_batch) - # Batches are distributed round-robin to ranks and then grouped into - # microbatches_per_step-sized optimizer steps, so the global count must be a - # multiple of dp_world_size * microbatches_per_step for every rank to have the - # same number of complete groups. - group_size = self.dp_world_size * self._microbatches_per_step + # Rows are distributed round-robin to ranks and then stacked into + # per_rank_batch_size-sized batches, so the global count must be a + # multiple of dp_world_size * per_rank_batch_size for every rank to have + # the same number of complete batches. + group_size = self.dp_world_size * self._per_rank_batch_size num_batches = len(batches) padding_start = num_batches if self._drop_last: @@ -394,9 +392,12 @@ def get_mock_batch(self) -> dict[str, Any]: if self._collator_max_seq_length is not None: num_examples = min(num_examples, self._collator_max_seq_length) examples = [self.dataset[i] for i in range(num_examples)] - return to_device(self._collator(examples), self._device) + collated = self._collator(examples) + if self._collator_max_seq_length is not None: + collated = padding_free_collator.stack_packed_rows([collated]) + return to_device(collated, self._device) - def global_num_tokens_in_batch(self, batch: dict[str, Any] | list[dict[str, Any]]) -> int: + def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int: """Return the total number of tokens in the batch across all ranks. Counts tokens from all keys containing 'input_ids' that are torch tensors. diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index 940446fb8c..a16ec35ef0 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -192,8 +192,15 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz else: collator = dpo_utils.DataCollatorForSeq2SeqDPO(tokenizer=tokenizer, model=None, padding="longest") - rank_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps - global_batch_size = rank_batch_size * dp_world_size + # With packing, an instance is one packed row (one microbatch), so a rank accumulates + # gradient_accumulation_steps rows per optimizer step. Without packing, an instance is + # one example and a rank consumes per_device_train_batch_size * gradient_accumulation_steps + # examples per step. Either way, each instance spans 2 * max_seq_length tokens + # (chosen + rejected streams). + if args.packing: + global_batch_size = args.gradient_accumulation_steps * dp_world_size + else: + global_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * dp_world_size data_loader = data_loader_lib.HFDataLoader( dataset=dataset, batch_size=global_batch_size, @@ -205,16 +212,13 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz device=device, drop_last=True, fs_local_rank=global_rank, + max_seq_length=2 * args.max_seq_length, microbatch_sample_cap=args.per_device_train_batch_size, - microbatches_per_step=args.gradient_accumulation_steps, ) - # 4x batch size: forward-only (no backward), so no activation storage needed. - # With packing, the collator's token budget controls the actual forward-pass size - # and the overflow mechanism in HFDataLoader ensures no examples are dropped. - # We could probably have logic to use a longer sequence length here when packing - # is enabled, but for simplicity we just keep the 4x increase in batch size regardless of packing. - # We want the batch size to be as large as possible so that we always pack efficiently. - cache_batch_size = int(args.per_device_train_batch_size * 4 * dp_world_size) + # Forward-only (no backward), so no activation storage is needed. With packing each + # instance is already a full token-budget row, so one row per rank per batch suffices; + # without packing we use a 4x larger example batch. + cache_batch_size = dp_world_size if args.packing else int(args.per_device_train_batch_size * 4 * dp_world_size) cache_data_loader = data_loader_lib.HFDataLoader( dataset=dataset, batch_size=cache_batch_size, @@ -227,6 +231,7 @@ def main(args: dpo_utils.DPOExperimentConfig, tc: dataset_transformation.Tokeniz # We need to process every example to cache reference logprobs, so we can't drop the last batch. drop_last=False, fs_local_rank=global_rank, + max_seq_length=2 * args.max_seq_length, ) forward_fn = dpo_utils.concatenated_forward_olmo if args.concatenated_forward else dpo_utils.separate_forward_olmo diff --git a/open_instruct/dpo_utils.py b/open_instruct/dpo_utils.py index 7f1b8f3553..0a1c9c4d10 100644 --- a/open_instruct/dpo_utils.py +++ b/open_instruct/dpo_utils.py @@ -459,34 +459,35 @@ def build_reference_logprobs_cache( with torch.no_grad(): for batch in (pbar := tqdm(dataloader, disable=not is_main_process, desc="Caching reference logprobs")): - batch_start = time.perf_counter() - if use_lora and disable_adapter_context is not None: - with disable_adapter_context(): + for row in padding_free_collator.unstack_packed_rows(batch): + row_start = time.perf_counter() + if use_lora and disable_adapter_context is not None: + with disable_adapter_context(): + chosen_logps, rejected_logps, _ = forward_fn( + model, row, average_log_prob=average_log_prob, **(forward_kwargs or {}) + ) + else: chosen_logps, rejected_logps, _ = forward_fn( - model, batch, average_log_prob=average_log_prob, **(forward_kwargs or {}) + model, row, average_log_prob=average_log_prob, **(forward_kwargs or {}) ) - else: - chosen_logps, rejected_logps, _ = forward_fn( - model, batch, average_log_prob=average_log_prob, **(forward_kwargs or {}) - ) - if batch.get("is_padding", False): - continue - - chosen_tensor[batch["index"]] = chosen_logps - rejected_tensor[batch["index"]] = rejected_logps - - batch_tokens, batch_size, chosen_lengths, rejected_lengths = _get_batch_stats(batch) - total_tokens += batch_tokens - total_examples += batch_size - pbar.set_postfix( - { - "avg_tok/ex": f"{total_tokens / total_examples:.0f}", - "MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - batch_start):.1f}", - "mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}", - "mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}", - } - ) + if row.get("is_padding", False): + continue + + chosen_tensor[row["index"]] = chosen_logps + rejected_tensor[row["index"]] = rejected_logps + + batch_tokens, batch_size, chosen_lengths, rejected_lengths = _get_batch_stats(row) + total_tokens += batch_tokens + total_examples += batch_size + pbar.set_postfix( + { + "avg_tok/ex": f"{total_tokens / total_examples:.0f}", + "MFU%": f"{model_dims.calculate_mfu(chosen_lengths + rejected_lengths, time.perf_counter() - row_start):.1f}", + "mem_GB": f"{torch.cuda.max_memory_allocated() / 1e9:.1f}", + "mem%": f"{torch.cuda.max_memory_allocated() / torch.cuda.get_device_properties(0).total_memory * 100:.0f}", + } + ) dist.all_reduce(chosen_tensor, op=dist.ReduceOp.MAX) dist.all_reduce(rejected_tensor, op=dist.ReduceOp.MAX) diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index 7ab02c2cfc..f8291ba28d 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -52,7 +52,6 @@ class DPOMetricsCallback(Callback): priority = 10 def pre_log_metrics(self, step: int, metrics: dict[str, float]) -> None: - del step if _DPO_TOKENS_KEY not in metrics: return tokens = metrics.pop(_DPO_TOKENS_KEY) @@ -61,6 +60,11 @@ def pre_log_metrics(self, step: int, metrics: dict[str, float]) -> None: for key in [k for k in metrics if k.startswith(prefix)]: metrics[key[len(prefix) :]] = metrics.pop(key) / tokens metrics["train/padding_fraction"] = 1.0 - tokens / padded + metrics["training_step"] = float(step) + if self.trainer.steps_per_epoch is not None: + metrics["epoch"] = step / self.trainer.steps_per_epoch + if "optim/LR (group 0)" in metrics: + metrics["learning_rate"] = metrics["optim/LR (group 0)"] class DPOLMHead(LMHead): @@ -166,7 +170,13 @@ def __init__( ) # TODO(finbarrtimbers): Remove this hack once Transformer supports configuring the LM head. model.lm_head.__class__ = DPOLMHead - rank_microbatch_size_tokens = sample_microbatch_size * max_sequence_length * 2 + # With packing, a microbatch is one packed row: chosen + rejected streams, each + # padded to max_sequence_length tokens. Without packing, it is sample_microbatch_size + # example pairs of up to max_sequence_length tokens per stream. + if dpo_config.packing: + rank_microbatch_size_tokens = 2 * max_sequence_length + else: + rank_microbatch_size_tokens = sample_microbatch_size * max_sequence_length * 2 super().__init__( model=model, optim=optim, @@ -203,14 +213,10 @@ def __init__( if dpo_config.packing: self._forward_kwargs["packing"] = True - def pre_train(self): - pass - - def global_num_flops_in_batch(self, batch: dict[str, Any] | list[dict[str, Any]]) -> int | None: + def global_num_flops_in_batch(self, batch: dict[str, Any]) -> int | None: data_loader = cast(data_loader_lib.HFDataLoader, self.trainer.data_loader) global_num_tokens = data_loader.global_num_tokens_in_batch(batch) - first = batch[0] if isinstance(batch, list) else batch - seq_len = first["chosen_input_ids"].shape[1] + seq_len = batch["chosen_input_ids"].shape[1] flops_per_token = self.num_flops_per_token(seq_len=seq_len) return flops_per_token * global_num_tokens if flops_per_token is not None else None @@ -248,12 +254,15 @@ def _compute_microbatch_loss(self, micro_batch: dict[str, Any]) -> tuple[torch.T return loss, step_metrics - def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: bool = False) -> None: + def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: self.model.train() - micro_batches = batch if isinstance(batch, list) else split_batch_dpo(batch, self.sample_microbatch_size) + if self.dpo_config.packing: + micro_batches = padding_free_collator.unstack_packed_rows(batch) + else: + micro_batches = split_batch_dpo(batch, self.sample_microbatch_size) num_micro_batches = len(micro_batches) - device = micro_batches[0]["chosen_input_ids"].device + device = batch["chosen_input_ids"].device micro_token_counts = [padding_free_collator.get_num_tokens(mb) for mb in micro_batches] total_tokens = sum(micro_token_counts) @@ -272,10 +281,10 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo self.model.post_batch(dry_run=dry_run) if not dry_run: - local_padded_tokens = padding_free_collator.get_num_padded_tokens(micro_batches) + local_padded_tokens = padding_free_collator.get_num_padded_tokens(batch) local_num_sequences = padding_free_collator.get_num_sequences(batch) if local_num_sequences is None: - local_num_sequences = sum(mb["chosen_input_ids"].shape[0] * 2 for mb in micro_batches) + local_num_sequences = batch["chosen_input_ids"].shape[0] * 2 tokens_tensor = torch.tensor(float(total_tokens), device=device) self.record_metric(_DPO_TOKENS_KEY, tokens_tensor, reduce_type=ReduceType.sum) @@ -310,16 +319,6 @@ def train_batch(self, batch: dict[str, Any] | list[dict[str, Any]], dry_run: boo torch.tensor(float(local_num_sequences), device=device), reduce_type=ReduceType.sum, ) - self.record_metric("training_step", float(self.trainer.global_step), reduce_type=None) - if self.trainer.steps_per_epoch is not None: - self.record_metric("epoch", self.trainer.global_step / self.trainer.steps_per_epoch, reduce_type=None) - if self.scheduler is not None and self.trainer.max_steps is not None: - lr = self.scheduler.get_lr( - self.optim.param_groups[0].get("initial_lr", self.optim.param_groups[0]["lr"]), - self.trainer.global_step, - self.trainer.max_steps, - ) - self.record_metric("learning_rate", float(lr), reduce_type=None) class GRPOTrainModule(TransformerTrainModule): diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 1f9dab1700..877edbabde 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -229,22 +229,81 @@ def get_batch_logps( return segment_sums -def get_num_tokens(batch: dict[str, Any] | list[dict[str, Any]]) -> int: +def stack_packed_rows(rows: list[dict[str, Any]]) -> dict[str, Any]: + """Stack collated packed rows into one rectangular batch dict. + + Each row is a collator output where every stream tensor has shape + (1, max_seq_length). Stream tensors are concatenated along dim 0 to + (num_rows, max_seq_length). Per-row cu_seq_lens tensors are padded to a + common length by repeating their final value (the row's total real token + count), producing zero-length phantom boundaries that unstack_packed_rows + trims away. `index` is padded with -1 and `max_length_q/k` ints are reduced + with max. + """ + out: dict[str, Any] = {} + for k, v in rows[0].items(): + if k.endswith(("cu_seq_lens_q", "cu_seq_lens_k")): + width = max(len(r[k]) for r in rows) + out[k] = torch.stack([torch.cat([r[k], r[k][-1:].expand(width - len(r[k]))]) for r in rows]) + elif k.endswith(("max_length_q", "max_length_k")): + out[k] = max(r[k] for r in rows) + elif k == "index": + width = max(len(r[k]) for r in rows) + out[k] = torch.stack([torch.cat([r[k], r[k].new_full((width - len(r[k]),), -1)]) for r in rows]) + elif k == "is_padding": + out[k] = torch.tensor([r[k] for r in rows]) + elif isinstance(v, torch.Tensor): + out[k] = torch.cat([r[k] for r in rows], dim=0) + else: + out[k] = v + return out + + +def unstack_packed_rows(batch: dict[str, Any]) -> list[dict[str, Any]]: + """Invert stack_packed_rows, recovering the per-row collated dicts. + + Trims the repeated-value padding from cu_seq_lens rows and the -1 padding + from index rows. Batches without stacked (2-D) cu_seq_lens tensors are + returned unchanged as a single-element list, so callers can treat stacked + and unstacked batches uniformly. + """ + cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] + if not cu_keys or batch[cu_keys[0]].dim() == 1: + return [batch] + rows = [] + for i in range(batch[cu_keys[0]].shape[0]): + row: dict[str, Any] = {} + for k, v in batch.items(): + if k.endswith(("cu_seq_lens_q", "cu_seq_lens_k")): + cu = v[i] + num_seqs = int((cu[1:] > cu[:-1]).sum().item()) + row[k] = cu[: num_seqs + 1] + elif k == "index": + row[k] = v[i][v[i] >= 0] + elif k == "is_padding": + row[k] = bool(v[i].item()) + elif isinstance(v, torch.Tensor): + row[k] = v[i : i + 1] + else: + row[k] = v + rows.append(row) + return rows + + +def get_num_tokens(batch: dict[str, Any]) -> int: """Return total non-padding token count from a training batch. For packed batches (DPO or GRPO), reads cu_seq_lens_k tensors whose last - element is the total token count for that branch. For padded batches, sums - the attention_mask. Falls back to counting input_ids elements. A list of - batches (gradient-accumulation microbatches) is summed. + element (per row, for stacked batches) is the total token count for that + branch. For padded batches, sums the attention_mask. Falls back to counting + input_ids elements. """ - if isinstance(batch, list): - return sum(get_num_tokens(b) for b in batch) # cu_seq_lens_k is a cumulative sequence length tensor from the padding-free # collator. Its last element equals the total token count for that branch. # DPO has chosen_cu_seq_lens_k + rejected_cu_seq_lens_k; GRPO has cu_seq_lens_k. cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] if cu_keys: - return sum(batch[k][-1].item() for k in cu_keys) + return sum(batch[k][..., -1].sum().item() for k in cu_keys) # DPO batches have chosen_attention_mask and rejected_attention_mask; sum both branches. attn_keys = [k for k in batch if k.endswith("attention_mask")] if attn_keys: @@ -252,29 +311,22 @@ def get_num_tokens(batch: dict[str, Any] | list[dict[str, Any]]) -> int: return sum(v.numel() for k, v in batch.items() if "input_ids" in k and isinstance(v, torch.Tensor)) -def get_num_padded_tokens(batch: dict[str, Any] | list[dict[str, Any]]) -> int: +def get_num_padded_tokens(batch: dict[str, Any]) -> int: """Return total token count including padding from a training batch. - Counts all elements of input_ids tensors. A list of batches - (gradient-accumulation microbatches) is summed. + Counts all elements of input_ids tensors. """ - if isinstance(batch, list): - return sum(get_num_padded_tokens(b) for b in batch) return sum(v.numel() for k, v in batch.items() if k.endswith("input_ids") and isinstance(v, torch.Tensor)) -def get_num_sequences(batch: dict[str, Any] | list[dict[str, Any]]) -> int | None: +def get_num_sequences(batch: dict[str, Any]) -> int | None: """Return total sequence count from a training batch, or None for non-packing batches. - For packed batches, reads cu_seq_lens_k tensors which each have num_seqs + 1 - elements (including a leading 0). Returns None if no cu_seq_lens_k keys are found. - A list of batches (gradient-accumulation microbatches) is summed. + Counts strictly-increasing boundaries in each cu_seq_lens_k tensor, which + works for both 1-D rows and stacked 2-D rows (whose repeated-value padding + contributes no increase). Returns None if no cu_seq_lens_k keys are found. """ - if isinstance(batch, list): - counts = [get_num_sequences(b) for b in batch] - return sum(c for c in counts if c is not None) cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] if cu_keys: - # Each cu_seq_lens tensor has num_seqs + 1 elements (leading 0 boundary). - return sum(len(batch[k]) - 1 for k in cu_keys) + return int(sum((batch[k][..., 1:] > batch[k][..., :-1]).sum().item() for k in cu_keys)) return None diff --git a/open_instruct/test_data_loader.py b/open_instruct/test_data_loader.py index 48a1e7a330..fc2b93ad71 100644 --- a/open_instruct/test_data_loader.py +++ b/open_instruct/test_data_loader.py @@ -5,7 +5,7 @@ import torch from datasets import Dataset -from open_instruct import data_loader +from open_instruct import data_loader, padding_free_collator from open_instruct.padding_free_collator import TensorDataCollatorWithFlatteningDPO @@ -72,7 +72,7 @@ def test_packing_equal_batches_across_ranks( for loader in loaders: for batch in loader: if "index" in batch: - all_indices.update(batch["index"].tolist()) + all_indices.update(batch["index"][batch["index"] >= 0].tolist()) if not drop_last: expected_indices = set(range(num_samples)) @@ -114,16 +114,16 @@ def test_packs_to_token_budget_not_sample_cap(self): drop_last=False, ) - batch_sizes = [] + row_sizes = [] seen_indices = set() for batch in loader: - num_seqs = len(batch["index"]) - batch_sizes.append(num_seqs) - seen_indices.update(batch["index"].tolist()) - self.assertLessEqual(batch["chosen_cu_seq_lens_k"][-1].item(), max_seq_length) - self.assertLessEqual(batch["rejected_cu_seq_lens_k"][-1].item(), max_seq_length) + for row in padding_free_collator.unstack_packed_rows(batch): + row_sizes.append(len(row["index"])) + seen_indices.update(row["index"].tolist()) + self.assertLessEqual(row["chosen_cu_seq_lens_k"][-1].item(), max_seq_length) + self.assertLessEqual(row["rejected_cu_seq_lens_k"][-1].item(), max_seq_length) - self.assertGreater(max(batch_sizes), global_batch_size) + self.assertGreater(max(row_sizes), global_batch_size) self.assertEqual(seen_indices, set(range(num_samples))) def test_microbatch_sample_cap_binds(self): @@ -148,12 +148,13 @@ def test_microbatch_sample_cap_binds(self): ) for batch in loader: - self.assertLessEqual(len(batch["index"]), cap) + for row in padding_free_collator.unstack_packed_rows(batch): + self.assertLessEqual(len(row["index"]), cap) -class TestGradientAccumulationGrouping(unittest.TestCase): - @parameterized.parameterized.expand([("gas2_dp1", 2, 1), ("gas4_dp1", 4, 1), ("gas2_dp2", 2, 2)]) - def test_groups_microbatches_per_step(self, _name, microbatches_per_step, dp_world_size): +class TestStackedPackedBatches(unittest.TestCase): + @parameterized.parameterized.expand([("rows2_dp1", 2, 1), ("rows4_dp1", 4, 1), ("rows2_dp2", 2, 2)]) + def test_yields_per_rank_rows_per_batch(self, _name, rows_per_rank, dp_world_size): max_seq_length = 16384 seq_len = 100 num_samples = 200 @@ -165,7 +166,7 @@ def test_groups_microbatches_per_step(self, _name, microbatches_per_step, dp_wor loaders = [ data_loader.HFDataLoader( dataset=dataset, - batch_size=4, + batch_size=rows_per_rank * dp_world_size, seed=42, dp_rank=rank, dp_world_size=dp_world_size, @@ -173,7 +174,6 @@ def test_groups_microbatches_per_step(self, _name, microbatches_per_step, dp_wor collator=collator, drop_last=True, microbatch_sample_cap=cap, - microbatches_per_step=microbatches_per_step, ) for rank in range(dp_world_size) ] @@ -182,14 +182,41 @@ def test_groups_microbatches_per_step(self, _name, microbatches_per_step, dp_wor self.assertTrue(all(c == batch_counts[0] for c in batch_counts), f"Step counts differ: {batch_counts}") for loader in loaders: - num_steps = 0 - for step in loader: - self.assertIsInstance(step, list) - self.assertEqual(len(step), microbatches_per_step) - for micro_batch in step: - self.assertLessEqual(len(micro_batch["index"]), cap) - num_steps += 1 - self.assertEqual(num_steps, loader.total_batches) + num_batches = 0 + for batch in loader: + self.assertIsInstance(batch, dict) + self.assertEqual(batch["chosen_input_ids"].shape, (rows_per_rank, max_seq_length)) + rows = padding_free_collator.unstack_packed_rows(batch) + self.assertEqual(len(rows), rows_per_rank) + for row in rows: + self.assertLessEqual(len(row["index"]), cap) + num_batches += 1 + self.assertEqual(num_batches, loader.total_batches) + + def test_stack_unstack_round_trip(self): + max_seq_length = 512 + dataset = _make_dpo_dataset(num_samples=7, max_seq_length=max_seq_length) + collator = TensorDataCollatorWithFlatteningDPO(max_seq_length=max_seq_length) + rows = [ + collator([dataset[0], dataset[1]]) | {"is_padding": False}, + collator([dataset[2]]) | {"is_padding": False}, + collator([dataset[3], dataset[4], dataset[5]]) | {"is_padding": True}, + ] + + stacked = padding_free_collator.stack_packed_rows(rows) + unstacked = padding_free_collator.unstack_packed_rows(stacked) + + self.assertEqual(len(unstacked), len(rows)) + for original, restored in zip(rows, unstacked): + self.assertEqual(set(original.keys()), set(restored.keys())) + for k, v in original.items(): + if k.endswith(("max_length_q", "max_length_k")): + # Stacking reduces max_length to a batch-level max (a safe upper bound). + self.assertEqual(restored[k], max(r[k] for r in rows)) + elif isinstance(v, torch.Tensor): + torch.testing.assert_close(restored[k], v) + else: + self.assertEqual(restored[k], v) if __name__ == "__main__": From de4ff38aaae25930ce5c021c2a28d3c07cfedc17 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Jun 2026 12:10:14 -0600 Subject: [PATCH 34/35] Update CHANGELOG entry for rectangular stacked packed-row DPO batches (microbatches_per_step removed) Co-Authored-By: Claude Opus 4.8 --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed952bc245..5832a771a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ All notable changes to this project will be documented in this file. ### Changed -- Support training the Olmo-Hybrid (GDN) model with the OLMo-core DPO trainer (`dpo.py`) and improve its MFU: bump olmo-core to a commit that adds the `olmo3_hybrid_7B` config preset and HF→olmo-core hybrid weight conversion (`convert_hybrid_state_from_hf`), pack DPO microbatches to the `max_seq_length` token budget with real gradient accumulation (`microbatches_per_step`) instead of capping at `per_device_train_batch_size` sequences, add a `selected_modules` activation checkpointing mode so torch.compile and checkpointing coexist with GDN, make `ModelDims` FLOPs/memory GDN-aware for correct MFU reporting, and add an OLMo-core hybrid DPO sweep script (https://github.com/allenai/open-instruct/pull/1713). +- Support training the Olmo-Hybrid (GDN) model with the OLMo-core DPO trainer (`dpo.py`) and improve its MFU: bump olmo-core to a commit that adds the `olmo3_hybrid_7B` config preset and HF→olmo-core hybrid weight conversion (`convert_hybrid_state_from_hf`), pack DPO microbatches to the `max_seq_length` token budget instead of capping at `per_device_train_batch_size` sequences and yield rectangular stacked packed-row batches (`stack_packed_rows`/`unstack_packed_rows`) so OLMo-core's dict batch contract, batch-size validation, and token accounting work natively (gradient accumulation = packed rows per rank per step), add a `selected_modules` activation checkpointing mode so torch.compile and checkpointing coexist with GDN, make `ModelDims` FLOPs/memory GDN-aware for correct MFU reporting, and add an OLMo-core hybrid DPO sweep script (https://github.com/allenai/open-instruct/pull/1713). - Record a `_metrics_keepalive` metric on every rank every GRPO+OLMo-core step to keep `_metrics` non-empty, preventing OLMo-core's empty-skip in `_log_metrics` from desyncing the bookkeeping process group and deadlocking gloo for 30 minutes at save-time flushes (https://github.com/allenai/open-instruct/pull/1708). - Expand type-checking coverage by replacing `# ty: ignore` directives with typed casts and fixing related type issues (https://github.com/allenai/open-instruct/pull/1688). - Add TV divergence rho filtering for GRPO (https://github.com/allenai/open-instruct/pull/1681). From e224447d0ac4da969f979566a5d032a5376f69a9 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Jun 2026 12:20:16 -0600 Subject: [PATCH 35/35] Simplify: get_num_sequences always returns int (counts input_ids rows for non-packed batches), removing the None fallbacks in train_batch and PerfCallback.pre_step and the now-unused per_device_train_batch_size field Co-Authored-By: Claude Opus 4.8 --- open_instruct/dpo.py | 1 - open_instruct/olmo_core_callbacks.py | 3 --- open_instruct/olmo_core_train_modules.py | 2 -- open_instruct/padding_free_collator.py | 13 +++++++------ 4 files changed, 7 insertions(+), 12 deletions(-) diff --git a/open_instruct/dpo.py b/open_instruct/dpo.py index a16ec35ef0..055ae93383 100644 --- a/open_instruct/dpo.py +++ b/open_instruct/dpo.py @@ -73,7 +73,6 @@ def _setup_callbacks(args: dpo_utils.DPOExperimentConfig, dp_world_size: int): model_dims = utils.ModelDims.from_hf_config(args.model_name_or_path) trainer_callbacks["perf"] = PerfCallback( model_dims=model_dims, - per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, dp_world_size=dp_world_size, tensor_parallel_degree=args.tensor_parallel_degree, diff --git a/open_instruct/olmo_core_callbacks.py b/open_instruct/olmo_core_callbacks.py index 6b7f18837b..cc1e35529b 100644 --- a/open_instruct/olmo_core_callbacks.py +++ b/open_instruct/olmo_core_callbacks.py @@ -123,7 +123,6 @@ class PerfCallback(Callback): """Calculates MFU and tokens_per_second using same formula as dpo_tune_cache.py.""" model_dims: utils.ModelDims - per_device_train_batch_size: int gradient_accumulation_steps: int dp_world_size: int tensor_parallel_degree: int = 1 @@ -161,8 +160,6 @@ def pre_step(self, batch: dict[str, Any]) -> None: self._pre_step_time = time.perf_counter() self._step_start_time = self._pre_step_time num_seqs = padding_free_collator.get_num_sequences(batch) - if num_seqs is None: - num_seqs = self.per_device_train_batch_size * 2 self._interval_num_sequences += num_seqs * self.dp_world_size def post_step(self) -> None: diff --git a/open_instruct/olmo_core_train_modules.py b/open_instruct/olmo_core_train_modules.py index f8291ba28d..e329d4b383 100644 --- a/open_instruct/olmo_core_train_modules.py +++ b/open_instruct/olmo_core_train_modules.py @@ -283,8 +283,6 @@ def train_batch(self, batch: dict[str, Any], dry_run: bool = False) -> None: if not dry_run: local_padded_tokens = padding_free_collator.get_num_padded_tokens(batch) local_num_sequences = padding_free_collator.get_num_sequences(batch) - if local_num_sequences is None: - local_num_sequences = batch["chosen_input_ids"].shape[0] * 2 tokens_tensor = torch.tensor(float(total_tokens), device=device) self.record_metric(_DPO_TOKENS_KEY, tokens_tensor, reduce_type=ReduceType.sum) diff --git a/open_instruct/padding_free_collator.py b/open_instruct/padding_free_collator.py index 877edbabde..119f8e3368 100644 --- a/open_instruct/padding_free_collator.py +++ b/open_instruct/padding_free_collator.py @@ -319,14 +319,15 @@ def get_num_padded_tokens(batch: dict[str, Any]) -> int: return sum(v.numel() for k, v in batch.items() if k.endswith("input_ids") and isinstance(v, torch.Tensor)) -def get_num_sequences(batch: dict[str, Any]) -> int | None: - """Return total sequence count from a training batch, or None for non-packing batches. +def get_num_sequences(batch: dict[str, Any]) -> int: + """Return total sequence count from a training batch. - Counts strictly-increasing boundaries in each cu_seq_lens_k tensor, which - works for both 1-D rows and stacked 2-D rows (whose repeated-value padding - contributes no increase). Returns None if no cu_seq_lens_k keys are found. + For packed batches, counts strictly-increasing boundaries in each + cu_seq_lens_k tensor, which works for both 1-D rows and stacked 2-D rows + (whose repeated-value padding contributes no increase). For non-packed + batches, counts rows of each input_ids tensor. """ cu_keys = [k for k in batch if k.endswith("cu_seq_lens_k")] if cu_keys: return int(sum((batch[k][..., 1:] > batch[k][..., :-1]).sum().item() for k in cu_keys)) - return None + return sum(batch[k].shape[0] for k in batch if k.endswith("input_ids"))