-
Notifications
You must be signed in to change notification settings - Fork 33
Improve hyperball implementation #233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
40ad7e7
1a204b6
7f3be7f
1cc83c0
5c1db4e
f32adcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,7 +12,6 @@ | |||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||
| # limitations under the License. | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
| import torch.nn as nn | ||||||||||||
| from absl import flags, logging | ||||||||||||
|
|
@@ -377,6 +376,7 @@ def test_norm_preservation(self, shape) -> None: | |||||||||||
| lr=0.01, | ||||||||||||
| momentum=0.0, | ||||||||||||
| weight_decay=0.0, | ||||||||||||
| hyperball_radius=initial_norm, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Run multiple steps with random gradients | ||||||||||||
|
|
@@ -392,47 +392,20 @@ def test_norm_preservation(self, shape) -> None: | |||||||||||
| rtol=1e-5, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| @parameterized.product( | ||||||||||||
| shape=[(5, 7), (33, 65), (127, 257)], | ||||||||||||
| hyperball_radius=[0.5, 1.0, 2.0], | ||||||||||||
| ) | ||||||||||||
| def test_hyperball_radius_rescales_params(self, shape, hyperball_radius) -> None: | ||||||||||||
| """Test that hyperball_radius kwarg rescales parameters to specified radius.""" | ||||||||||||
| test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=self.device)) | ||||||||||||
|
|
||||||||||||
| opt = muon_hyperball.MuonHyperball( | ||||||||||||
| [test_param], | ||||||||||||
| lr=0.01, | ||||||||||||
| hyperball_radius=hyperball_radius, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # After initialization, parameter should have the specified radius | ||||||||||||
| torch.testing.assert_close( | ||||||||||||
| test_param.norm(), | ||||||||||||
| torch.tensor(hyperball_radius, device=self.device), | ||||||||||||
| atol=1e-5, | ||||||||||||
| rtol=1e-5, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Run multiple steps with random gradients | ||||||||||||
| for _ in range(5): | ||||||||||||
| test_param.grad = torch.randn_like(test_param) | ||||||||||||
| opt.step() | ||||||||||||
| def test_zero_norm_raises_error(self) -> None: | ||||||||||||
| test_param = nn.Parameter(torch.zeros((5, 7), device=self.device)) | ||||||||||||
|
Comment on lines
+395
to
+396
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @greptile, don't signal this if test name is self-explanary enough.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair point — Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes |
||||||||||||
|
|
||||||||||||
| # Norm should remain at hyperball_radius after each step | ||||||||||||
| torch.testing.assert_close( | ||||||||||||
| test_param.norm(), | ||||||||||||
| torch.tensor(hyperball_radius, device=self.device), | ||||||||||||
| atol=1e-5, | ||||||||||||
| rtol=1e-5, | ||||||||||||
| ) | ||||||||||||
| with self.assertRaises(ValueError): | ||||||||||||
| muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=1.0) | ||||||||||||
|
|
||||||||||||
| def test_zero_norm_raises_error(self) -> None: | ||||||||||||
| """Test that MuonHyperball raises ValueError for zero-norm parameters.""" | ||||||||||||
| test_param = nn.Parameter(torch.zeros((5, 7), dtype=torch.float32, device=self.device)) | ||||||||||||
| def test_radius_mismatch_raises_error(self) -> None: | ||||||||||||
| """Test that MuonHyperball raises ValueError when a parameter's norm does not match the radius.""" | ||||||||||||
| test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device)) | ||||||||||||
| # Pick a radius that differs from the parameter's actual Frobenius norm. | ||||||||||||
| mismatched_radius = test_param.norm().item() + 1.0 | ||||||||||||
|
|
||||||||||||
| with self.assertRaises(ValueError): | ||||||||||||
| muon_hyperball.MuonHyperball([test_param], lr=0.01) | ||||||||||||
| muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=mismatched_radius) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class PolarGradTest(parameterized.TestCase): | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hyperball_epsdefault weakens float32 numerical protectionThe default dropped from
1e-8to1e-15.clamp_min(self.hyperball_eps)is used in bothpre_weight_update_fn_inplaceandpost_weight_update_fn_inplaceto guard against division by near-zero norms. Float32 machine epsilon is ~1.19e-7, so a gradient norm like1e-10is well above1e-15and would not be clamped, while the old1e-8would have clamped it. The result is that a near-zero gradient norm can now produce a scaling factor (hyperball_radius / update_norm) that is up to 10^7× larger than before, potentially creating very large intermediate values during the weight-update step beforepost_weight_update_fn_inplaceprojects back to the sphere. The PR description notes eps handling is being revisited across the repo — worth tracking this alongside that broader change.