Describe the bug
LinearFunctionForZeroStage3 in deepspeed/runtime/zero/linear.py uses the legacy autograd.Function pattern where forward(ctx, ...) directly calls ctx.save_for_backward(). This crashes when any code path uses torch.func transforms (torch.func.grad, torch.func.grad_and_value, vmap, etc.) on a model running with ZeRO Stage 3.
RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
This affects any library that uses torch.func internally, including Liger-Kernel (fused cross-entropy) and Axolotl's KD (Knowledge Distillation) kernel.
To Reproduce
- Enable ZeRO Stage 3
- Use any operation that internally calls
torch.func.grad_and_value during forward/loss computation (e.g. Axolotl offline KD with Liger kernel, or any custom loss using functorch)
- Training crashes at the first forward pass
Minimal example:
import torch
import deepspeed
model = ... # any model
model, _, _, _ = deepspeed.initialize(model=model, config={"zero_optimization": {"stage": 3}, ...})
# Any torch.func usage on a ZeRO-3 wrapped model triggers the crash
# because F.linear is replaced by LinearFunctionForZeroStage3
# which lacks setup_context
torch.func.grad_and_value(some_loss_fn)(params, inputs)
Expected behavior
LinearFunctionForZeroStage3 should be compatible with torch.func transforms. The ctx usage in forward is minimal (only save_for_backward), so splitting into the setup_context pattern is straightforward:
# Current (crashes with torch.func)
class LinearFunctionForZeroStage3(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
...
# Fix: separate setup_context (compatible with torch.func)
class LinearFunctionForZeroStage3(torch.autograd.Function):
@staticmethod
def forward(input, weight, bias=None):
...
return ret
@staticmethod
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
backward needs no changes. Note: setup_context requires PyTorch >= 2.0, so a version-conditional definition may be needed if PyTorch < 2.0 support is still required.
ds_report output
DeepSpeed C++/CUDA extension op report
NOTE: Alarm op not installed
NOTE: AsyncIO op not installed
...
torch install path: /usr/local/lib/python3.11/dist-packages/torch
torch version: 2.8.0+cu128
deepspeed install path: /usr/local/lib/python3.11/dist-packages/deepspeed
deepspeed info: 0.16.4, unknown, unknown
torch cuda version: 12.8
torch hip version: None
nvcc version: Not Available
deepspeed wheel compiled w.: torch 2.8, cuda 12.8
System info:
- OS: Ubuntu 22.04
- GPU: H100 80GB PCIe
- Python: 3.11
- PyTorch: 2.8.0+cu128
Launcher context
deepspeed CLI launcher
Additional context
The root cause is that PyTorch's torch.func transforms require autograd.Function subclasses to use the separate forward() + setup_context() pattern instead of the combined forward(ctx, ...) pattern. This is documented in Extending torch.func with autograd.Function.
The current forward only uses ctx for a single save_for_backward call, making the migration minimal. The backward method uses ctx.saved_tensors and ctx.needs_input_grad, both of which work identically with the new pattern.
One consideration: setup_context was introduced in PyTorch 2.0. DeepSpeed's documented minimum is PyTorch 1.9+. if PyTorch < 2.0 is no longer actively supported, the migration can be done unconditionally.
Happy to submit a PR if this direction is acceptable.
Describe the bug
LinearFunctionForZeroStage3indeepspeed/runtime/zero/linear.pyuses the legacyautograd.Functionpattern whereforward(ctx, ...)directly callsctx.save_for_backward(). This crashes when any code path usestorch.functransforms (torch.func.grad,torch.func.grad_and_value,vmap, etc.) on a model running with ZeRO Stage 3.This affects any library that uses
torch.funcinternally, including Liger-Kernel (fused cross-entropy) and Axolotl's KD (Knowledge Distillation) kernel.To Reproduce
torch.func.grad_and_valueduring forward/loss computation (e.g. Axolotl offline KD with Liger kernel, or any custom loss using functorch)Minimal example:
Expected behavior
LinearFunctionForZeroStage3should be compatible withtorch.functransforms. Thectxusage inforwardis minimal (onlysave_for_backward), so splitting into thesetup_contextpattern is straightforward:backwardneeds no changes. Note:setup_contextrequires PyTorch >= 2.0, so a version-conditional definition may be needed if PyTorch < 2.0 support is still required.ds_report output
System info:
Launcher context
deepspeed CLI launcher
Additional context
The root cause is that PyTorch's
torch.functransforms requireautograd.Functionsubclasses to use the separateforward()+setup_context()pattern instead of the combinedforward(ctx, ...)pattern. This is documented in Extending torch.func with autograd.Function.The current
forwardonly usesctxfor a singlesave_for_backwardcall, making the migration minimal. Thebackwardmethod usesctx.saved_tensorsandctx.needs_input_grad, both of which work identically with the new pattern.One consideration:
setup_contextwas introduced in PyTorch 2.0. DeepSpeed's documented minimum is PyTorch 1.9+. if PyTorch < 2.0 is no longer actively supported, the migration can be done unconditionally.Happy to submit a PR if this direction is acceptable.