diff --git a/megatron/core/optimizer/__init__.py b/megatron/core/optimizer/__init__.py index 27b675d1b8d..f3091100273 100644 --- a/megatron/core/optimizer/__init__.py +++ b/megatron/core/optimizer/__init__.py @@ -45,7 +45,13 @@ HAVE_EMERGING_OPTIMIZERS = _eo_ver >= (0, 2) if HAVE_EMERGING_OPTIMIZERS: + from emerging_optimizers.orthogonalized_optimizers import OrthogonalizedOptimizer from emerging_optimizers.scalar_optimizers import Lion +else: + # Sentinel so ``isinstance(opt, OrthogonalizedOptimizer)`` is always False when the + # package is unavailable. All sites that test it are already guarded by + # HAVE_EMERGING_OPTIMIZERS, so this is only a defensive fallback for import safety. + OrthogonalizedOptimizer = () from megatron.core import parallel_state from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer @@ -874,7 +880,21 @@ def _get_megatron_emerging_optimizer( optimizer, init_state_fn = _create_emerging_optimizer( config, groups, eopt_name, model_chunks, pg_collection ) + # Only orthogonalizing optimizers (Muon family, i.e. subclasses of + # OrthogonalizedOptimizer) have scale-invariant updates that make magnitude + # grad-norm clipping a no-op-at-best / harmful-at-worst (issue #5394). + # Other emerging optimizers (SOAP, Lion) are NOT scale-invariant and must + # still be clipped, so the flag is gated on this check rather than set for + # every emerging optimizer. + is_orthogonalizing = isinstance(optimizer, OrthogonalizedOptimizer) if use_layer_wise: + # Mark the raw (unwrapped) Muon sub-optimizer so the flag is visible in + # the layer-wise DIRECT path. NOTE: LayerWiseDistributedOptimizer re-wraps + # base optimizers in Float16OptimizerWithFloat16Params when config.bf16 is + # True, so this flag is *also* re-propagated onto the actual + # ``chained_optimizers`` members after construction below. + if is_orthogonalizing: + setattr(optimizer, 'skip_grad_norm_clip', True) layer_wise_base_results.append((optimizer, init_state_fn)) continue if config.bf16: @@ -884,6 +904,11 @@ def _get_megatron_emerging_optimizer( else: optimizer = FP32Optimizer(optimizer, config, init_state_fn) setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group) + # Orthogonalizing optimizers (Muon) have scale-invariant updates; magnitude + # gradient clipping is a no-op at best and harmful at worst (see issue #5394). + # Gate on is_orthogonalizing so SOAP/Lion (non-scale-invariant) keep clipping. + if is_orthogonalizing: + setattr(optimizer, 'skip_grad_norm_clip', True) if pg_collection is None or not hasattr(pg_collection, 'tp'): tp_group = parallel_state.get_tensor_model_parallel_group() else: @@ -963,6 +988,28 @@ def _get_megatron_emerging_optimizer( init_state_fn_list=list(init_fns), model_chunks=model_chunks, ) + # Re-propagate the skip flag onto the actual chained sub-optimizers. LayerWise + # re-wraps each base optimizer in Float16OptimizerWithFloat16Params when + # config.bf16 is True, so the flag set on the raw Muon sub-optimizer above is + # NOT visible on ``layer_wise_optimizer.chained_optimizers`` (the wrappers do not + # forward attribute access). The DIRECT path returns ``layer_wise_optimizer`` and + # runs the inherited ``ChainedOptimizer.step()`` over these inner subs, reading + # their per-sub flag. Carry the flag from each member's underlying raw optimizer + # (``.optimizer`` on a wrapper, else the member itself) so Muon subs stay flagged + # while non-emerging (Adam) subs stay unflagged. + for sub_optimizer in layer_wise_optimizer.chained_optimizers: + raw_sub = getattr(sub_optimizer, 'optimizer', sub_optimizer) + if getattr(raw_sub, 'skip_grad_norm_clip', False): + setattr(sub_optimizer, 'skip_grad_norm_clip', True) + # The CHAINED path (results non-empty) treats ``layer_wise_optimizer`` as a leaf + # and reads the CONTAINER flag, so set it only when every base sub-optimizer is + # orthogonalizing (Muon-only container). In the separate-distopt path the + # container is Muon-only -> True; if any non-Muon sub is inside (legacy path) the + # container must NOT claim skip -- and that path is direct anyway. + if base_optimizers and all( + getattr(o, 'skip_grad_norm_clip', False) for o in base_optimizers + ): + setattr(layer_wise_optimizer, 'skip_grad_norm_clip', True) # LayerWise owns Muon-managed params; DistOpt instances in ``results`` # own the rest. Chain them so the training loop sees one optimizer. if results: diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index e03992e0657..3f22ba67e8e 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -1544,6 +1544,64 @@ def get_grad_norm(self): grad_norm = math.sqrt(sum([x**2 for x in grad_norms])) return grad_norm + @torch.no_grad() + def _get_grad_norm_skip_threshold(self): + """Grad norm used for the ``grad_norm_skip_threshold`` comparison. + + This is the same quantity as :meth:`get_grad_norm` but EXCLUDES the gradients of + sub-optimizers flagged with ``skip_grad_norm_clip`` (orthogonalizing / Muon-family, + see issue #5394). A Muon-managed sub can produce a combined grad norm of order 1e7 + whose magnitude is irrelevant (Newton-Schulz discards it); folding it into the + shared norm would wrongly trip the skip threshold for well-behaved Adam-managed + subs. Excluding those grads gives each non-skipped sub a threshold check against a + norm computed only over the params that are actually magnitude-clipped. + + When no sub is flagged, this returns exactly :meth:`get_grad_norm` (no behavior + change for non-Muon users). The distributed all-reduce semantics of + :func:`get_grad_norm_fp32` are preserved; because ``skip_grad_norm_clip`` is fixed + at construction it is identical across ranks, so the set of collectives issued here + is globally consistent. + """ + # Fast path / no-op-equivalence: if nothing is flagged, reuse get_grad_norm so the + # value (and the collectives issued) are bit-for-bit the existing behavior. + if not any( + getattr(optimizer, 'skip_grad_norm_clip', False) + for optimizer in self.chained_optimizers + ): + return self.get_grad_norm() + + non_skip = [ + optimizer + for optimizer in self.chained_optimizers + if not getattr(optimizer, 'skip_grad_norm_clip', False) + ] + if not non_skip: + # Everything is skip-flagged (e.g. a Muon-only chain). There is nothing to + # threshold; skip-flagged subs are exempt from the threshold check anyway. + return 0.0 + if len(non_skip) == 1: + # Single non-skip sub: defer to its own norm (handles its own group sharedness). + return non_skip[0].get_grad_norm() + + # Determine grad-stats sharedness over the NON-SKIP subs only, not the whole + # container: the container-level get_grad_stats_parallel_group() asserts that ALL + # subs share a group, which fails on the distributed-optimizer path where a Muon + # LayerWise sub and an Adam DistributedOptimizer sub have different grad-stats + # groups. This mirrors get_grad_norm()'s shared / non-shared handling. + try: + ref_group = non_skip[0].get_grad_stats_parallel_group() + shared = all(o.get_grad_stats_parallel_group() == ref_group for o in non_skip) + except AssertionError: + shared = False + if shared: + grads_for_norm = [] + for optimizer in non_skip: + grads_for_norm += optimizer.get_grads_for_grad_norm() + return get_grad_norm_fp32(grads_for_norm, grad_stats_parallel_group=ref_group) + # Non-shared groups: combine per-sub norms (mirrors get_grad_norm()'s fallback). + grad_norms = [o.get_grad_norm() or 0.0 for o in non_skip] + return math.sqrt(sum(x * x for x in grad_norms)) + @torch.no_grad() def count_zeros(self): if self.grads_states_parallel_group_is_shared(): @@ -1631,11 +1689,19 @@ def step(self): return False, None, None grad_norm = self.get_grad_norm() + # Norm used for the grad_norm_skip_threshold comparison only. It excludes the grads + # of skip-flagged (orthogonalizing / Muon) sub-optimizers so a Muon sub's huge but + # meaningless grad magnitude cannot wrongly trip the skip threshold for well-behaved + # Adam subs (see issue #5394). Equals ``grad_norm`` when nothing is flagged. + # NOTE: scoped to the threshold check; the clip ``total_norm`` below still uses the + # full ``grad_norm`` (a separate follow-up may narrow the clip norm too). + threshold_grad_norm = self._get_grad_norm_skip_threshold() should_skip_update = False should_clip = any( not (hasattr(optimizer, 'is_stub_optimizer') and optimizer.is_stub_optimizer) and optimizer.config.clip_grad > 0.0 + and not getattr(optimizer, 'skip_grad_norm_clip', False) for optimizer in self.chained_optimizers ) if should_clip: @@ -1668,7 +1734,15 @@ def step(self): else: main_params.append(p) - if optimizer.config.clip_grad > 0.0: + # Skip magnitude-based gradient clipping for orthogonalizing optimizers + # (e.g. Muon): their update is scale-invariant (Newton-Schulz discards the + # gradient magnitude), so clipping is a no-op at best. At worst, when the + # global ``grad_norm`` is large the tiny clip coefficient pushes per-matrix + # gradients below Newton-Schulz's normalization floor, silently degenerating + # the orthogonalization and stalling training. See issue #5394. + if optimizer.config.clip_grad > 0.0 and not getattr( + optimizer, "skip_grad_norm_clip", False + ): if main_params: clip_grad_by_total_norm_fp32( main_params, @@ -1687,9 +1761,20 @@ def step(self): use_decoupled_grad=use_decoupled_grad, ) - if grad_norm > optimizer.config.grad_norm_skip_threshold and main_params: + # Skip-flagged (Muon) subs are exempt from the magnitude-based skip threshold: + # their grad magnitude is discarded by Newton-Schulz, so it is not a meaningful + # signal for skipping the update. For all other subs, compare against the + # narrower ``threshold_grad_norm`` (which excludes the Muon subs' grads). + if ( + not getattr(optimizer, "skip_grad_norm_clip", False) + and threshold_grad_norm > optimizer.config.grad_norm_skip_threshold + and main_params + ): log_single_rank( - logger, logging.INFO, "skipping grad norm because it's too large %s", grad_norm + logger, + logging.INFO, + "skipping grad norm because it's too large %s", + threshold_grad_norm, ) should_skip_update = True diff --git a/tests/unit_tests/optimizer/test_skip_grad_norm_clip.py b/tests/unit_tests/optimizer/test_skip_grad_norm_clip.py new file mode 100644 index 00000000000..5c7a9564c0d --- /dev/null +++ b/tests/unit_tests/optimizer/test_skip_grad_norm_clip.py @@ -0,0 +1,285 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Targeted tests for the Muon skip-grad-norm-clip PR (#5395, revision 937d8677d). + +Covers the four reviewer findings: + B1 (Bug1, critical): skip_grad_norm_clip is set on the *bf16 wrapper* members of + LayerWiseDistributedOptimizer.chained_optimizers (not only the raw sub), for the + results-empty (Muon-only) DIRECT path; Adam subs in a mix stay unflagged. + B2 (Bug2, high): the flag is gated on isinstance(opt, OrthogonalizedOptimizer) so only + the Muon family is flagged; SOAP/Lion keep clipping. + B3 (Bug3, medium): ChainedOptimizer._get_grad_norm_skip_threshold() excludes flagged + subs' grads so a Muon sub's huge grad cannot trip the skip threshold for an Adam sub. + B4 (Bug4, low): should_clip is False for a Muon-only chain, so _compute_grad_norms_by_group + (and its AllReduce) is not run; a clippable Adam chain still runs it. + +Run: torchrun --nproc_per_node=2 -m pytest tests/unit_tests/optimizer/test_skip_grad_norm_clip.py + with NVIDIA_PYTORCH_VERSION>25.05 set. +""" +import os + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging.version import Version + +from megatron.core import parallel_state +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer +from megatron.core.optimizer.emerging_optimizers import HAVE_EMERGING_OPTIMIZERS +from megatron.core.optimizer.layer_wise_optimizer import LayerWiseDistributedOptimizer +from megatron.core.optimizer.optimizer import ChainedOptimizer +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +if HAVE_EMERGING_OPTIMIZERS: + from emerging_optimizers.orthogonalized_optimizers import OrthogonalizedOptimizer + +pytestmark = [ + pytest.mark.skipif( + Version(os.getenv('NVIDIA_PYTORCH_VERSION', "24.01")) <= Version("25.05"), + reason="Skip emerging optimizer tests for LTS test", + ), + pytest.mark.skipif( + not HAVE_EMERGING_OPTIMIZERS, reason="emerging_optimizers package is not installed" + ), + pytest.mark.skipif( + int(os.getenv('WORLD_SIZE', '1')) == 1, reason="Multi-rank test requires WORLD_SIZE > 1" + ), +] + + +class MuonOnly(nn.Module): + """All-2D-weights (bias=False) so every managed param goes to Muon -> results empty.""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(64, 48, bias=False) + self.fc2 = nn.Linear(48, 32, bias=False) + self.fc3 = nn.Linear(32, 16, bias=False) + + def forward(self, x): + return self.fc3(F.relu(self.fc2(F.relu(self.fc1(x))))) + + +class MuonAdamMix(nn.Module): + """2D weights -> Muon, 1D biases -> Adam (a single bare LayerWise chain of [muon, adam]).""" + + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(64, 48, bias=True) + self.fc2 = nn.Linear(48, 16, bias=True) + + def forward(self, x): + return self.fc2(F.relu(self.fc1(x))) + + +def _inner(sub): + """Unwrap a chained member to the actual torch optimizer (wrapper.optimizer or itself).""" + raw = getattr(sub, 'optimizer', sub) + return getattr(raw, 'optimizer', raw) + + +def _is_orthogonalizing(sub): + return isinstance(_inner(sub), OrthogonalizedOptimizer) + + +@pytest.mark.skipif( + int(os.getenv('WORLD_SIZE', '1')) == 1, reason="Multi-rank test requires WORLD_SIZE > 1" +) +class TestSkipGradNormClip: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + Utils.initialize_model_parallel() + yield + Utils.destroy_model_parallel() + + # ---- builders ----------------------------------------------------------------- + def _build(self, model, optimizer_name, use_layer_wise, clip_grad=1.0): + model = model.bfloat16().cuda() + model.requires_grad_(True) + ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=False) + model = DistributedDataParallel( + TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model + ) + model.broadcast_params() + cfg = OptimizerConfig( + optimizer=optimizer_name, + lr=0.01, + weight_decay=0.01, + bf16=True, + use_distributed_optimizer=False, + clip_grad=clip_grad, + muon_tp_mode="duplicated", + use_layer_wise_distributed_optimizer=use_layer_wise, + ) + pg = ProcessGroupCollection.use_mpu_process_groups() + pg.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) + pg.expt_dp = parallel_state.get_expert_data_parallel_group() + opt = get_megatron_optimizer(cfg, [model], pg_collection=pg, use_gloo_process_groups=False) + return model, opt + + @staticmethod + def _forward_backward(model, batch=8, in_dim=64): + x = torch.randn(batch, in_dim, dtype=torch.bfloat16, device='cuda') + model(x).sum().backward() + + # ================================ B1 (Bug1) ================================ + def test_b1_muon_only_every_member_flagged(self): + """Muon-only -> bare LayerWiseDistributedOptimizer; EVERY chained member (the bf16 + wrapper itself) and the container carry skip_grad_norm_clip is True.""" + _, opt = self._build(MuonOnly(), 'muon', use_layer_wise=True) + assert isinstance(opt, LayerWiseDistributedOptimizer) + # The DIRECT path returns the container; the container flag must be True (all-orthog). + assert getattr(opt, 'skip_grad_norm_clip', False) is True + assert len(opt.chained_optimizers) >= 1 + for i, sub in enumerate(opt.chained_optimizers): + # bf16 re-wraps base optimizers in Float16OptimizerWithFloat16Params; the flag must + # be visible on the WRAPPER (not only the raw sub it forwards to). + assert getattr(sub, 'skip_grad_norm_clip', False) is True, ( + f"member {i} ({type(sub).__name__}) wrapping {type(_inner(sub)).__name__} " + f"is not flagged" + ) + assert _is_orthogonalizing(sub), f"member {i} should be a Muon-family optimizer" + + def test_b1_mix_adam_sub_not_flagged(self): + """Muon+Adam mix -> the Adam sub must NOT be flagged; the Muon sub must be flagged; + the container must NOT be flagged (not all base subs are orthogonalizing).""" + _, opt = self._build(MuonAdamMix(), 'muon', use_layer_wise=True) + assert isinstance(opt, ChainedOptimizer) + flagged = {i: getattr(s, 'skip_grad_norm_clip', False) for i, s in enumerate(opt.chained_optimizers)} + orthog = {i: _is_orthogonalizing(s) for i, s in enumerate(opt.chained_optimizers)} + # exactly the orthogonalizing (Muon) subs are flagged + for i, s in enumerate(opt.chained_optimizers): + assert bool(flagged[i]) == orthog[i], ( + f"member {i} ({type(_inner(s)).__name__}): flagged={flagged[i]} orthog={orthog[i]}" + ) + assert any(orthog.values()), "expected a Muon sub" + assert not all(orthog.values()), "expected an Adam sub in the mix" + # container must not claim skip when a non-orthogonalizing sub is present + assert getattr(opt, 'skip_grad_norm_clip', False) is False + + # ================================ B2 (Bug2) ================================ + def test_b2_muon_flagged_lion_not(self): + """isinstance(OrthogonalizedOptimizer) gate: muon flagged, lion (scalar) not flagged.""" + _, muon_opt = self._build(nn.Linear(64, 32, bias=False), 'muon', use_layer_wise=False) + assert any(getattr(s, 'skip_grad_norm_clip', False) for s in muon_opt.chained_optimizers) + for s in muon_opt.chained_optimizers: + assert getattr(s, 'skip_grad_norm_clip', False) == _is_orthogonalizing(s) + + _, lion_opt = self._build(nn.Linear(64, 32, bias=False), 'lion', use_layer_wise=False) + for s in lion_opt.chained_optimizers: + assert getattr(s, 'skip_grad_norm_clip', False) is False, "Lion must keep clipping" + assert not _is_orthogonalizing(s) + + # ================================ B3 (Bug3) ================================ + def test_b3_threshold_excludes_muon_grad(self): + """_get_grad_norm_skip_threshold() excludes the flagged (Muon) sub's huge grad, so a + well-behaved Adam sub is not wrongly skipped. Contrast with get_grad_norm() (full).""" + model, opt = self._build(MuonAdamMix(), 'muon', use_layer_wise=True, clip_grad=1.0) + assert isinstance(opt, ChainedOptimizer) + # find the Adam sub and give it a finite skip threshold + adam_subs = [s for s in opt.chained_optimizers if not _is_orthogonalizing(s)] + muon_subs = [s for s in opt.chained_optimizers if _is_orthogonalizing(s)] + assert adam_subs and muon_subs + for s in adam_subs: + s.config.grad_norm_skip_threshold = 10.0 + + # populate grads, then make the Muon (2D) grads huge and the Adam (1D) grads tiny + # we reach the model via the optimizer's param groups + self._forward_backward(model) + for p in model.parameters(): + g = p.main_grad if getattr(p, 'main_grad', None) is not None else p.grad + if g is None: + continue + if p.dim() >= 2: + g.fill_(1.0e5) # Muon sub -> huge norm + else: + g.fill_(1.0e-4) # Adam sub -> tiny norm + + opt.prepare_grads() + full = float(opt.get_grad_norm()) + threshold_norm = float(opt._get_grad_norm_skip_threshold()) + if Utils.rank == 0: + print(f"\n[B3] get_grad_norm()={full:.3e} _get_grad_norm_skip_threshold()={threshold_norm:.3e}") + # the skip-threshold norm must be far smaller than the full norm (Muon grad excluded) + assert threshold_norm < full + # and below the Adam sub's finite threshold (10) so the update is NOT skipped, + # whereas the full global norm would have tripped it. + assert threshold_norm < 10.0 < full + + def test_b3_update_not_skipped(self): + """End-to-end: with a huge Muon grad and a finite Adam threshold, step() must NOT skip.""" + model, opt = self._build(MuonAdamMix(), 'muon', use_layer_wise=True, clip_grad=1.0) + for s in opt.chained_optimizers: + if not _is_orthogonalizing(s): + s.config.grad_norm_skip_threshold = 10.0 + self._forward_backward(model) + for p in model.parameters(): + g = p.main_grad if getattr(p, 'main_grad', None) is not None else p.grad + if g is None: + continue + g.fill_(1.0e5 if p.dim() >= 2 else 1.0e-4) + update_successful, grad_norm, _ = opt.step() + assert update_successful is True, "update was wrongly skipped despite small non-Muon norm" + + # ================================ B4 (Bug4) ================================ + def test_b4_muon_only_skips_clip_norm_compute(self): + """Muon-only chain: should_clip is False -> _compute_grad_norms_by_group not called.""" + model, opt = self._build(MuonOnly(), 'muon', use_layer_wise=True, clip_grad=1.0) + self._forward_backward(model) + calls = {'n': 0} + orig = opt._compute_grad_norms_by_group + + def counting(*a, **k): + calls['n'] += 1 + return orig(*a, **k) + + opt._compute_grad_norms_by_group = counting + update_successful, _, _ = opt.step() + assert calls['n'] == 0, "Muon-only chain should not compute per-group clip norms" + assert update_successful is True + + def test_b4_adam_chain_runs_clip_norm_compute(self): + """Clippable Adam chain: should_clip is True -> _compute_grad_norms_by_group is called.""" + model, opt = self._build(nn.Linear(64, 32, bias=True), 'adam', use_layer_wise=False, clip_grad=1.0) + self._forward_backward(model) + calls = {'n': 0} + orig = opt._compute_grad_norms_by_group + + def counting(*a, **k): + calls['n'] += 1 + return orig(*a, **k) + + opt._compute_grad_norms_by_group = counting + opt.step() + assert calls['n'] >= 1, "clippable Adam chain must compute clip norms" + + # ===== distributed-optimizer path: step() must succeed (was the f207dc2-fixed regression) ===== + def test_distopt_path_step_succeeds(self): + """Muon LayerWise chained with an Adam DistributedOptimizer => non-shared grad-stats + groups. _get_grad_norm_skip_threshold() must handle that (per-sub fallback) instead of + asserting a shared group. Regressed in 937d8677d, fixed in f207dc2.""" + from megatron.training.training import wrap_model_chunks_with_ddp + + model = MuonAdamMix().bfloat16().cuda() + model.requires_grad_(True) + ddp_config = DistributedDataParallelConfig() # use_distributed_optimizer=True (default) + model = wrap_model_chunks_with_ddp( + [model], + TransformerConfig(num_attention_heads=1, num_layers=1), + ddp_config, + use_layer_wise_distributed_optimizer=True, + )[0] + cfg = OptimizerConfig( + optimizer='muon', lr=0.01, weight_decay=0.01, bf16=True, clip_grad=1.0, + muon_tp_mode="duplicated", use_layer_wise_distributed_optimizer=True, + ) + pg = ProcessGroupCollection.use_mpu_process_groups() + pg.dp_cp = parallel_state.get_data_parallel_group(with_context_parallel=True) + pg.expt_dp = parallel_state.get_expert_data_parallel_group() + opt = get_megatron_optimizer(cfg, [model], pg_collection=pg, use_gloo_process_groups=False) + self._forward_backward(model) + update_successful, _, _ = opt.step() + assert update_successful is True