Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/apidocs/orthogonalized-optimizers.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ emerging_optimizers.orthogonalized_optimizers
.. autoclass:: PolarGrad
:members:

.. autofunction:: right_polargrad_orth_fn


:hidden:`AdaptiveMuon`
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
75 changes: 74 additions & 1 deletion emerging_optimizers/orthogonalized_optimizers/polargrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
from emerging_optimizers.utils import FP32MatmulPrecT
from emerging_optimizers.utils.eig import eigh_with_fallback


__all__ = ["PolarGrad"]
__all__ = ["PolarGrad", "right_polargrad_orth_fn"]


@registry.register_optimizer("polargrad")
Expand Down Expand Up @@ -103,3 +104,75 @@ def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor:


PolarGrad.__doc__ = PolarGrad.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]


def right_polargrad_orth_fn(
grad: torch.Tensor,
*,
alpha: float = 1.0,
center_rows: bool = False,
eps: float = 1e-15,
extra_scale_factor: float = 1.0,
) -> torch.Tensor:
r"""Right-spectral (one-sided polar) orthogonalization for tall matrices.

Orthogonalizes only the right factor of a tall matrix ``G`` (e.g. an embedding or LM-head weight,
``vocab x hidden``):

.. math::
u = G \, (G^\top G)^{-1/2}, \qquad \text{update} = \lVert G \rVert_*^{\,\alpha} \, u

.. code-block:: python
:caption: Define a ``RightPolarGrad`` by partially applying this orthogonalization

class RightPolarGrad(OrthogonalizedOptimizer):
def __init__(
self,
params,
lr: float = 3e-4,
momentum: float = 0.95,
weight_decay: float = 0.01,
*,
...
alpha: float = 1.0,
center_rows: bool = False,
eps: float = 1e-15,
extra_scale_factor: float = 1.0,
) -> None:
scaled_orthogonalize_fn = functools.partial(
right_polargrad_orth_fn,
alpha=alpha,
center_rows=center_rows,
eps=eps,
extra_scale_factor=extra_scale_factor,
)
super().__init__(
...
)

Args:
grad: The (momentum) tensor to orthogonalize.
alpha: Exponent applied to the nuclear-norm scale factor.
center_rows: If True, subtract the per-column mean (the average over the row / vocabulary axis,
``dim=0``) before and after the update, so each column is zero-mean.
eps: Floor on the right-Gram eigenvalues for the inverse sqrt and nuclear-norm computation.
extra_scale_factor: Extra multiplier on the update.

Returns:
The scaled right-polar update, same shape and dtype as ``grad``.
"""
m = grad.to(torch.float32)
if center_rows:
m = m - m.mean(dim=0, keepdim=True)

eigvals, eigvecs = eigh_with_fallback(m.transpose(-1, -2) @ m)
eigvals.clamp_min_(eps)
right_gram_inv_sqrt = (eigvecs * eigvals.rsqrt().unsqueeze(-2)) @ eigvecs.transpose(-1, -2)

u = m @ right_gram_inv_sqrt
nuclear_norm = eigvals.sqrt().sum()
update = u * nuclear_norm.pow(alpha) * extra_scale_factor

if center_rows:
update = update - update.mean(dim=0, keepdim=True)
return update.to(grad.dtype)
54 changes: 1 addition & 53 deletions tests/test_orthogonalized_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from absl import flags, logging
from absl.testing import absltest, parameterized

from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, polargrad, scion
from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, scion
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer


Expand Down Expand Up @@ -408,57 +408,5 @@ def test_radius_mismatch_raises_error(self) -> None:
muon_hyperball.MuonHyperball([test_param], lr=0.01, hyperball_radius=mismatched_radius)


class PolarGradTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device

@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
extra_scale_factor=[1.0, 0.2],
)
def test_smoke(self, shape, extra_scale_factor) -> None:
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)

polargrad_opt = polargrad.PolarGrad(
[test_param],
extra_scale_factor=extra_scale_factor,
)
polargrad_opt.step()

@parameterized.product(
shape=[(4, 8), (16, 16), (32, 64), (13, 17)],
extra_scale_factor=[0.25, 0.125],
)
def test_orthogonalize_fn_matches_ref(self, shape, extra_scale_factor) -> None:
dummy_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
dummy_grad = torch.full(shape, 0.5, dtype=torch.float32, device=self.device)

# Set num_ns_steps to 0 to skip Newton-Schulz iterations and only normalize the input gradient.
polargrad_opt = polargrad.PolarGrad([dummy_param], num_ns_steps=0, extra_scale_factor=extra_scale_factor)
norm_grad = torch.nn.functional.normalize(dummy_grad, p=2, dim=(-2, -1), eps=1e-7)

