Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions emerging_optimizers/orthogonalized_optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.weight_update_hooks import WeightUpdateHook


__all__ = ["Muon", "get_muon_scale_factor"]
Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(
scale_mode: MuonScaleT = "spectral",
extra_scale_factor: float = 1.0,
use_syrk: bool = False,
weight_update_hook: WeightUpdateHook | None = None,
) -> None:
if num_ns_steps < 1:
raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}")
Expand Down Expand Up @@ -127,6 +129,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
weight_update_hook=weight_update_hook,
)


Expand Down
46 changes: 10 additions & 36 deletions emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, override
from typing import Any

import torch

from emerging_optimizers import registry
from emerging_optimizers.orthogonalized_optimizers import muon
from emerging_optimizers.weight_update_hooks import Hyperball


__all__ = ["MuonHyperball"]
Expand Down Expand Up @@ -63,8 +64,12 @@ def __init__(
hyperball_radius: float | None = None,
**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
if "weight_update_hook" in kwargs:
raise TypeError(
"MuonHyperball does not accept a 'weight_update_hook' argument; "
"it manages its own Hyperball hook internally."
)
kwargs["weight_update_hook"] = Hyperball(radius=hyperball_radius, eps=hyperball_eps)
super().__init__(*args, **kwargs)
Comment thread
mkhona-nvidia marked this conversation as resolved.

# Validate and optionally rescale parameters based on hyperball_radius.
Expand All @@ -78,36 +83,5 @@ def __init__(
"MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm."
)
# Rescale parameter to have the specified radius if provided.
if self.hyperball_radius is not None:
p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps))

@override
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Store the original weight norm and normalize the update using Frobenius norm.

Args:
p: The parameter tensor.
update: The orthogonalized gradient tensor.
"""
# Use user-specified radius or compute R = ||W_t||_F (Frobenius norm)
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
self.state[p]["hyperball_R"] = R

# Normalize the update in-place and scale by R
# This modifies update to be: R * normalize(update) using Frobenius norm.
update_norm = update.norm().clamp_min(self.hyperball_eps)
update.mul_(R / update_norm)

@override
def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None:
"""Normalize the updated weights and scale back to original norm using Frobenius norm.

Args:
p: The parameter tensor (already updated).
"""
# Retrieve R from per-parameter state
R = self.state[p]["hyperball_R"]

# Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm.
p_norm = p.norm().clamp_min(self.hyperball_eps)
p.mul_(R / p_norm)
if hyperball_radius is not None:
p.mul_(hyperball_radius / p_norm.clamp_min(hyperball_eps))
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import utils
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.weight_update_hooks import NoOpWeightUpdateHook, WeightUpdateHook


_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups
Expand All @@ -37,6 +38,7 @@
weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin`
for more details.
fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
weight_update_hook: Optional hook that runs around the final in-place weight update.
"""


