-
Notifications
You must be signed in to change notification settings - Fork 33
Refactor of pre/post weight update hook functions #224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,6 +26,7 @@ | |
| from emerging_optimizers import mixin as opt_mixin | ||
| from emerging_optimizers import utils | ||
| from emerging_optimizers.utils import FP32MatmulPrecT | ||
| from emerging_optimizers.weight_update_hooks import NoOpWeightUpdateHook, WeightUpdateHook | ||
|
|
||
|
|
||
| _args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups | ||
|
|
@@ -37,6 +38,7 @@ | |
| weight_decay_method: Method to apply weight decay, see :class:`~emerging_optimizers.mixin.WeightDecayMixin` | ||
| for more details. | ||
| fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. | ||
| weight_update_hook: Optional hook that runs around the final in-place weight update. | ||
| """ | ||
|
|
||
|
|
||
|
|
@@ -103,6 +105,7 @@ def __init__( | |
| weight_decay_method: opt_mixin.WeightDecayT, | ||
| fp32_matmul_prec: FP32MatmulPrecT, | ||
| scaled_orthogonalize_fn: Callable | None = None, | ||
| weight_update_hook: WeightUpdateHook | None = None, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make it private, Hook easily invites abuse. making it private at least making people aware of the abusing. A setter can be provided if we want to support properly change the hook inflight. |
||
| **kwargs: Any, | ||
| ): | ||
| if scaled_orthogonalize_fn is None: | ||
|
|
@@ -112,6 +115,7 @@ def __init__( | |
| self.fp32_matmul_prec = fp32_matmul_prec | ||
| self.nesterov = nesterov | ||
| self.weight_decay_method = weight_decay_method | ||
| self.weight_update_hook = weight_update_hook if weight_update_hook is not None else NoOpWeightUpdateHook() | ||
|
|
||
| default_args_dict = dict( | ||
| lr=lr, | ||
|
|
@@ -195,8 +199,10 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: | |
|
|
||
| # perform weight update with pre and post weight update functions for subclass customization | ||
| self.pre_weight_update_fn_inplace(p, orth_grad) | ||
| weight_update_hook_pre_update_state = self.weight_update_hook.pre_weight_update_inplace(p, orth_grad) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. name is too long, two update in the name. |
||
| p.add_(orth_grad, alpha=-group["lr"]) | ||
| self.post_weight_update_fn_inplace(p) | ||
| self.weight_update_hook.post_weight_update_inplace(p, weight_update_hook_pre_update_state) | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from emerging_optimizers.weight_update_hooks.base import NoOpWeightUpdateHook, WeightUpdateHook | ||
| from emerging_optimizers.weight_update_hooks.hyperball import Hyperball | ||
| from emerging_optimizers.weight_update_hooks.radial_brake import RadialBrake | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Hyperball", | ||
| "NoOpWeightUpdateHook", | ||
| "RadialBrake", | ||
| "WeightUpdateHook", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from typing import Protocol | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| __all__ = ["NoOpWeightUpdateHook", "WeightUpdateHook"] | ||
|
|
||
|
|
||
| class WeightUpdateHook(Protocol): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mkhona-nvidia I don't oppose using Protocol. But I think need to get yourself comfortable with PEP544 before this can be merged. |
||
| """Protocol for behavior around an optimizer's final in-place weight update.""" | ||
|
|
||
| def pre_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| update: torch.Tensor, | ||
| ) -> torch.Tensor | None: | ||
| """Called immediately before ``p.add_(update, alpha=-lr)`` and returns pre-update state.""" | ||
|
|
||
| def post_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| pre_update_state: torch.Tensor | None, | ||
| ) -> None: | ||
| """Called after the optimizer's final update and optimizer-specific post-update hook.""" | ||
|
|
||
|
|
||
| class NoOpWeightUpdateHook: | ||
| """Default hook that leaves the optimizer update unchanged.""" | ||
|
|
||
| def pre_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| update: torch.Tensor, | ||
| ) -> None: | ||
| return None | ||
|
|
||
| def post_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| pre_update_state: torch.Tensor | None, | ||
| ) -> None: | ||
| pass | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import torch | ||
|
|
||
|
|
||
| __all__ = ["Hyperball"] | ||
|
|
||
|
|
||
| class Hyperball: | ||
| """Normalize update and post-update weights to a fixed Frobenius norm. | ||
|
|
||
| This hook mirrors the hyperball-style behavior used by MuonHyperball: before the weight update, normalize the | ||
| update to the target radius; after the weight update, project the parameter back to that radius. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| radius: float | None = None, | ||
| eps: float = 1e-8, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe reduce this to 1e-15? |
||
| ) -> None: | ||
| self.radius = radius | ||
| self.eps = eps | ||
|
|
||
| def pre_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| update: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. detach is not necessary as all of our optimizers are wrapped in no_grad.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, vector_norm has an dtype argument, explicit to fp32 is not necessary. It technically not dtype but compute type though, don't know who added it to pytorch. |
||
|
|
||
| if self.radius is not None: | ||
| radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32) | ||
| else: | ||
| if current_norm.item() == 0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This triggers synced device to host copy. Simply using (current_norm == 0).all(). |
||
| raise ValueError("Hyperball requires all parameters to have non-zero norm when radius is not fixed.") | ||
| radius = current_norm | ||
|
|
||
| update_norm = torch.linalg.vector_norm(update.to(torch.float32)).clamp_min(self.eps) | ||
| update.mul_((radius / update_norm).to(dtype=update.dtype)) | ||
| return radius | ||
|
|
||
| def post_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| pre_update_state: torch.Tensor | None, | ||
| ) -> None: | ||
| if pre_update_state is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q: why does it allow None and only to raise error? making it non optional is not sufficient? |
||
| raise RuntimeError("Hyperball requires radius state") | ||
| radius = pre_update_state | ||
| post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)).clamp_min(self.eps) | ||
| p.mul_((radius / post_norm).to(dtype=p.dtype)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import torch | ||
|
|
||
|
|
||
| __all__ = ["RadialBrake"] | ||
|
|
||
|
|
||
| class RadialBrake: | ||
| """Dampen radial norm changes after an optimizer update. | ||
|
|
||
| The optimizer first applies its usual update ``w = w_prev + dw``. This hook then rescales ``w`` so that | ||
|
|
||
| .. math:: | ||
|
|
||
| \\|w_{brake}\\| = \\|w_{prev}\\| + s(\\|w\\| - \\|w_{prev}\\|) | ||
|
|
||
| where ``s`` is ``outward_scale_factor`` when the update increases the norm, otherwise | ||
| ``inward_scale_factor``. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| outward_scale_factor: float = 0.5, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. scale_factor looks redundant, would outward_scale be ambiguous? if not, use the short form |
||
| inward_scale_factor: float = 1.0, | ||
| eps: float = 1e-12, | ||
| ) -> None: | ||
| if not 0.0 <= outward_scale_factor <= 1.0: | ||
| raise ValueError(f"outward_scale_factor must be in [0, 1], got {outward_scale_factor}") | ||
| if not 0.0 <= inward_scale_factor <= 1.0: | ||
| raise ValueError(f"inward_scale_factor must be in [0, 1], got {inward_scale_factor}") | ||
| self.outward_scale_factor = outward_scale_factor | ||
| self.inward_scale_factor = inward_scale_factor | ||
| self.eps = eps | ||
|
|
||
| def pre_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| update: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| return torch.linalg.vector_norm(p.detach().to(torch.float32)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why eps is not used here? |
||
|
|
||
| def post_weight_update_inplace( | ||
| self, | ||
| p: torch.Tensor, | ||
| pre_update_state: torch.Tensor | None, | ||
| ) -> None: | ||
| if pre_update_state is None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same question as before, what's the reason to allow None and only to raise error? |
||
| raise RuntimeError("RadialBrake requires pre-update norm state") | ||
| pre_norm = pre_update_state | ||
| post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't detach. |
||
| norm_delta = post_norm - pre_norm | ||
| scale_factor = self.outward_scale_factor if norm_delta.item() > 0 else self.inward_scale_factor | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as before, don't use .item(). |
||
| target_norm = pre_norm + scale_factor * norm_delta | ||
| p.mul_((target_norm / post_norm.clamp_min(self.eps)).to(dtype=p.dtype)) | ||
|
mkhona-nvidia marked this conversation as resolved.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -434,6 +434,17 @@ def test_zero_norm_raises_error(self) -> None: | |
| with self.assertRaises(ValueError): | ||
| muon_hyperball.MuonHyperball([test_param], lr=0.01) | ||
|
|
||
| def test_rejects_weight_update_hook(self) -> None: | ||
| """MuonHyperball manages its own Hyperball hook internally.""" | ||
| test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device)) | ||
|
|
||
| with self.assertRaisesRegex(TypeError, "does not accept a 'weight_update_hook' argument"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is better to be a KeyError, |
||
| muon_hyperball.MuonHyperball( | ||
| [test_param], | ||
| lr=0.01, | ||
| weight_update_hook=None, | ||
| ) | ||
|
|
||
|
|
||
| class PolarGradTest(parameterized.TestCase): | ||
| def setUp(self): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.