Skip to content

Refactor of pre/post weight update hook functions#224

Open
mkhona-nvidia wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/pre_post_weight_update_refactors
Open

Refactor of pre/post weight update hook functions#224
mkhona-nvidia wants to merge 2 commits into
NVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/pre_post_weight_update_refactors

Conversation

@mkhona-nvidia

Copy link
Copy Markdown
Contributor

Weight Update Hooks

This change adds a small weight_update_hooks library for reusable behavior around an optimizer's final in-place parameter update. A hook is a configured object passed into an optimizer, for example:

from emerging_optimizers.weight_update_hooks import RadialBrake

optimizer = Muon(
    params,
    weight_decay=0.0,
    weight_update_hook=RadialBrake(outward_scale_factor=0.5),
)

The optimizer owns only one hook object. Hook-specific arguments live on the hook constructor, not on the optimizer constructor. The base update flow is:

pre_update_state = weight_update_hook.pre_weight_update_inplace(p, update)
p.add_(update, alpha=-lr)
post_weight_update_fn_inplace(p)
weight_update_hook.post_weight_update_inplace(p, pre_update_state)

pre_update_state is a transient value returned by the pre hook and immediately consumed by the post hook for the same parameter. It is not stored in optimizer.state, so it does not pollute checkpoints.

The library currently provides three hook implementations, including the newly added RadialBrake based on https://nilin.github.io/radial-brake/:

  • NoOpWeightUpdateHook: default no-op behavior.
  • Hyperball: normalizes the update to a target radius before the update, then projects the updated weight back to that radius.
  • RadialBrake: applies the normal optimizer update, then rescales the updated weight so radial norm changes are damped. For pre-update weight w_prev and updated weight w, it sets:
$$\|w_{\text{brake}}\| = \|w_{\text{prev}}\| + s(\|w\| - \|w_{\text{prev}}\|)$$

where s = outward_scale_factor if the update increases the norm, otherwise s = inward_scale_factor.

MuonHyperball now uses this shared hook machinery by passing Hyperball(radius=hyperball_radius, eps=hyperball_eps) into Muon, instead of implementing custom pre/post update methods and storing temporary values in optimizer state.

Signed-off-by: mikail <mkhona@nvidia.com>
@mkhona-nvidia mkhona-nvidia requested a review from skyw June 4, 2026 18:10
@copy-pr-bot

copy-pr-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@mkhona-nvidia mkhona-nvidia changed the title refactor of pre/post weight update hook functions Refactor of pre/post weight update hook functions Jun 4, 2026
@greptile-apps

greptile-apps Bot commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extracts the pre/post weight-update logic from MuonHyperball into a reusable weight_update_hooks library and adds a new RadialBrake hook. The three previously reported issues (silent overwrite of weight_update_hook in MuonHyperball, potential negative target_norm in RadialBrake, and the misplaced zero-norm guard in Hyperball) are all addressed in this revision.

  • New weight_update_hooks package introduces WeightUpdateHook Protocol, NoOpWeightUpdateHook, Hyperball, and RadialBrake; pre-update state is transient and never written to optimizer.state, keeping checkpoints clean.
  • MuonHyperball refactored to delegate entirely to the Hyperball hook rather than overriding pre/post_weight_update_fn_inplace; now raises TypeError if a caller attempts to pass their own weight_update_hook.
  • RadialBrake.__init__ validates both scale factors to [0.0, 1.0], eliminating the previously identified negative-target_norm path.

Confidence Score: 5/5

The change is safe to merge; all three issues from the prior review round are addressed and the new hook machinery is well-tested.

The refactor is logically sound: transient state is correctly not persisted, the MuonHyperball overwrite bug is fixed with an explicit guard, and RadialBrake's init validation prevents the negative-target-norm scenario. The only finding is a wasted vector_norm call on the fixed-radius path of Hyperball, which does not affect correctness.

emerging_optimizers/weight_update_hooks/hyperball.py — minor inefficiency in the fixed-radius code path

Important Files Changed

Filename Overview
emerging_optimizers/weight_update_hooks/hyperball.py New Hyperball hook; fixed-radius zero-norm guard is correctly scoped, but current_norm is computed unconditionally even when unused (fixed-radius path).
emerging_optimizers/weight_update_hooks/radial_brake.py New RadialBrake hook; init correctly validates scale factors to [0,1], eliminating the negative-target-norm risk; post_norm zero-division is guarded with clamp_min.
emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py Integrates the hook into the step loop cleanly; pre-state is transient (not stored in optimizer.state), consistent with PR intent.
emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py Refactored to delegate to Hyperball hook; raises TypeError when caller passes weight_update_hook, removing the previous silent-overwrite bug.
emerging_optimizers/weight_update_hooks/base.py Defines WeightUpdateHook Protocol and NoOpWeightUpdateHook; Protocol return types are consistent with all concrete implementations.
emerging_optimizers/orthogonalized_optimizers/muon.py Minimal change: adds weight_update_hook parameter and forwards it to OrthogonalizedOptimizer.init.
tests/test_weight_update_hooks.py New test file covering NoOp, RadialBrake, and Hyperball hooks including edge cases (zero-norm rejection, fixed-radius zero-norm parameter, scale-factor validation).
tests/test_orthogonalized_optimizer.py Adds test verifying MuonHyperball rejects weight_update_hook argument (including None), covering the previously reported silent-overwrite bug.

Sequence Diagram