# Assert normalization took effect
self.assertFalse((norm_grad == 1).all())

ref_scale = (norm_grad * dummy_grad).sum()
ref_out = norm_grad * ref_scale * extra_scale_factor

test_out = polargrad_opt.scaled_orthogonalize_fn(dummy_grad)

torch.testing.assert_close(
ref_out,
test_out,
atol=0,
rtol=0,
)

def test_negative_num_ns_steps_raises_value_error(self) -> None:
"""Test that PolarGrad raises ValueError for negative num_ns_steps."""
test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device))
with self.assertRaisesRegex(ValueError, "num_ns_steps must be positive"):
polargrad.PolarGrad([test_param], num_ns_steps=-1)


if __name__ == "__main__":
absltest.main()
133 changes: 133 additions & 0 deletions tests/test_polargrad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools

import torch
import torch.nn as nn
from absl import flags, logging
from absl.testing import absltest, parameterized

from emerging_optimizers.orthogonalized_optimizers import polargrad
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer


flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on")
flags.DEFINE_integer("seed", None, "Random seed for reproducible tests")
Comment thread
skyw marked this conversation as resolved.
FLAGS = flags.FLAGS


def setUpModule() -> None:
if FLAGS.seed is not None:
logging.info("Setting random seed to %d", FLAGS.seed)
torch.manual_seed(FLAGS.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(FLAGS.seed)


class PolarGradTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device

@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
extra_scale_factor=[1.0, 0.2],
)
def test_smoke(self, shape, extra_scale_factor) -> None:
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)

polargrad_opt = polargrad.PolarGrad(
[test_param],
extra_scale_factor=extra_scale_factor,
)
polargrad_opt.step()

@parameterized.product(
shape=[(4, 8), (16, 16), (32, 64), (13, 17)],
extra_scale_factor=[0.25, 0.125],
)
def test_orthogonalize_fn_matches_ref(self, shape, extra_scale_factor) -> None:
dummy_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
dummy_grad = torch.full(shape, 0.5, dtype=torch.float32, device=self.device)

# Set num_ns_steps to 0 to skip Newton-Schulz iterations and only normalize the input gradient.
polargrad_opt = polargrad.PolarGrad([dummy_param], num_ns_steps=0, extra_scale_factor=extra_scale_factor)
norm_grad = torch.nn.functional.normalize(dummy_grad, p=2, dim=(-2, -1), eps=1e-7)

# Assert normalization took effect
self.assertFalse((norm_grad == 1).all())

ref_scale = (norm_grad * dummy_grad).sum()
ref_out = norm_grad * ref_scale * extra_scale_factor

test_out = polargrad_opt.scaled_orthogonalize_fn(dummy_grad)

torch.testing.assert_close(
ref_out,
test_out,
atol=0,
rtol=0,
)

def test_negative_num_ns_steps_raises_value_error(self) -> None:
"""Test that PolarGrad raises ValueError for negative num_ns_steps."""
test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device))
with self.assertRaisesRegex(ValueError, "num_ns_steps must be positive"):
polargrad.PolarGrad([test_param], num_ns_steps=-1)


class RightPolarGradOrthFnTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device

@parameterized.product(shape=[(8, 4), (32, 8)], center_rows=[False, True])
def test_right_orthogonal_equivariance(self, shape, center_rows) -> None:
"""f(G Q) == f(G) Q for an orthogonal Q acting on the hidden (right) dimension."""
grad = torch.randn(shape, device=self.device)
n = shape[1]
q, _ = torch.linalg.qr(torch.randn(n, n, device=self.device))

rotated = polargrad.right_polargrad_orth_fn(grad @ q, center_rows=center_rows)
expected = polargrad.right_polargrad_orth_fn(grad, center_rows=center_rows) @ q
torch.testing.assert_close(
rotated,
expected,
atol=1e-5,
rtol=1e-5,
)

def test_usable_as_scaled_orthogonalize_fn(self) -> None:
param = nn.Parameter(torch.randn((32, 8), device=self.device))
scaled_orthogonalize_fn = functools.partial(
polargrad.right_polargrad_orth_fn, alpha=1.0, center_rows=True, extra_scale_factor=0.2
)
opt = OrthogonalizedOptimizer(
[param],
lr=1e-2,
momentum=0.95,
weight_decay=0.01,
nesterov=False,
weight_decay_method="decoupled",
fp32_matmul_prec="highest",
scaled_orthogonalize_fn=scaled_orthogonalize_fn,
)
for _ in range(3):
param.grad = torch.randn_like(param)
opt.step()
self.assertTrue(torch.isfinite(param).all())


if __name__ == "__main__":
absltest.main()
Loading