Skip to content

[BUG] LinearFunctionForZeroStage3 crashes with torch.func transforms (missing setup_context) #7913

@roycho96

Description

@roycho96

Describe the bug

LinearFunctionForZeroStage3 in deepspeed/runtime/zero/linear.py uses the legacy autograd.Function pattern where forward(ctx, ...) directly calls ctx.save_for_backward(). This crashes when any code path uses torch.func transforms (torch.func.grad, torch.func.grad_and_value, vmap, etc.) on a model running with ZeRO Stage 3.

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.

This affects any library that uses torch.func internally, including Liger-Kernel (fused cross-entropy) and Axolotl's KD (Knowledge Distillation) kernel.

To Reproduce

  1. Enable ZeRO Stage 3
  2. Use any operation that internally calls torch.func.grad_and_value during forward/loss computation (e.g. Axolotl offline KD with Liger kernel, or any custom loss using functorch)
  3. Training crashes at the first forward pass

Minimal example:

import torch
import deepspeed

model = ...  # any model
model, _, _, _ = deepspeed.initialize(model=model, config={"zero_optimization": {"stage": 3}, ...})

# Any torch.func usage on a ZeRO-3 wrapped model triggers the crash
# because F.linear is replaced by LinearFunctionForZeroStage3
# which lacks setup_context
torch.func.grad_and_value(some_loss_fn)(params, inputs)

Expected behavior

LinearFunctionForZeroStage3 should be compatible with torch.func transforms. The ctx usage in forward is minimal (only save_for_backward), so splitting into the setup_context pattern is straightforward:

# Current (crashes with torch.func)
class LinearFunctionForZeroStage3(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        ...

# Fix: separate setup_context (compatible with torch.func)
class LinearFunctionForZeroStage3(torch.autograd.Function):
    @staticmethod
    def forward(input, weight, bias=None):
        ...
        return ret

    @staticmethod
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

backward needs no changes. Note: setup_context requires PyTorch >= 2.0, so a version-conditional definition may be needed if PyTorch < 2.0 support is still required.

ds_report output

DeepSpeed C++/CUDA extension op report
NOTE: Alarm op not installed
NOTE: AsyncIO op not installed
...
torch install path: /usr/local/lib/python3.11/dist-packages/torch
torch version: 2.8.0+cu128
deepspeed install path: /usr/local/lib/python3.11/dist-packages/deepspeed
deepspeed info: 0.16.4, unknown, unknown
torch cuda version: 12.8
torch hip version: None
nvcc version: Not Available
deepspeed wheel compiled w.: torch 2.8, cuda 12.8

System info:

  • OS: Ubuntu 22.04
  • GPU: H100 80GB PCIe
  • Python: 3.11
  • PyTorch: 2.8.0+cu128

Launcher context

deepspeed CLI launcher

Additional context

The root cause is that PyTorch's torch.func transforms require autograd.Function subclasses to use the separate forward() + setup_context() pattern instead of the combined forward(ctx, ...) pattern. This is documented in Extending torch.func with autograd.Function.

The current forward only uses ctx for a single save_for_backward call, making the migration minimal. The backward method uses ctx.saved_tensors and ctx.needs_input_grad, both of which work identically with the new pattern.

One consideration: setup_context was introduced in PyTorch 2.0. DeepSpeed's documented minimum is PyTorch 1.9+. if PyTorch < 2.0 is no longer actively supported, the migration can be done unconditionally.

Happy to submit a PR if this direction is acceptable.

Metadata

Metadata

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions