Sven is a PyTorch optimizer that replaces standard gradient descent with parameter updates computed via the Moore-Penrose pseudoinverse of the per-sample Jacobian matrix. Where SGD computes a single gradient by averaging over the batch, Sven decomposes the loss into individual per-sample components and solves for the minimum-norm parameter update that simultaneously reduces all of them, using a truncated SVD to keep the computation tractable.
In the over-parameterized regime this yields the minimum-norm solution among all updates that minimize the L2 error across the batch, and under favorable conditions can achieve exponential loss decay rather than the power-law behavior typical of first-order methods.
pip install -e .Sven is a near drop-in replacement for a standard PyTorch optimizer, with two differences: (1) the model must be wrapped with SvenWrapper, which converts it to a functional form for per-sample Jacobian computation, and (2) the training step calls loss_and_grad instead of the usual loss.backward().
import torch
import torch.nn as nn
from sven.nn import SvenWrapper
from sven.opt import Sven
# Define any standard PyTorch model and a per-sample loss function
model = nn.Sequential(nn.Linear(1, 64), nn.GELU(), nn.Linear(64, 1))
loss_fn = lambda pred, y: ((pred - y) ** 2).sum(dim=-1) # must return shape (B,)
# Wrap the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wrapped = SvenWrapper(model, loss_fn, device)
# Create the optimizer
optimizer = Sven(wrapped, lr=0.1, k=64, rtol=1e-3)
# Training step
for xb, yb in train_loader:
xb, yb = xb.to(device), yb.to(device)
losses, preds = wrapped.loss_and_grad((xb, yb))
optimizer.step()See examples/toy_1d_regression.ipynb for a complete worked example comparing Sven to Adam.
The loss function passed to SvenWrapper must return per-sample losses with shape (B,), not a scalar. This is because Sven needs the individual loss components to construct the Jacobian matrix.
# Correct: returns (B,) tensor
loss_fn = lambda pred, y: ((pred - y) ** 2).sum(dim=-1)
# Wrong: returns scalar
loss_fn = nn.MSELoss()k: Number of singular values to keep in the truncated SVD. Controls the rank of the pseudoinverse approximation. A good starting point isbatch_size // 2.rtol: Relative tolerance for singular value truncation. Singular values smaller thanrtol * sigma_maxare discarded. Default1e-3.lr: Learning rate applied to the pseudoinverse update.svd_mode: Algorithm for computing the truncated SVD. Options:"torch"(full SVD then truncate),"randomized"(randomized SVD),"randomized_v2"(different randomized strategy using eigendecomposition),"scipy","lobpcg". Default"torch".
The per-sample Jacobian has shape (B, P) where B is batch size and P is the number of parameters, so memory scales as O(B * P). Two options help manage this:
param_fraction: Compute the Jacobian with respect to a random subset of parameters each step. Set to e.g.0.5to halve memory usage.microbatch_size: Aggregate losses within sub-batches before computing the Jacobian, reducing the effective batch dimension.
sven/
├── nn/
│ ├── sven_wrapper.py # SvenWrapper: functional model wrapper + Jacobian computation
│ └── __init__.py
└── opt/
├── sven.py # Sven optimizer
├── pinv.py # Truncated SVD pseudoinverse implementations
├── polyak.py # PolyakSGD baseline optimizer
└── __init__.py
Given a batch
Sven instead treats each element's contribution to the loss separately. Inspired by the
where
In the
with the Jacobian matrix defined as
We seek solutions that drive each term of the loss to zero (or as close to zero as it can get in the linear approximation):
An exact solution rarely exists, but the closest approximation to one is given by
where
For a generic loss function as written above with
where
In practice, while