diff --git a/emerging_optimizers/orthogonalized_optimizers/muon.py b/emerging_optimizers/orthogonalized_optimizers/muon.py index e962239..403bfa2 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon.py @@ -25,6 +25,7 @@ from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc from emerging_optimizers.utils import FP32MatmulPrecT +from emerging_optimizers.weight_update_hooks import WeightUpdateHook __all__ = ["Muon", "get_muon_scale_factor"] @@ -86,6 +87,7 @@ def __init__( scale_mode: MuonScaleT = "spectral", extra_scale_factor: float = 1.0, use_syrk: bool = False, + weight_update_hook: WeightUpdateHook | None = None, ) -> None: if num_ns_steps < 1: raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") @@ -127,6 +129,7 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: weight_decay_method=weight_decay_method, fp32_matmul_prec=fp32_matmul_prec, scaled_orthogonalize_fn=scaled_orthogonalize_fn, + weight_update_hook=weight_update_hook, ) diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index c9a8c09..62acdbb 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, override +from typing import Any import torch from emerging_optimizers import registry from emerging_optimizers.orthogonalized_optimizers import muon +from emerging_optimizers.weight_update_hooks import Hyperball __all__ = ["MuonHyperball"] @@ -63,8 +64,12 @@ def __init__( hyperball_radius: float | None = None, **kwargs: Any, ) -> None: - self.hyperball_eps = hyperball_eps - self.hyperball_radius = hyperball_radius + if "weight_update_hook" in kwargs: + raise TypeError( + "MuonHyperball does not accept a 'weight_update_hook' argument; " + "it manages its own Hyperball hook internally." + ) + kwargs["weight_update_hook"] = Hyperball(radius=hyperball_radius, eps=hyperball_eps) super().__init__(*args, **kwargs) # Validate and optionally rescale parameters based on hyperball_radius. @@ -78,36 +83,5 @@ def __init__( "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." ) # Rescale parameter to have the specified radius if provided. - if self.hyperball_radius is not None: - p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps)) - - @override - def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: - """Store the original weight norm and normalize the update using Frobenius norm. - - Args: - p: The parameter tensor. - update: The orthogonalized gradient tensor. - """ - # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) - R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() - self.state[p]["hyperball_R"] = R - - # Normalize the update in-place and scale by R - # This modifies update to be: R * normalize(update) using Frobenius norm. - update_norm = update.norm().clamp_min(self.hyperball_eps) - update.mul_(R / update_norm) - - @override - def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: - """Normalize the updated weights and scale back to original norm using Frobenius norm. - - Args: - p: The parameter tensor (already updated). - """ - # Retrieve R from per-parameter state - R = self.state[p]["hyperball_R"] - - # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. - p_norm = p.norm().clamp_min(self.hyperball_eps) - p.mul_(R / p_norm) + if hyperball_radius is not None: + p.mul_(hyperball_radius / p_norm.clamp_min(hyperball_eps)) diff --git a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py index e6ddda5..c41f1d6 100644 --- a/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py +++ b/emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py @@ -26,6 +26,7 @@ from emerging_optimizers import mixin as opt_mixin from emerging_optimizers import utils from emerging_optimizers.utils import FP32MatmulPrecT +from emerging_optimizers.weight_update_hooks import NoOpWeightUpdateHook, WeightUpdateHook _args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups @@ -37,6 +38,7 @@ weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` for more details. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. + weight_update_hook: Optional hook that runs around the final in-place weight update. """ @@ -103,6 +105,7 @@ def __init__( weight_decay_method: opt_mixin.WeightDecayT, fp32_matmul_prec: FP32MatmulPrecT, scaled_orthogonalize_fn: Callable | None = None, + weight_update_hook: WeightUpdateHook | None = None, **kwargs: Any, ): if scaled_orthogonalize_fn is None: @@ -112,6 +115,7 @@ def __init__( self.fp32_matmul_prec = fp32_matmul_prec self.nesterov = nesterov self.weight_decay_method = weight_decay_method + self.weight_update_hook = weight_update_hook if weight_update_hook is not None else NoOpWeightUpdateHook() default_args_dict = dict( lr=lr, @@ -195,8 +199,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # perform weight update with pre and post weight update functions for subclass customization self.pre_weight_update_fn_inplace(p, orth_grad) + weight_update_hook_pre_update_state = self.weight_update_hook.pre_weight_update_inplace(p, orth_grad) p.add_(orth_grad, alpha=-group["lr"]) self.post_weight_update_fn_inplace(p) + self.weight_update_hook.post_weight_update_inplace(p, weight_update_hook_pre_update_state) return None diff --git a/emerging_optimizers/weight_update_hooks/__init__.py b/emerging_optimizers/weight_update_hooks/__init__.py new file mode 100644 index 0000000..e984114 --- /dev/null +++ b/emerging_optimizers/weight_update_hooks/__init__.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from emerging_optimizers.weight_update_hooks.base import NoOpWeightUpdateHook, WeightUpdateHook +from emerging_optimizers.weight_update_hooks.hyperball import Hyperball +from emerging_optimizers.weight_update_hooks.radial_brake import RadialBrake + + +__all__ = [ + "Hyperball", + "NoOpWeightUpdateHook", + "RadialBrake", + "WeightUpdateHook", +] diff --git a/emerging_optimizers/weight_update_hooks/base.py b/emerging_optimizers/weight_update_hooks/base.py new file mode 100644 index 0000000..dae5459 --- /dev/null +++ b/emerging_optimizers/weight_update_hooks/base.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Protocol + +import torch + + +__all__ = ["NoOpWeightUpdateHook", "WeightUpdateHook"] + + +class WeightUpdateHook(Protocol): + """Protocol for behavior around an optimizer's final in-place weight update.""" + + def pre_weight_update_inplace( + self, + p: torch.Tensor, + update: torch.Tensor, + ) -> torch.Tensor | None: + """Called immediately before ``p.add_(update, alpha=-lr)`` and returns pre-update state.""" + + def post_weight_update_inplace( + self, + p: torch.Tensor, + pre_update_state: torch.Tensor | None, + ) -> None: + """Called after the optimizer's final update and optimizer-specific post-update hook.""" + + +class NoOpWeightUpdateHook: + """Default hook that leaves the optimizer update unchanged.""" + + def pre_weight_update_inplace( + self, + p: torch.Tensor, + update: torch.Tensor, + ) -> None: + return None + + def post_weight_update_inplace( + self, + p: torch.Tensor, + pre_update_state: torch.Tensor | None, + ) -> None: + pass diff --git a/emerging_optimizers/weight_update_hooks/hyperball.py b/emerging_optimizers/weight_update_hooks/hyperball.py new file mode 100644 index 0000000..c1b0ef4 --- /dev/null +++ b/emerging_optimizers/weight_update_hooks/hyperball.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +__all__ = ["Hyperball"] + + +class Hyperball: + """Normalize update and post-update weights to a fixed Frobenius norm. + + This hook mirrors the hyperball-style behavior used by MuonHyperball: before the weight update, normalize the + update to the target radius; after the weight update, project the parameter back to that radius. + """ + + def __init__( + self, + radius: float | None = None, + eps: float = 1e-8, + ) -> None: + self.radius = radius + self.eps = eps + + def pre_weight_update_inplace( + self, + p: torch.Tensor, + update: torch.Tensor, + ) -> torch.Tensor: + current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) + + if self.radius is not None: + radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32) + else: + if current_norm.item() == 0: + raise ValueError("Hyperball requires all parameters to have non-zero norm when radius is not fixed.") + radius = current_norm + + update_norm = torch.linalg.vector_norm(update.to(torch.float32)).clamp_min(self.eps) + update.mul_((radius / update_norm).to(dtype=update.dtype)) + return radius + + def post_weight_update_inplace( + self, + p: torch.Tensor, + pre_update_state: torch.Tensor | None, + ) -> None: + if pre_update_state is None: + raise RuntimeError("Hyperball requires radius state") + radius = pre_update_state + post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)).clamp_min(self.eps) + p.mul_((radius / post_norm).to(dtype=p.dtype)) diff --git a/emerging_optimizers/weight_update_hooks/radial_brake.py b/emerging_optimizers/weight_update_hooks/radial_brake.py new file mode 100644 index 0000000..96e3d73 --- /dev/null +++ b/emerging_optimizers/weight_update_hooks/radial_brake.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +__all__ = ["RadialBrake"] + + +class RadialBrake: + """Dampen radial norm changes after an optimizer update. + + The optimizer first applies its usual update ``w = w_prev + dw``. This hook then rescales ``w`` so that + + .. math:: + + \\|w_{brake}\\| = \\|w_{prev}\\| + s(\\|w\\| - \\|w_{prev}\\|) + + where ``s`` is ``outward_scale_factor`` when the update increases the norm, otherwise + ``inward_scale_factor``. + """ + + def __init__( + self, + outward_scale_factor: float = 0.5, + inward_scale_factor: float = 1.0, + eps: float = 1e-12, + ) -> None: + if not 0.0 <= outward_scale_factor <= 1.0: + raise ValueError(f"outward_scale_factor must be in [0, 1], got {outward_scale_factor}") + if not 0.0 <= inward_scale_factor <= 1.0: + raise ValueError(f"inward_scale_factor must be in [0, 1], got {inward_scale_factor}") + self.outward_scale_factor = outward_scale_factor + self.inward_scale_factor = inward_scale_factor + self.eps = eps + + def pre_weight_update_inplace( + self, + p: torch.Tensor, + update: torch.Tensor, + ) -> torch.Tensor: + return torch.linalg.vector_norm(p.detach().to(torch.float32)) + + def post_weight_update_inplace( + self, + p: torch.Tensor, + pre_update_state: torch.Tensor | None, + ) -> None: + if pre_update_state is None: + raise RuntimeError("RadialBrake requires pre-update norm state") + pre_norm = pre_update_state + post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) + norm_delta = post_norm - pre_norm + scale_factor = self.outward_scale_factor if norm_delta.item() > 0 else self.inward_scale_factor + target_norm = pre_norm + scale_factor * norm_delta + p.mul_((target_norm / post_norm.clamp_min(self.eps)).to(dtype=p.dtype)) diff --git a/tests/test_orthogonalized_optimizer.py b/tests/test_orthogonalized_optimizer.py index 252252a..b8e5c4b 100644 --- a/tests/test_orthogonalized_optimizer.py +++ b/tests/test_orthogonalized_optimizer.py @@ -434,6 +434,17 @@ def test_zero_norm_raises_error(self) -> None: with self.assertRaises(ValueError): muon_hyperball.MuonHyperball([test_param], lr=0.01) + def test_rejects_weight_update_hook(self) -> None: + """MuonHyperball manages its own Hyperball hook internally.""" + test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device)) + + with self.assertRaisesRegex(TypeError, "does not accept a 'weight_update_hook' argument"): + muon_hyperball.MuonHyperball( + [test_param], + lr=0.01, + weight_update_hook=None, + ) + class PolarGradTest(parameterized.TestCase): def setUp(self): diff --git a/tests/test_weight_update_hooks.py b/tests/test_weight_update_hooks.py new file mode 100644 index 0000000..763701d --- /dev/null +++ b/tests/test_weight_update_hooks.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from absl import flags, logging +from absl.testing import absltest + +from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer +from emerging_optimizers.weight_update_hooks import Hyperball, NoOpWeightUpdateHook, RadialBrake + + +flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") +flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") +FLAGS = flags.FLAGS + + +def setUpModule() -> None: + if FLAGS.seed is not None: + logging.info("Setting random seed to %d", FLAGS.seed) + torch.manual_seed(FLAGS.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(FLAGS.seed) + + +class WeightUpdateHooksTest(absltest.TestCase): + def setUp(self) -> None: + super().setUp() + self.device = FLAGS.device + + def test_no_op_hook_leaves_update_and_param_unchanged(self) -> None: + hook = NoOpWeightUpdateHook() + param = torch.tensor([3.0, 4.0], device=self.device) + update = torch.tensor([1.0, -2.0], device=self.device) + param_before = param.clone() + update_before = update.clone() + + pre_update_state = hook.pre_weight_update_inplace(param, update) + hook.post_weight_update_inplace(param, pre_update_state) + + torch.testing.assert_close(param, param_before, atol=0.0, rtol=0.0) + torch.testing.assert_close(update, update_before, atol=0.0, rtol=0.0) + + def test_radial_brake_dampens_outward_norm_change(self) -> None: + hook = RadialBrake(outward_scale_factor=0.5, inward_scale_factor=1.0) + param = torch.tensor([3.0, 4.0], device=self.device) + update = torch.tensor([3.0, 4.0], device=self.device) + + pre_update_state = hook.pre_weight_update_inplace(param, update) + param.add_(update) + hook.post_weight_update_inplace(param, pre_update_state) + + torch.testing.assert_close(torch.linalg.vector_norm(param), torch.tensor(7.5, device=self.device)) + + def test_radial_brake_dampens_inward_norm_change(self) -> None: + hook = RadialBrake(outward_scale_factor=1.0, inward_scale_factor=0.2) + param = torch.tensor([6.0, 8.0], device=self.device) + update = torch.tensor([-3.0, -4.0], device=self.device) + + pre_update_state = hook.pre_weight_update_inplace(param, update) + param.add_(update) + hook.post_weight_update_inplace(param, pre_update_state) + + torch.testing.assert_close(torch.linalg.vector_norm(param), torch.tensor(9.0, device=self.device)) + + def test_radial_brake_rejects_amplifying_scale_factors(self) -> None: + with self.assertRaisesRegex(ValueError, "outward_scale_factor"): + RadialBrake(outward_scale_factor=1.1) + with self.assertRaisesRegex(ValueError, "inward_scale_factor"): + RadialBrake(inward_scale_factor=1.1) + + def test_hyperball_normalizes_update_and_final_weight_norm(self) -> None: + hook = Hyperball() + param = torch.tensor([3.0, 4.0], device=self.device) + update = torch.tensor([0.0, 10.0], device=self.device) + + pre_update_state = hook.pre_weight_update_inplace(param, update) + torch.testing.assert_close(torch.linalg.vector_norm(update), torch.tensor(5.0, device=self.device)) + + param.add_(update, alpha=-1.0) + hook.post_weight_update_inplace(param, pre_update_state) + + torch.testing.assert_close(torch.linalg.vector_norm(param), torch.tensor(5.0, device=self.device)) + + def test_hyperball_fixed_radius_allows_zero_norm_param(self) -> None: + hook = Hyperball(radius=2.0) + param = torch.zeros(2, device=self.device) + update = torch.tensor([0.0, 3.0], device=self.device) + + pre_update_state = hook.pre_weight_update_inplace(param, update) + param.add_(update, alpha=-1.0) + hook.post_weight_update_inplace(param, pre_update_state) + + torch.testing.assert_close(torch.linalg.vector_norm(param), torch.tensor(2.0, device=self.device)) + + def test_hyperball_dynamic_radius_rejects_zero_norm_param(self) -> None: + hook = Hyperball() + param = torch.zeros(2, device=self.device) + update = torch.tensor([0.0, 3.0], device=self.device) + + with self.assertRaisesRegex(ValueError, "when radius is not fixed"): + hook.pre_weight_update_inplace(param, update) + + def test_orthogonalized_optimizer_applies_weight_update_hook(self) -> None: + param = torch.tensor([[3.0, 4.0]], device=self.device) + param.grad = torch.tensor([[3.0, 4.0]], device=self.device) + optimizer = OrthogonalizedOptimizer( + [param], + lr=-1.0, + momentum=0.0, + weight_decay=0.0, + nesterov=False, + weight_decay_method="l2", + fp32_matmul_prec="highest", + scaled_orthogonalize_fn=torch.nn.Identity(), + weight_update_hook=RadialBrake(outward_scale_factor=0.5), + ) + + optimizer.step() + + torch.testing.assert_close(torch.linalg.vector_norm(param), torch.tensor(7.5, device=self.device)) + + +if __name__ == "__main__": + absltest.main()