sequenceDiagram
    participant OO as OrthogonalizedOptimizer.step()
    participant Sub as pre/post_weight_update_fn_inplace (subclass override)
    participant Hook as WeightUpdateHook (e.g. Hyperball / RadialBrake)
    participant P as Parameter p

    OO->>Sub: pre_weight_update_fn_inplace(p, orth_grad)
    Sub-->>OO: (may modify orth_grad in-place)
    OO->>Hook: pre_weight_update_inplace(p, orth_grad)
    Hook-->>OO: pre_update_state (transient, not in optimizer.state)
    OO->>P: "p.add_(orth_grad, alpha=-lr)"
    OO->>Sub: post_weight_update_fn_inplace(p)
    Sub-->>OO: (may modify p in-place)
    OO->>Hook: post_weight_update_inplace(p, pre_update_state)
    Hook-->>OO: (rescales p in-place)
Loading

Reviews (2): Last reviewed commit: "fixed greptile caught errors" | Re-trigger Greptile

Comment on lines +41 to +49
current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))
if current_norm.item() == 0:
raise ValueError("Hyperball requires all parameters to have non-zero norm.")

radius = (
torch.as_tensor(self.radius, device=p.device, dtype=torch.float32)
if self.radius is not None
else current_norm
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 The zero-norm guard fires even when self.radius is not None, but in that branch current_norm is never used as the radius — it is only assigned to radius when self.radius is None. A fixed-radius Hyperball applied to a parameter that starts at zero (e.g. a zero-initialized bias) will therefore raise a ValueError on the very first optimizer step, even though the fixed-radius path has no mathematical dependency on the pre-update parameter norm.

Suggested change
current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))
if current_norm.item() == 0:
raise ValueError("Hyperball requires all parameters to have non-zero norm.")
radius = (
torch.as_tensor(self.radius, device=p.device, dtype=torch.float32)
if self.radius is not None
else current_norm
)
current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))
if self.radius is not None:
radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32)
else:
if current_norm.item() == 0:
raise ValueError("Hyperball requires all parameters to have non-zero norm when radius is not fixed.")
radius = current_norm

Comment thread emerging_optimizers/weight_update_hooks/radial_brake.py
Comment thread emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py
Signed-off-by: mikail <mkhona@nvidia.com>

@skyw skyw left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code logic LGTM. A lot of test names can be improved, not super critical.

Before merge, please get familiar with protocol, i.e. being able to explain what purpose it serves here and why not use subclass.

One critical thing is 0 handling, i.e. epsilon.

  • at least, it should be consistently applied, for vector_norm for example.
  • Comparing a floating point number directly against zero (a == 0 for example) is actually testing underflow, not numerical 0. Same eps apply, if numbers smaller than eps are considered 0, logical test to determine a numerical 0 should be abs(a) < eps
  • Magnitude aware handling maybe out of scope, but something to consider as we running into this more and more.


# perform weight update with pre and post weight update functions for subclass customization
self.pre_weight_update_fn_inplace(p, orth_grad)
weight_update_hook_pre_update_state = self.weight_update_hook.pre_weight_update_inplace(p, orth_grad)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name is too long, two update in the name.

__all__ = ["NoOpWeightUpdateHook", "WeightUpdateHook"]


class WeightUpdateHook(Protocol):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mkhona-nvidia I don't oppose using Protocol. But I think need to get yourself comfortable with PEP544 before this can be merged.

p: torch.Tensor,
update: torch.Tensor,
) -> torch.Tensor:
current_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detach is not necessary as all of our optimizers are wrapped in no_grad.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, vector_norm has an dtype argument, explicit to fp32 is not necessary.

It technically not dtype but compute type though, don't know who added it to pytorch.

if self.radius is not None:
radius = torch.as_tensor(self.radius, device=p.device, dtype=torch.float32)
else:
if current_norm.item() == 0:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This triggers synced device to host copy. Simply using (current_norm == 0).all().

def __init__(
self,
radius: float | None = None,
eps: float = 1e-8,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reduce this to 1e-15?

pre_norm = pre_update_state
post_norm = torch.linalg.vector_norm(p.detach().to(torch.float32))
norm_delta = post_norm - pre_norm
scale_factor = self.outward_scale_factor if norm_delta.item() > 0 else self.inward_scale_factor

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as before, don't use .item().

"""MuonHyperball manages its own Hyperball hook internally."""
test_param = nn.Parameter(torch.randn((5, 7), dtype=torch.float32, device=self.device))

with self.assertRaisesRegex(TypeError, "does not accept a 'weight_update_hook' argument"):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better to be a KeyError,

pre_update_state = hook.pre_weight_update_inplace(param, update)
hook.post_weight_update_inplace(param, pre_update_state)

torch.testing.assert_close(param, param_before, atol=0.0, rtol=0.0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to test exact match.

We should probably have a partial function called assert_equal given it is so widely used.

torch.testing.assert_close(param, param_before, atol=0.0, rtol=0.0)
torch.testing.assert_close(update, update_before, atol=0.0, rtol=0.0)

def test_radial_brake_dampens_outward_norm_change(self) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: always state expected behavior in test name. e.g. in this case, what close to what. "change" is too vague of a behavior.

weight_decay_method: opt_mixin.WeightDecayT,
fp32_matmul_prec: FP32MatmulPrecT,
scaled_orthogonalize_fn: Callable | None = None,
weight_update_hook: WeightUpdateHook | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it private, _weight_update_hook, to suggest it is not supposed to be modified after initialization.

Hook easily invites abuse. making it private at least making people aware of the abusing. A setter can be provided if we want to support properly change the hook inflight.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants