From ad2967ebdafd2f14230dcf45338f781f1b0e5ad3 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 14:45:53 -0700 Subject: [PATCH 1/5] add stacked soap and tests Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap.py | 130 ++++++++++++++++++++++++++++++- tests/test_soap.py | 101 +++++++++++++++++++++++- 2 files changed, 229 insertions(+), 2 deletions(-) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index c05672ac..ac7e3ee8 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterable from contextlib import nullcontext from functools import partial from typing import TYPE_CHECKING, Callable, override @@ -22,7 +23,7 @@ import torch from absl import logging -from torch import optim +from torch import nn, optim from torch.optim.optimizer import ParamsT from emerging_optimizers import mixin as opt_mixin @@ -34,6 +35,7 @@ __all__ = [ "SOAP", + "StackedSoap", "precondition", "init_kronecker_factors", "update_kronecker_factors", @@ -584,3 +586,129 @@ def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7 scale = (max_rms / (rms + eps)).clamp(max=1.0) # in‐place scale u.mul_(scale) + + +def _stack_2d(x: torch.Tensor) -> torch.Tensor: + """Flattens a (>=2D) tensor to 2D, merging the leading (batch) dims into the smaller matrix edge. + + Args: + x: Tensor with at least 2 dimensions. + + Returns: + The 2D stacking of ``x``. + """ + if x.ndim == 2: + return x + b, p, q = x.shape + if q <= p: + # -> (p, b*q): move the batch next to the smaller edge, then merge. + out = x.reshape(b, p, q).permute(1, 0, 2).reshape(p, b * q) + else: + # -> (b*p, q): contiguous merge into rows. + out = x.reshape(b * p, q) + return out.contiguous() + + +def _unstack(u: torch.Tensor, shape: torch.Size) -> torch.Tensor: + """Inverse of :func:`_stack_2d`, restoring the original ``shape``.""" + if len(shape) == 2: + return u + b, p, q = shape + if q <= p: + return u.reshape(p, b, q).permute(1, 0, 2).reshape(shape) + return u.reshape(shape) + + +@registry.register_optimizer("stacked_soap") +class StackedSoap(SOAP): + """Limited-memory SOAP for batched / 3D parameters via 2D stacking. + + Each parameter is flattened to 2D by merging its leading (batch) dims into the smaller matrix edge + (see :func:`_stack_2d`); the resulting 2D ``shadow`` parameters are what the base :class:`SOAP` + optimizes. Stacking on the smaller edge keeps both Kronecker factors small (the larger edge becomes a + single shared factor) and reuses the full, unmodified SOAP machinery (KL-Shampoo + QR eigenbasis). + ``step`` only bridges gradients in and the update out; all preconditioning is the inherited SOAP step. + + The shadow shares storage with the real parameter whenever the stacking is a pure view (the row-merge + branch and plain 2D parameters), in which case the base SOAP step writes through directly and no + copy-back is needed; the permute branch allocates an independent shadow and the update is copied back. + A plain 2D parameter is stacked as itself, so this is exactly stock SOAP. + + Because the base optimizer operates on the shadows, ``self.param_groups`` and ``state_dict`` are keyed + by the shadow parameters; ``zero_grad`` is overridden to clear the *real* parameters' gradients (the + ones ``backward()`` populates). + + SOAP is configured with the fixed settings appropriate for this use: decoupled weight decay, no + Nesterov, bias correction on, the QR eigenbasis path with 1 power-iteration step, KL-Shampoo on, and + the default matmul precision. + + Args: + params: Iterable of parameters to optimize. + lr: The learning rate. + betas: Inner Adam betas ``(b1, b2)``. + shampoo_beta: Beta for the kronecker factor moving average. + eps: Inner Adam epsilon. + weight_decay: Decoupled weight decay coefficient. + """ + + def __init__( + self, + params: Iterable[torch.Tensor], + lr: float, + betas: tuple[float, float] = (0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + ) -> None: + # Build a 2D shadow per parameter; it shares storage with the real param when _stack_2d returns a + # view (row-merge / plain 2D), otherwise it is an independent buffer. + self._pairs: list[tuple[torch.Tensor, nn.Parameter]] = [ + (p, nn.Parameter(_stack_2d(p.detach()))) for p in params + ] + super().__init__( + [shadow for _, shadow in self._pairs], + lr, + betas=betas, + shampoo_beta=shampoo_beta, + eps=eps, + weight_decay=weight_decay, + weight_decay_method="decoupled", + nesterov=False, + correct_bias=True, + use_eigh=False, + power_iter_steps=1, + use_kl_shampoo=True, + ) + + @override + def zero_grad(self, set_to_none: bool = True) -> None: + for p, _ in self._pairs: + if p.grad is None: + continue + if set_to_none: + p.grad = None + else: + p.grad.zero_() + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + if closure is not None: + raise ValueError("closure is not supported") + for p, shadow in self._pairs: + shadow.grad = _stack_2d(p.grad) if p.grad is not None else None + super().step() + for p, shadow in self._pairs: + # Only copy back when the shadow has independent storage (the permute branch); otherwise + # the base SOAP step already wrote through to p. + if shadow.data_ptr() != p.data_ptr(): + p.copy_(_unstack(shadow, p.shape)) + return None diff --git a/tests/test_soap.py b/tests/test_soap.py index 4ed8f913..d8449e98 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -21,7 +21,7 @@ from absl.testing import absltest, parameterized from emerging_optimizers.soap import REKLS, SOAP, soap -from emerging_optimizers.soap.soap import _clip_update_rms_in_place +from emerging_optimizers.soap.soap import StackedSoap, _clip_update_rms_in_place, _stack_2d, _unstack flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") @@ -586,5 +586,104 @@ def test_eigenbasis_matches_reference(self, shape: tuple, num_steps: int): self.assertEqual(test_state["step"], ref_state["step"]) +class StackedSoapTest(parameterized.TestCase): + def setUp(self): + self.device = FLAGS.device + + @parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6)]) + def test_smoke(self, shape) -> None: + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + opt = StackedSoap([p], lr=1e-2, weight_decay=0.01) + for _ in range(3): + p.grad = torch.randn_like(p) + opt.step() + self.assertTrue(torch.isfinite(p).all()) + + @parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6), (4, 5, 5)]) + def test_stack_unstack_shapes_and_roundtrip(self, shape) -> None: + x = torch.randn(shape, device=self.device) + + if x.ndim == 2: + expected_2d = shape + else: + b, p, q = shape + expected_2d = (p, b * q) if q <= p else (b * p, q) + + stacked = _stack_2d(x) + self.assertEqual(stacked.shape, torch.Size(expected_2d)) + + restored = _unstack(stacked, x.shape) + self.assertEqual(restored.shape, x.shape) + torch.testing.assert_close(restored, x, atol=0, rtol=0) + + @parameterized.product(shape=[(8, 5), (16, 16), (5, 7)]) + def test_2d_input_7steps_matches_vanilla_soap(self, shape) -> None: + x = torch.randn(shape, device=self.device) + p_stacked = torch.nn.Parameter(x.clone()) + p_ref = torch.nn.Parameter(x.clone()) + + opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01) + opt_ref = SOAP( + [p_ref], + 1e-2, + weight_decay=0.01, + weight_decay_method="decoupled", + nesterov=False, + correct_bias=True, + use_eigh=False, + power_iter_steps=1, + use_kl_shampoo=True, + ) + + for _ in range(7): + grad = torch.randn(shape, device=self.device) + p_stacked.grad = grad.clone() + p_ref.grad = grad.clone() + opt_stacked.step() + opt_ref.step() + torch.testing.assert_close( + p_stacked.detach(), + p_ref.detach(), + atol=0, + rtol=0, + msg=lambda m: f"StackedSoap must match stock SOAP exactly on 2D params.\n\n{m}", + ) + + @parameterized.product(shape=[(4, 6, 3), (4, 3, 6)]) + def test_3d_input_5steps_matches_vanilla_soap(self, shape) -> None: + """StackedSoap on a 3D param must match vanilla SOAP run on the manually stacked 2D param.""" + x = torch.randn(shape, device=self.device) + p_stacked = torch.nn.Parameter(x.clone()) + # Reference is vanilla SOAP on the 2D stacking of the same parameter. + p_ref = torch.nn.Parameter(_stack_2d(x).clone()) + + opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01) + opt_ref = SOAP( + [p_ref], + 1e-2, + weight_decay=0.01, + weight_decay_method="decoupled", + nesterov=False, + correct_bias=True, + use_eigh=False, + power_iter_steps=1, + use_kl_shampoo=True, + ) + + for _ in range(5): + grad = torch.randn(shape, device=self.device) + p_stacked.grad = grad.clone() + p_ref.grad = _stack_2d(grad) + opt_stacked.step() + opt_ref.step() + torch.testing.assert_close( + _stack_2d(p_stacked.detach()), + p_ref.detach(), + atol=0, + rtol=0, + msg=lambda m: f"StackedSoap on a 3D param must match vanilla SOAP on its 2D stacking.\n\n{m}", + ) + + if __name__ == "__main__": absltest.main() From e9a628b70daa804682bfd5c85e4193ac7986071d Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 15:30:22 -0700 Subject: [PATCH 2/5] improve stack and unstack logic Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap.py | 84 ++++++++++++++++---------------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index ac7e3ee8..72e7ad94 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -23,7 +23,7 @@ import torch from absl import logging -from torch import nn, optim +from torch import optim from torch.optim.optimizer import ParamsT from emerging_optimizers import mixin as opt_mixin @@ -589,10 +589,13 @@ def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7 def _stack_2d(x: torch.Tensor) -> torch.Tensor: - """Flattens a (>=2D) tensor to 2D, merging the leading (batch) dims into the smaller matrix edge. + """Flattens a 2D or 3D tensor to 2D, merging the batch dim into the smaller matrix edge. + + A 2D tensor is returned unchanged. A 3D tensor ``(b, p, q)`` is merged into the smaller of its two + matrix edges: ``(p, b * q)`` when ``q <= p``, otherwise ``(b * p, q)``. Args: - x: Tensor with at least 2 dimensions. + x: A 2D matrix or a 3D batched matrix ``(batch, p, q)``. Returns: The 2D stacking of ``x``. @@ -621,29 +624,27 @@ def _unstack(u: torch.Tensor, shape: torch.Size) -> torch.Tensor: @registry.register_optimizer("stacked_soap") class StackedSoap(SOAP): - """Limited-memory SOAP for batched / 3D parameters via 2D stacking. - - Each parameter is flattened to 2D by merging its leading (batch) dims into the smaller matrix edge - (see :func:`_stack_2d`); the resulting 2D ``shadow`` parameters are what the base :class:`SOAP` - optimizes. Stacking on the smaller edge keeps both Kronecker factors small (the larger edge becomes a - single shared factor) and reuses the full, unmodified SOAP machinery (KL-Shampoo + QR eigenbasis). - ``step`` only bridges gradients in and the update out; all preconditioning is the inherited SOAP step. + """Limited-memory SOAP for batched / 3D parameters via transient 2D stacking. - The shadow shares storage with the real parameter whenever the stacking is a pure view (the row-merge - branch and plain 2D parameters), in which case the base SOAP step writes through directly and no - copy-back is needed; the permute branch allocates an independent shadow and the update is copied back. - A plain 2D parameter is stacked as itself, so this is exactly stock SOAP. + Optimizes the real parameters directly: ``self.param_groups``, ``self.state``, and gradients are all + keyed by the user's parameters, so learning-rate schedulers, gradient clipping, and ``state_dict`` + behave exactly as for plain :class:`SOAP`. Each 3D parameter is flattened to 2D by merging its batch + dim into the smaller matrix edge (see :func:`_stack_2d`) only for the duration of :meth:`step`: the + parameter's ``data`` and ``grad`` are swapped to their 2D views, the inherited SOAP step runs, and the + 2D update is unstacked back into the original storage. Because the swap happens before the inherited + step, its lazy state initialization sizes the optimizer state to the stacked 2D shape automatically. - Because the base optimizer operates on the shadows, ``self.param_groups`` and ``state_dict`` are keyed - by the shadow parameters; ``zero_grad`` is overridden to clear the *real* parameters' gradients (the - ones ``backward()`` populates). + Stacking on the smaller edge keeps both Kronecker factors small (the larger edge becomes a single + shared factor) while reusing the full, unmodified SOAP machinery (KL-Shampoo + QR eigenbasis). The + stacking is a storage-sharing view except for the permute branch (``q <= p``), which allocates one + transient 2D buffer per step. A plain 2D parameter is stacked as itself, so this is exactly stock SOAP. SOAP is configured with the fixed settings appropriate for this use: decoupled weight decay, no Nesterov, bias correction on, the QR eigenbasis path with 1 power-iteration step, KL-Shampoo on, and the default matmul precision. Args: - params: Iterable of parameters to optimize. + params: Iterable of 2D or 3D parameters to optimize. lr: The learning rate. betas: Inner Adam betas ``(b1, b2)``. shampoo_beta: Beta for the kronecker factor moving average. @@ -660,13 +661,8 @@ def __init__( eps: float = 1e-8, weight_decay: float = 0.01, ) -> None: - # Build a 2D shadow per parameter; it shares storage with the real param when _stack_2d returns a - # view (row-merge / plain 2D), otherwise it is an independent buffer. - self._pairs: list[tuple[torch.Tensor, nn.Parameter]] = [ - (p, nn.Parameter(_stack_2d(p.detach()))) for p in params - ] super().__init__( - [shadow for _, shadow in self._pairs], + params, lr, betas=betas, shampoo_beta=shampoo_beta, @@ -680,16 +676,6 @@ def __init__( use_kl_shampoo=True, ) - @override - def zero_grad(self, set_to_none: bool = True) -> None: - for p, _ in self._pairs: - if p.grad is None: - continue - if set_to_none: - p.grad = None - else: - p.grad.zero_() - if TYPE_CHECKING: @overload @@ -703,12 +689,28 @@ def step(self, closure: Callable[[], float]) -> float: ... def step(self, closure: Callable[[], float] | None = None) -> float | None: if closure is not None: raise ValueError("closure is not supported") - for p, shadow in self._pairs: - shadow.grad = _stack_2d(p.grad) if p.grad is not None else None + + # Swap each parameter's data/grad to their 2D stacking, run the inherited SOAP step on the 2D + # views (state is keyed by the real parameter and sized for the stacked shape), then unstack the + # update back into the original storage. + saved: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + data, grad = p.data, p.grad + saved.append((p, data, grad)) + p.data = _stack_2d(data) + p.grad = _stack_2d(grad) + super().step() - for p, shadow in self._pairs: - # Only copy back when the shadow has independent storage (the permute branch); otherwise - # the base SOAP step already wrote through to p. - if shadow.data_ptr() != p.data_ptr(): - p.copy_(_unstack(shadow, p.shape)) + + for p, data, grad in saved: + stacked = p.data + p.data = data + p.grad = grad + # Copy back only when stacking allocated an independent buffer (permute branch); the view + # branches already wrote the update through to the original storage. + if stacked.data_ptr() != data.data_ptr(): + data.copy_(_unstack(stacked, data.shape)) return None From 283d7ed61c9cd1caab12aecc735cb0795f095eb6 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 15:44:36 -0700 Subject: [PATCH 3/5] Add exception handling Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap.py | 44 +++++++++++++++++--------------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 72e7ad94..16783a44 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -605,7 +605,7 @@ def _stack_2d(x: torch.Tensor) -> torch.Tensor: b, p, q = x.shape if q <= p: # -> (p, b*q): move the batch next to the smaller edge, then merge. - out = x.reshape(b, p, q).permute(1, 0, 2).reshape(p, b * q) + out = x.permute(1, 0, 2).reshape(p, b * q) else: # -> (b*p, q): contiguous merge into rows. out = x.reshape(b * p, q) @@ -692,25 +692,27 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: # Swap each parameter's data/grad to their 2D stacking, run the inherited SOAP step on the 2D # views (state is keyed by the real parameter and sized for the stacked shape), then unstack the - # update back into the original storage. + # update back into the original storage. The restore runs in a finally so that an exception inside + # super().step() (e.g. OOM, a NaN check) cannot leave parameters stuck in their 2D stacked shape. saved: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - data, grad = p.data, p.grad - saved.append((p, data, grad)) - p.data = _stack_2d(data) - p.grad = _stack_2d(grad) - - super().step() - - for p, data, grad in saved: - stacked = p.data - p.data = data - p.grad = grad - # Copy back only when stacking allocated an independent buffer (permute branch); the view - # branches already wrote the update through to the original storage. - if stacked.data_ptr() != data.data_ptr(): - data.copy_(_unstack(stacked, data.shape)) + try: + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + data, grad = p.data, p.grad + saved.append((p, data, grad)) + p.data = _stack_2d(data) + p.grad = _stack_2d(grad) + + super().step() + finally: + for p, data, grad in saved: + stacked = p.data + p.data = data + p.grad = grad + # Copy back only when stacking allocated an independent buffer (permute branch); the view + # branches already wrote the update through to the original storage. + if stacked.data_ptr() != data.data_ptr(): + data.copy_(_unstack(stacked, data.shape)) return None From 18920899de47f58f75c68a827a4ebb8e8e5ea1ae Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 20:58:03 -0700 Subject: [PATCH 4/5] qol improvement Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap.py | 29 ++++++++++++++--------------- tests/test_soap.py | 4 ++-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index 16783a44..ac862387 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterable from contextlib import nullcontext from functools import partial from typing import TYPE_CHECKING, Callable, override @@ -591,24 +590,24 @@ def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7 def _stack_2d(x: torch.Tensor) -> torch.Tensor: """Flattens a 2D or 3D tensor to 2D, merging the batch dim into the smaller matrix edge. - A 2D tensor is returned unchanged. A 3D tensor ``(b, p, q)`` is merged into the smaller of its two - matrix edges: ``(p, b * q)`` when ``q <= p``, otherwise ``(b * p, q)``. + A 2D tensor is returned unchanged. A 3D tensor ``(b, m, n)`` is merged into the smaller of its two + matrix edges: ``(m, b * n)`` when ``n <= m``, otherwise ``(b * m, n)``. Args: - x: A 2D matrix or a 3D batched matrix ``(batch, p, q)``. + x: A 2D matrix ``(m, n)`` or a 3D batched matrix ``(b, m, n)``. Returns: The 2D stacking of ``x``. """ if x.ndim == 2: return x - b, p, q = x.shape - if q <= p: - # -> (p, b*q): move the batch next to the smaller edge, then merge. - out = x.permute(1, 0, 2).reshape(p, b * q) + b, m, n = x.shape + if n <= m: + # -> (m, b*n): move the batch next to the smaller edge, then merge. + out = x.permute(1, 0, 2).reshape(m, b * n) else: - # -> (b*p, q): contiguous merge into rows. - out = x.reshape(b * p, q) + # -> (b*m, n): contiguous merge into rows. + out = x.reshape(b * m, n) return out.contiguous() @@ -616,9 +615,9 @@ def _unstack(u: torch.Tensor, shape: torch.Size) -> torch.Tensor: """Inverse of :func:`_stack_2d`, restoring the original ``shape``.""" if len(shape) == 2: return u - b, p, q = shape - if q <= p: - return u.reshape(p, b, q).permute(1, 0, 2).reshape(shape) + b, m, n = shape + if n <= m: + return u.reshape(m, b, n).permute(1, 0, 2).reshape(shape) return u.reshape(shape) @@ -644,7 +643,7 @@ class StackedSoap(SOAP): the default matmul precision. Args: - params: Iterable of 2D or 3D parameters to optimize. + params: Iterable of 2D or 3D parameters to optimize or dicts defining parameter groups. lr: The learning rate. betas: Inner Adam betas ``(b1, b2)``. shampoo_beta: Beta for the kronecker factor moving average. @@ -654,7 +653,7 @@ class StackedSoap(SOAP): def __init__( self, - params: Iterable[torch.Tensor], + params: ParamsT, lr: float, betas: tuple[float, float] = (0.9, 0.95), shampoo_beta: float = 0.95, diff --git a/tests/test_soap.py b/tests/test_soap.py index d8449e98..1c8ee5fc 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -606,8 +606,8 @@ def test_stack_unstack_shapes_and_roundtrip(self, shape) -> None: if x.ndim == 2: expected_2d = shape else: - b, p, q = shape - expected_2d = (p, b * q) if q <= p else (b * p, q) + b, m, n = shape + expected_2d = (m, b * n) if n <= m else (b * m, n) stacked = _stack_2d(x) self.assertEqual(stacked.shape, torch.Size(expected_2d)) From fdb8c274906c6c4cabcd6ed386e2a84f84d36695 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 24 Jun 2026 21:11:35 -0700 Subject: [PATCH 5/5] qol improvement Signed-off-by: Hao Wu --- emerging_optimizers/soap/soap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index ac862387..710e906d 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -698,7 +698,7 @@ def step(self, closure: Callable[[], float] | None = None) -> float | None: for group in self.param_groups: for p in group["params"]: if p.grad is None: - continue + continue # pragma: no cover data, grad = p.data, p.grad saved.append((p, data, grad)) p.data = _stack_2d(data)