Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions emerging_optimizers/soap/soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

__all__ = [
"SOAP",
"StackedSoap",
"precondition",
"init_kronecker_factors",
"update_kronecker_factors",
Expand Down Expand Up @@ -584,3 +585,133 @@ def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7
scale = (max_rms / (rms + eps)).clamp(max=1.0)
# in‐place scale
u.mul_(scale)


def _stack_2d(x: torch.Tensor) -> torch.Tensor:
"""Flattens a 2D or 3D tensor to 2D, merging the batch dim into the smaller matrix edge.

A 2D tensor is returned unchanged. A 3D tensor ``(b, m, n)`` is merged into the smaller of its two
matrix edges: ``(m, b * n)`` when ``n <= m``, otherwise ``(b * m, n)``.

Args:
x: A 2D matrix ``(m, n)`` or a 3D batched matrix ``(b, m, n)``.

Returns:
The 2D stacking of ``x``.
"""
if x.ndim == 2:
return x
b, m, n = x.shape
if n <= m:
# -> (m, b*n): move the batch next to the smaller edge, then merge.
out = x.permute(1, 0, 2).reshape(m, b * n)
else:
# -> (b*m, n): contiguous merge into rows.
out = x.reshape(b * m, n)
return out.contiguous()


def _unstack(u: torch.Tensor, shape: torch.Size) -> torch.Tensor:
"""Inverse of :func:`_stack_2d`, restoring the original ``shape``."""
if len(shape) == 2:
return u
b, m, n = shape
if n <= m:
return u.reshape(m, b, n).permute(1, 0, 2).reshape(shape)
return u.reshape(shape)


@registry.register_optimizer("stacked_soap")
class StackedSoap(SOAP):
"""Limited-memory SOAP for batched / 3D parameters via transient 2D stacking.

Optimizes the real parameters directly: ``self.param_groups``, ``self.state``, and gradients are all
keyed by the user's parameters, so learning-rate schedulers, gradient clipping, and ``state_dict``
behave exactly as for plain :class:`SOAP`. Each 3D parameter is flattened to 2D by merging its batch
dim into the smaller matrix edge (see :func:`_stack_2d`) only for the duration of :meth:`step`: the
parameter's ``data`` and ``grad`` are swapped to their 2D views, the inherited SOAP step runs, and the
2D update is unstacked back into the original storage. Because the swap happens before the inherited
step, its lazy state initialization sizes the optimizer state to the stacked 2D shape automatically.

Stacking on the smaller edge keeps both Kronecker factors small (the larger edge becomes a single
shared factor) while reusing the full, unmodified SOAP machinery (KL-Shampoo + QR eigenbasis). The
stacking is a storage-sharing view except for the permute branch (``q <= p``), which allocates one
transient 2D buffer per step. A plain 2D parameter is stacked as itself, so this is exactly stock SOAP.

SOAP is configured with the fixed settings appropriate for this use: decoupled weight decay, no
Nesterov, bias correction on, the QR eigenbasis path with 1 power-iteration step, KL-Shampoo on, and
the default matmul precision.

Args:
params: Iterable of 2D or 3D parameters to optimize or dicts defining parameter groups.
lr: The learning rate.
betas: Inner Adam betas ``(b1, b2)``.
shampoo_beta: Beta for the kronecker factor moving average.
eps: Inner Adam epsilon.
weight_decay: Decoupled weight decay coefficient.
"""

def __init__(
self,
params: ParamsT,
lr: float,
betas: tuple[float, float] = (0.9, 0.95),
shampoo_beta: float = 0.95,
eps: float = 1e-8,
weight_decay: float = 0.01,
) -> None:
super().__init__(
params,
lr,
betas=betas,
shampoo_beta=shampoo_beta,
eps=eps,
weight_decay=weight_decay,
weight_decay_method="decoupled",
nesterov=False,
correct_bias=True,
use_eigh=False,
power_iter_steps=1,
use_kl_shampoo=True,
)

if TYPE_CHECKING:

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
if closure is not None:
raise ValueError("closure is not supported")

# Swap each parameter's data/grad to their 2D stacking, run the inherited SOAP step on the 2D
# views (state is keyed by the real parameter and sized for the stacked shape), then unstack the
# update back into the original storage. The restore runs in a finally so that an exception inside
# super().step() (e.g. OOM, a NaN check) cannot leave parameters stuck in their 2D stacked shape.
saved: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
try:
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue # pragma: no cover
data, grad = p.data, p.grad
saved.append((p, data, grad))
p.data = _stack_2d(data)
p.grad = _stack_2d(grad)

super().step()
finally:
for p, data, grad in saved:
stacked = p.data
p.data = data
p.grad = grad
# Copy back only when stacking allocated an independent buffer (permute branch); the view
# branches already wrote the update through to the original storage.
if stacked.data_ptr() != data.data_ptr():
data.copy_(_unstack(stacked, data.shape))
return None
101 changes: 100 additions & 1 deletion tests/test_soap.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from absl.testing import absltest, parameterized

from emerging_optimizers.soap import REKLS, SOAP, soap
from emerging_optimizers.soap.soap import _clip_update_rms_in_place
from emerging_optimizers.soap.soap import StackedSoap, _clip_update_rms_in_place, _stack_2d, _unstack


flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on")
Expand Down Expand Up @@ -586,5 +586,104 @@ def test_eigenbasis_matches_reference(self, shape: tuple, num_steps: int):
self.assertEqual(test_state["step"], ref_state["step"])


class StackedSoapTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device

@parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6)])
def test_smoke(self, shape) -> None:
p = torch.nn.Parameter(torch.randn(shape, device=self.device))
opt = StackedSoap([p], lr=1e-2, weight_decay=0.01)
for _ in range(3):
p.grad = torch.randn_like(p)
opt.step()
self.assertTrue(torch.isfinite(p).all())

@parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6), (4, 5, 5)])
def test_stack_unstack_shapes_and_roundtrip(self, shape) -> None:
x = torch.randn(shape, device=self.device)

if x.ndim == 2:
expected_2d = shape
else:
b, m, n = shape
expected_2d = (m, b * n) if n <= m else (b * m, n)

stacked = _stack_2d(x)
self.assertEqual(stacked.shape, torch.Size(expected_2d))

restored = _unstack(stacked, x.shape)
self.assertEqual(restored.shape, x.shape)
torch.testing.assert_close(restored, x, atol=0, rtol=0)

@parameterized.product(shape=[(8, 5), (16, 16), (5, 7)])
def test_2d_input_7steps_matches_vanilla_soap(self, shape) -> None:
x = torch.randn(shape, device=self.device)
p_stacked = torch.nn.Parameter(x.clone())
p_ref = torch.nn.Parameter(x.clone())

opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01)
opt_ref = SOAP(
[p_ref],
1e-2,
weight_decay=0.01,
weight_decay_method="decoupled",
nesterov=False,
correct_bias=True,
use_eigh=False,
power_iter_steps=1,
use_kl_shampoo=True,
)

for _ in range(7):
grad = torch.randn(shape, device=self.device)
p_stacked.grad = grad.clone()
p_ref.grad = grad.clone()
opt_stacked.step()
opt_ref.step()
torch.testing.assert_close(
p_stacked.detach(),
p_ref.detach(),
atol=0,
rtol=0,
msg=lambda m: f"StackedSoap must match stock SOAP exactly on 2D params.\n\n{m}",
)

@parameterized.product(shape=[(4, 6, 3), (4, 3, 6)])
def test_3d_input_5steps_matches_vanilla_soap(self, shape) -> None:
"""StackedSoap on a 3D param must match vanilla SOAP run on the manually stacked 2D param."""
x = torch.randn(shape, device=self.device)
p_stacked = torch.nn.Parameter(x.clone())
# Reference is vanilla SOAP on the 2D stacking of the same parameter.
p_ref = torch.nn.Parameter(_stack_2d(x).clone())

opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01)
opt_ref = SOAP(
[p_ref],
1e-2,
weight_decay=0.01,
weight_decay_method="decoupled",
nesterov=False,
correct_bias=True,
use_eigh=False,
power_iter_steps=1,
use_kl_shampoo=True,
)

for _ in range(5):
grad = torch.randn(shape, device=self.device)
p_stacked.grad = grad.clone()
p_ref.grad = _stack_2d(grad)
opt_stacked.step()
opt_ref.step()
torch.testing.assert_close(
_stack_2d(p_stacked.detach()),
p_ref.detach(),
atol=0,
rtol=0,
msg=lambda m: f"StackedSoap on a 3D param must match vanilla SOAP on its 2D stacking.\n\n{m}",
)


if __name__ == "__main__":
absltest.main()
Loading