Expand Down Expand Up @@ -103,6 +105,7 @@ def __init__(
weight_decay_method: opt_mixin.WeightDecayT,
fp32_matmul_prec: FP32MatmulPrecT,
scaled_orthogonalize_fn: Callable | None = None,
weight_update_hook: WeightUpdateHook | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it private, _weight_update_hook, to suggest it is not supposed to be modified after initialization.

Hook easily invites abuse. making it private at least making people aware of the abusing. A setter can be provided if we want to support properly change the hook inflight.

**kwargs: Any,
):
if scaled_orthogonalize_fn is None:
Expand All @@ -112,6 +115,7 @@ def __init__(
self.fp32_matmul_prec = fp32_matmul_prec
self.nesterov = nesterov
self.weight_decay_method = weight_decay_method
self.weight_update_hook = weight_update_hook if weight_update_hook is not None else NoOpWeightUpdateHook()

default_args_dict = dict(
lr=lr,
Expand Down Expand Up @@ -195,8 +199,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None:

# perform weight update with pre and post weight update functions for subclass customization
self.pre_weight_update_fn_inplace(p, orth_grad)
weight_update_hook_pre_update_state = self.weight_update_hook.pre_weight_update_inplace(p, orth_grad)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is too long, two update in the name.

p.add_(orth_grad, alpha=-group["lr"])
self.post_weight_update_fn_inplace(p)
self.weight_update_hook.post_weight_update_inplace(p, weight_update_hook_pre_update_state)

return None

Expand Down
25 changes: 25 additions & 0 deletions emerging_optimizers/weight_update_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from emerging_optimizers.weight_update_hooks.base import NoOpWeightUpdateHook, WeightUpdateHook
from emerging_optimizers.weight_update_hooks.hyperball import Hyperball
from emerging_optimizers.weight_update_hooks.radial_brake import RadialBrake


__all__ = [
"Hyperball",
"NoOpWeightUpdateHook",
"RadialBrake",
"WeightUpdateHook",
]
56 changes: 56 additions & 0 deletions emerging_optimizers/weight_update_hooks/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Protocol

import torch


__all__ = ["NoOpWeightUpdateHook", "WeightUpdateHook"]


class WeightUpdateHook(Protocol):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mkhona-nvidia I don't oppose using Protocol. But I think need to get yourself comfortable with PEP544 before this can be merged.

"""Protocol for behavior around an optimizer's final in-place weight update."""

def pre_weight_update_inplace(
self,
p: torch.Tensor,
update: torch.Tensor,
) -> torch.Tensor | None:
"""Called immediately before ``p.add_(update, alpha=-lr)`` and returns pre-update state."""

def post_weight_update_inplace(
self,
p: torch.Tensor,
pre_update_state: torch.Tensor | None,
) -> None:
"""Called after the optimizer's final update and optimizer-specific post-update hook."""


class NoOpWeightUpdateHook:
"""Default hook that leaves the optimizer update unchanged."""

def pre_weight_update_inplace(
self,
p: torch.Tensor,
update: torch.Tensor,
) -> None:
return None

def post_weight_update_inplace(
self,
p: torch.Tensor,
pre_update_state: torch.Tensor | None,
) -> None:
pass
63 changes: 63 additions & 0 deletions emerging_optimizers/weight_update_hooks/hyperball.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


__all__ = ["Hyperball"]


class Hyperball:
"""Normalize update and post-update weights to a fixed Frobenius norm.

This hook mirrors the hyperball-style behavior used by MuonHyperball: before the weight update, normalize the
update to the target radius; after the weight update, project the parameter back to that radius.
"""

def __init__(
self,
radius: float | None = None,
eps: float = 1e-8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reduce this to 1e-15?

) -> None:
self.radius = radius
self.eps = eps

def pre_weight_update_inplace(
self,
p: torch.Tensor,
update: torch.Tensor,
) -> torch.Tensor:
current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detach is not necessary as all of our optimizers are wrapped in no_grad.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, vector_norm has an dtype argument, explicit to fp32 is not necessary.

It technically not dtype but compute type though, don't know who added it to pytorch.


if self.radius is not None:
radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32)
else:
if current_norm.item() == 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This triggers synced device to host copy. Simply using (current_norm == 0).all().

raise ValueError("Hyperball requires all parameters to have non-zero norm when radius is not fixed.")
radius = current_norm

update_norm = torch.linalg.vector_norm(update.to(torch.float32)).clamp_min(self.eps)
update.mul_((radius / update_norm).to(dtype=update.dtype))
return radius

def post_weight_update_inplace(
self,
p: torch.Tensor,
pre_update_state: torch.Tensor | None,
) -> None:
if pre_update_state is None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: why does it allow None and only to raise error? making it non optional is not sufficient?

raise RuntimeError("Hyperball requires radius state")
radius = pre_update_state
post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)).clamp_min(self.eps)
p.mul_((radius / post_norm).to(dtype=p.dtype))
67 changes: 67 additions & 0 deletions emerging_optimizers/weight_update_hooks/radial_brake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


__all__ = ["RadialBrake"]


class RadialBrake:
"""Dampen radial norm changes after an optimizer update.

The optimizer first applies its usual update ``w = w_prev + dw``. This hook then rescales ``w`` so that

.. math::

\\|w_{brake}\\| = \\|w_{prev}\\| + s(\\|w\\| - \\|w_{prev}\\|)

where ``s`` is ``outward_scale_factor`` when the update increases the norm, otherwise
``inward_scale_factor``.
"""

def __init__(
self,
outward_scale_factor: float = 0.5,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scale_factor looks redundant, would outward_scale be ambiguous? if not, use the short form

inward_scale_factor: float = 1.0,
eps: float = 1e-12,
) -> None:
if not 0.0 <= outward_scale_factor <= 1.0:
raise ValueError(f"outward_scale_factor must be in [0, 1], got {outward_scale_factor}")
if not 0.0 <= inward_scale_factor <= 1.0:
raise ValueError(f"inward_scale_factor must be in [0, 1], got {inward_scale_factor}")
self.outward_scale_factor = outward_scale_factor
self.inward_scale_factor = inward_scale_factor
self.eps = eps

def pre_weight_update_inplace(
self,
p: torch.Tensor,
update: torch.Tensor,
) -> torch.Tensor:
return torch.linalg.vector_norm(p.detach().to(torch.float32))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why eps is not used here?


def post_weight_update_inplace(
self,
p: torch.Tensor,
pre_update_state: torch.Tensor | None,
) -> None:
if pre_update_state is None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question as before, what's the reason to allow None and only to raise error?

raise RuntimeError("RadialBrake requires pre-update norm state")
pre_norm = pre_update_state
post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't detach.

norm_delta = post_norm - pre_norm
scale_factor = self.outward_scale_factor if norm_delta.item() > 0 else self.inward_scale_factor

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before, don't use .item().

target_norm = pre_norm + scale_factor * norm_delta
p.mul_((target_norm / post_norm.clamp_min(self.eps)).to(dtype=p.dtype))
Comment thread
mkhona-nvidia marked this conversation as resolved.
11 changes: 11 additions & 0 deletions tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,17 @@ def test_zero_norm_raises_error(self) -> None:
with self.assertRaises(ValueError):
muon_hyperball.MuonHyperball([test_param], lr=0.01)

def test_rejects_weight_update_hook(self) -> None:
"""MuonHyperball manages its own Hyperball hook internally."""
test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device))

with self.assertRaisesRegex(TypeError, "does not accept a 'weight_update_hook' argument"):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better to be a KeyError,

muon_hyperball.MuonHyperball(
[test_param],
lr=0.01,
weight_update_hook=None,
)


class PolarGradTest(parameterized.TestCase):
def setUp(self):
Expand Down
Loading
Loading