Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions megatron/core/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,13 @@
HAVE_EMERGING_OPTIMIZERS = _eo_ver >= (0, 2)

if HAVE_EMERGING_OPTIMIZERS:
from emerging_optimizers.orthogonalized_optimizers import OrthogonalizedOptimizer
from emerging_optimizers.scalar_optimizers import Lion
else:
# Sentinel so ``isinstance(opt, OrthogonalizedOptimizer)`` is always False when the
# package is unavailable. All sites that test it are already guarded by
# HAVE_EMERGING_OPTIMIZERS, so this is only a defensive fallback for import safety.
OrthogonalizedOptimizer = ()

from megatron.core import parallel_state
from megatron.core.optimizer.cpu_offloading.hybrid_optimizer import HybridDeviceOptimizer
Expand Down Expand Up @@ -874,7 +880,21 @@ def _get_megatron_emerging_optimizer(
optimizer, init_state_fn = _create_emerging_optimizer(
config, groups, eopt_name, model_chunks, pg_collection
)
# Only orthogonalizing optimizers (Muon family, i.e. subclasses of
# OrthogonalizedOptimizer) have scale-invariant updates that make magnitude
# grad-norm clipping a no-op-at-best / harmful-at-worst (issue #5394).
# Other emerging optimizers (SOAP, Lion) are NOT scale-invariant and must
# still be clipped, so the flag is gated on this check rather than set for
# every emerging optimizer.
is_orthogonalizing = isinstance(optimizer, OrthogonalizedOptimizer)
if use_layer_wise:
# Mark the raw (unwrapped) Muon sub-optimizer so the flag is visible in
# the layer-wise DIRECT path. NOTE: LayerWiseDistributedOptimizer re-wraps
# base optimizers in Float16OptimizerWithFloat16Params when config.bf16 is
# True, so this flag is *also* re-propagated onto the actual
# ``chained_optimizers`` members after construction below.
if is_orthogonalizing:
setattr(optimizer, 'skip_grad_norm_clip', True)
layer_wise_base_results.append((optimizer, init_state_fn))
continue
if config.bf16:
Expand All @@ -884,6 +904,11 @@ def _get_megatron_emerging_optimizer(
else:
optimizer = FP32Optimizer(optimizer, config, init_state_fn)
setattr(optimizer, 'grad_stats_parallel_group', model_parallel_group)
# Orthogonalizing optimizers (Muon) have scale-invariant updates; magnitude
# gradient clipping is a no-op at best and harmful at worst (see issue #5394).
# Gate on is_orthogonalizing so SOAP/Lion (non-scale-invariant) keep clipping.
if is_orthogonalizing:
setattr(optimizer, 'skip_grad_norm_clip', True)
if pg_collection is None or not hasattr(pg_collection, 'tp'):
tp_group = parallel_state.get_tensor_model_parallel_group()
else:
Expand Down Expand Up @@ -963,6 +988,28 @@ def _get_megatron_emerging_optimizer(
init_state_fn_list=list(init_fns),
model_chunks=model_chunks,
)
# Re-propagate the skip flag onto the actual chained sub-optimizers. LayerWise
# re-wraps each base optimizer in Float16OptimizerWithFloat16Params when
# config.bf16 is True, so the flag set on the raw Muon sub-optimizer above is
# NOT visible on ``layer_wise_optimizer.chained_optimizers`` (the wrappers do not
# forward attribute access). The DIRECT path returns ``layer_wise_optimizer`` and
# runs the inherited ``ChainedOptimizer.step()`` over these inner subs, reading
# their per-sub flag. Carry the flag from each member's underlying raw optimizer
# (``.optimizer`` on a wrapper, else the member itself) so Muon subs stay flagged
# while non-emerging (Adam) subs stay unflagged.
for sub_optimizer in layer_wise_optimizer.chained_optimizers:
raw_sub = getattr(sub_optimizer, 'optimizer', sub_optimizer)
if getattr(raw_sub, 'skip_grad_norm_clip', False):
setattr(sub_optimizer, 'skip_grad_norm_clip', True)
# The CHAINED path (results non-empty) treats ``layer_wise_optimizer`` as a leaf
# and reads the CONTAINER flag, so set it only when every base sub-optimizer is
# orthogonalizing (Muon-only container). In the separate-distopt path the
# container is Muon-only -> True; if any non-Muon sub is inside (legacy path) the
# container must NOT claim skip -- and that path is direct anyway.
if base_optimizers and all(
getattr(o, 'skip_grad_norm_clip', False) for o in base_optimizers
):
setattr(layer_wise_optimizer, 'skip_grad_norm_clip', True)
# LayerWise owns Muon-managed params; DistOpt instances in ``results``
# own the rest. Chain them so the training loop sees one optimizer.
if results:
Expand Down
91 changes: 88 additions & 3 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,64 @@ def get_grad_norm(self):
grad_norm = math.sqrt(sum([x**2 for x in grad_norms]))
return grad_norm

@torch.no_grad()
def _get_grad_norm_skip_threshold(self):
"""Grad norm used for the ``grad_norm_skip_threshold`` comparison.

This is the same quantity as :meth:`get_grad_norm` but EXCLUDES the gradients of
sub-optimizers flagged with ``skip_grad_norm_clip`` (orthogonalizing / Muon-family,
see issue #5394). A Muon-managed sub can produce a combined grad norm of order 1e7
whose magnitude is irrelevant (Newton-Schulz discards it); folding it into the
shared norm would wrongly trip the skip threshold for well-behaved Adam-managed
subs. Excluding those grads gives each non-skipped sub a threshold check against a
norm computed only over the params that are actually magnitude-clipped.

When no sub is flagged, this returns exactly :meth:`get_grad_norm` (no behavior
change for non-Muon users). The distributed all-reduce semantics of
:func:`get_grad_norm_fp32` are preserved; because ``skip_grad_norm_clip`` is fixed
at construction it is identical across ranks, so the set of collectives issued here
is globally consistent.
"""
# Fast path / no-op-equivalence: if nothing is flagged, reuse get_grad_norm so the
# value (and the collectives issued) are bit-for-bit the existing behavior.
if not any(
getattr(optimizer, 'skip_grad_norm_clip', False)
for optimizer in self.chained_optimizers
):
return self.get_grad_norm()

non_skip = [
optimizer
for optimizer in self.chained_optimizers
if not getattr(optimizer, 'skip_grad_norm_clip', False)
]
if not non_skip:
# Everything is skip-flagged (e.g. a Muon-only chain). There is nothing to
# threshold; skip-flagged subs are exempt from the threshold check anyway.
return 0.0
if len(non_skip) == 1:
# Single non-skip sub: defer to its own norm (handles its own group sharedness).
return non_skip[0].get_grad_norm()

# Determine grad-stats sharedness over the NON-SKIP subs only, not the whole
# container: the container-level get_grad_stats_parallel_group() asserts that ALL
# subs share a group, which fails on the distributed-optimizer path where a Muon
# LayerWise sub and an Adam DistributedOptimizer sub have different grad-stats
# groups. This mirrors get_grad_norm()'s shared / non-shared handling.
try:
ref_group = non_skip[0].get_grad_stats_parallel_group()
shared = all(o.get_grad_stats_parallel_group() == ref_group for o in non_skip)
except AssertionError:
shared = False
if shared:
grads_for_norm = []
for optimizer in non_skip:
grads_for_norm += optimizer.get_grads_for_grad_norm()
return get_grad_norm_fp32(grads_for_norm, grad_stats_parallel_group=ref_group)
# Non-shared groups: combine per-sub norms (mirrors get_grad_norm()'s fallback).
grad_norms = [o.get_grad_norm() or 0.0 for o in non_skip]
return math.sqrt(sum(x * x for x in grad_norms))

@torch.no_grad()
def count_zeros(self):
if self.grads_states_parallel_group_is_shared():
Expand Down Expand Up @@ -1631,11 +1689,19 @@ def step(self):
return False, None, None

grad_norm = self.get_grad_norm()
# Norm used for the grad_norm_skip_threshold comparison only. It excludes the grads
# of skip-flagged (orthogonalizing / Muon) sub-optimizers so a Muon sub's huge but
# meaningless grad magnitude cannot wrongly trip the skip threshold for well-behaved
# Adam subs (see issue #5394). Equals ``grad_norm`` when nothing is flagged.
# NOTE: scoped to the threshold check; the clip ``total_norm`` below still uses the
# full ``grad_norm`` (a separate follow-up may narrow the clip norm too).
threshold_grad_norm = self._get_grad_norm_skip_threshold()
should_skip_update = False

should_clip = any(
not (hasattr(optimizer, 'is_stub_optimizer') and optimizer.is_stub_optimizer)
and optimizer.config.clip_grad > 0.0
and not getattr(optimizer, 'skip_grad_norm_clip', False)
for optimizer in self.chained_optimizers
)
if should_clip:
Expand Down Expand Up @@ -1668,7 +1734,15 @@ def step(self):
else:
main_params.append(p)

if optimizer.config.clip_grad > 0.0:
# Skip magnitude-based gradient clipping for orthogonalizing optimizers
# (e.g. Muon): their update is scale-invariant (Newton-Schulz discards the
# gradient magnitude), so clipping is a no-op at best. At worst, when the
# global ``grad_norm`` is large the tiny clip coefficient pushes per-matrix
# gradients below Newton-Schulz's normalization floor, silently degenerating
# the orthogonalization and stalling training. See issue #5394.
if optimizer.config.clip_grad > 0.0 and not getattr(
optimizer, "skip_grad_norm_clip", False
):
if main_params:
clip_grad_by_total_norm_fp32(
main_params,
Expand All @@ -1687,9 +1761,20 @@ def step(self):
use_decoupled_grad=use_decoupled_grad,
)

if grad_norm > optimizer.config.grad_norm_skip_threshold and main_params:
# Skip-flagged (Muon) subs are exempt from the magnitude-based skip threshold:
# their grad magnitude is discarded by Newton-Schulz, so it is not a meaningful
# signal for skipping the update. For all other subs, compare against the
# narrower ``threshold_grad_norm`` (which excludes the Muon subs' grads).
if (
not getattr(optimizer, "skip_grad_norm_clip", False)
and threshold_grad_norm > optimizer.config.grad_norm_skip_threshold
and main_params
):
log_single_rank(
logger, logging.INFO, "skipping grad norm because it's too large %s", grad_norm
logger,
logging.INFO,
"skipping grad norm because it's too large %s",
threshold_grad_norm,
)
should_skip_update = True

Expand Down
Loading