Refactor of pre/post weight update hook functions#224
Conversation
Signed-off-by: mikail <mkhona@nvidia.com>
Greptile SummaryThis PR extracts the pre/post weight-update logic from
Confidence Score: 5/5The change is safe to merge; all three issues from the prior review round are addressed and the new hook machinery is well-tested. The refactor is logically sound: transient state is correctly not persisted, the MuonHyperball overwrite bug is fixed with an explicit guard, and RadialBrake's init validation prevents the negative-target-norm scenario. The only finding is a wasted vector_norm call on the fixed-radius path of Hyperball, which does not affect correctness. emerging_optimizers/weight_update_hooks/hyperball.py — minor inefficiency in the fixed-radius code path Important Files Changed
Sequence DiagramsequenceDiagram
participant OO as OrthogonalizedOptimizer.step()
participant Sub as pre/post_weight_update_fn_inplace (subclass override)
participant Hook as WeightUpdateHook (e.g. Hyperball / RadialBrake)
participant P as Parameter p
OO->>Sub: pre_weight_update_fn_inplace(p, orth_grad)
Sub-->>OO: (may modify orth_grad in-place)
OO->>Hook: pre_weight_update_inplace(p, orth_grad)
Hook-->>OO: pre_update_state (transient, not in optimizer.state)
OO->>P: "p.add_(orth_grad, alpha=-lr)"
OO->>Sub: post_weight_update_fn_inplace(p)
Sub-->>OO: (may modify p in-place)
OO->>Hook: post_weight_update_inplace(p, pre_update_state)
Hook-->>OO: (rescales p in-place)
Reviews (2): Last reviewed commit: "fixed greptile caught errors" | Re-trigger Greptile |
| current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) | ||
| if current_norm.item() == 0: | ||
| raise ValueError("Hyperball requires all parameters to have non-zero norm.") | ||
|
|
||
| radius = ( | ||
| torch.as_tensor(self.radius, device=p.device, dtype=torch.float32) | ||
| if self.radius is not None | ||
| else current_norm | ||
| ) |
There was a problem hiding this comment.
The zero-norm guard fires even when
self.radius is not None, but in that branch current_norm is never used as the radius — it is only assigned to radius when self.radius is None. A fixed-radius Hyperball applied to a parameter that starts at zero (e.g. a zero-initialized bias) will therefore raise a ValueError on the very first optimizer step, even though the fixed-radius path has no mathematical dependency on the pre-update parameter norm.
| current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) | |
| if current_norm.item() == 0: | |
| raise ValueError("Hyperball requires all parameters to have non-zero norm.") | |
| radius = ( | |
| torch.as_tensor(self.radius, device=p.device, dtype=torch.float32) | |
| if self.radius is not None | |
| else current_norm | |
| ) | |
| current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) | |
| if self.radius is not None: | |
| radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32) | |
| else: | |
| if current_norm.item() == 0: | |
| raise ValueError("Hyperball requires all parameters to have non-zero norm when radius is not fixed.") | |
| radius = current_norm |
Signed-off-by: mikail <mkhona@nvidia.com>
skyw
left a comment
There was a problem hiding this comment.
Code logic LGTM. A lot of test names can be improved, not super critical.
Before merge, please get familiar with protocol, i.e. being able to explain what purpose it serves here and why not use subclass.
One critical thing is 0 handling, i.e. epsilon.
- at least, it should be consistently applied, for vector_norm for example.
- Comparing a floating point number directly against zero (a == 0 for example) is actually testing underflow, not numerical 0. Same eps apply, if numbers smaller than eps are considered 0, logical test to determine a numerical 0 should be abs(a) < eps
- Magnitude aware handling maybe out of scope, but something to consider as we running into this more and more.
|
|
||
| # perform weight update with pre and post weight update functions for subclass customization | ||
| self.pre_weight_update_fn_inplace(p, orth_grad) | ||
| weight_update_hook_pre_update_state = self.weight_update_hook.pre_weight_update_inplace(p, orth_grad) |
There was a problem hiding this comment.
name is too long, two update in the name.
| __all__ = ["NoOpWeightUpdateHook", "WeightUpdateHook"] | ||
|
|
||
|
|
||
| class WeightUpdateHook(Protocol): |
There was a problem hiding this comment.
@mkhona-nvidia I don't oppose using Protocol. But I think need to get yourself comfortable with PEP544 before this can be merged.
| p: torch.Tensor, | ||
| update: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) |
There was a problem hiding this comment.
detach is not necessary as all of our optimizers are wrapped in no_grad.
There was a problem hiding this comment.
Actually, vector_norm has an dtype argument, explicit to fp32 is not necessary.
It technically not dtype but compute type though, don't know who added it to pytorch.
| if self.radius is not None: | ||
| radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32) | ||
| else: | ||
| if current_norm.item() == 0: |
There was a problem hiding this comment.
This triggers synced device to host copy. Simply using (current_norm == 0).all().
| def __init__( | ||
| self, | ||
| radius: float | None = None, | ||
| eps: float = 1e-8, |
| pre_norm = pre_update_state | ||
| post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32)) | ||
| norm_delta = post_norm - pre_norm | ||
| scale_factor = self.outward_scale_factor if norm_delta.item() > 0 else self.inward_scale_factor |
There was a problem hiding this comment.
same as before, don't use .item().
| """MuonHyperball manages its own Hyperball hook internally.""" | ||
| test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device)) | ||
|
|
||
| with self.assertRaisesRegex(TypeError, "does not accept a 'weight_update_hook' argument"): |
There was a problem hiding this comment.
This is better to be a KeyError,
| pre_update_state = hook.pre_weight_update_inplace(param, update) | ||
| hook.post_weight_update_inplace(param, pre_update_state) | ||
|
|
||
| torch.testing.assert_close(param, param_before, atol=0.0, rtol=0.0) |
There was a problem hiding this comment.
Good to test exact match.
We should probably have a partial function called assert_equal given it is so widely used.
| torch.testing.assert_close(param, param_before, atol=0.0, rtol=0.0) | ||
| torch.testing.assert_close(update, update_before, atol=0.0, rtol=0.0) | ||
|
|
||
| def test_radial_brake_dampens_outward_norm_change(self) -> None: |
There was a problem hiding this comment.
Suggestion: always state expected behavior in test name. e.g. in this case, what close to what. "change" is too vague of a behavior.
| weight_decay_method: opt_mixin.WeightDecayT, | ||
| fp32_matmul_prec: FP32MatmulPrecT, | ||
| scaled_orthogonalize_fn: Callable | None = None, | ||
| weight_update_hook: WeightUpdateHook | None = None, |
There was a problem hiding this comment.
Let's make it private, _weight_update_hook, to suggest it is not supposed to be modified after initialization.
Hook easily invites abuse. making it private at least making people aware of the abusing. A setter can be provided if we want to support properly change the hook inflight.
Weight Update Hooks
This change adds a small
weight_update_hookslibrary for reusable behavior around an optimizer's final in-place parameter update. A hook is a configured object passed into an optimizer, for example:The optimizer owns only one hook object. Hook-specific arguments live on the hook constructor, not on the optimizer constructor. The base update flow is:
pre_update_stateis a transient value returned by the pre hook and immediately consumed by the post hook for the same parameter. It is not stored inoptimizer.state, so it does not pollute checkpoints.The library currently provides three hook implementations, including the newly added
RadialBrakebased on https://nilin.github.io/radial-brake/:NoOpWeightUpdateHook: default no-op behavior.Hyperball: normalizes the update to a target radius before the update, then projects the updated weight back to that radius.RadialBrake: applies the normal optimizer update, then rescales the updated weight so radial norm changes are damped. For pre-update weightw_prevand updated weightw, it sets:where
s = outward_scale_factorif the update increases the norm, otherwises = inward_scale_factor.MuonHyperballnow uses this shared hook machinery by passingHyperball(radius=hyperball_radius, eps=hyperball_eps)intoMuon, instead of implementing custom pre/post update methods and storing temporary values in optimizer state.