Improve hyperball implementation#233
Conversation
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
| hyperball_eps: float = 1e-8, | ||
| hyperball_radius: float | None = None, | ||
| hyperball_radius: float, | ||
| hyperball_eps: float = 1e-15, |
There was a problem hiding this comment.
hyperball_eps default weakens float32 numerical protection
The default dropped from 1e-8 to 1e-15. clamp_min(self.hyperball_eps) is used in both pre_weight_update_fn_inplace and post_weight_update_fn_inplace to guard against division by near-zero norms. Float32 machine epsilon is ~1.19e-7, so a gradient norm like 1e-10 is well above 1e-15 and would not be clamped, while the old 1e-8 would have clamped it. The result is that a near-zero gradient norm can now produce a scaling factor (hyperball_radius / update_norm) that is up to 10^7× larger than before, potentially creating very large intermediate values during the weight-update step before post_weight_update_fn_inplace projects back to the sphere. The PR description notes eps handling is being revisited across the repo — worth tracking this alongside that broader change.
| def test_zero_norm_raises_error(self) -> None: | ||
| test_param = nn.Parameter(torch.zeros((5, 7), device=self.device)) |
There was a problem hiding this comment.
Missing docstring on
test_zero_norm_raises_error — inconsistent with the surrounding tests which all carry docstrings describing their intent.
| def test_zero_norm_raises_error(self) -> None: | |
| test_param = nn.Parameter(torch.zeros((5, 7), device=self.device)) | |
| def test_zero_norm_raises_error(self) -> None: | |
| """Test that MuonHyperball raises ValueError for zero-norm parameters.""" | |
| test_param = nn.Parameter(torch.zeros((5, 7), device=self.device)) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
@greptile, don't signal this if test name is self-explanary enough.
There was a problem hiding this comment.
Fair point — test_zero_norm_raises_error is clear enough on its own. I'll remember not to flag missing docstrings on test methods when the test name already explains the intent.
Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?
There was a problem hiding this comment.
Fair point —
test_zero_norm_raises_erroris clear enough on its own. I'll remember not to flag missing docstrings on test methods when the test name already explains the intent.Shall I create a rule: "Do not flag missing docstrings on test methods when the test name is sufficiently self-explanatory"?
yes
|
/ok to test f32adcf |
Updated naming and interface.
eps handling needs to be revisited for the entire repo.