From 33db7c432916abb07d237192d9a401660f0127af Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Sat, 21 Mar 2026 18:18:23 +0900 Subject: [PATCH 01/17] fix: fix LinearFunctionForZeroStage3 to support torch.func transforms Signed-off-by: Sung Hyun Cho --- deepspeed/runtime/zero/linear.py | 167 ++++++++++++++++++++----------- 1 file changed, 110 insertions(+), 57 deletions(-) diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 0fd02cdc67ef..275efd672b38 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -35,69 +35,122 @@ def print_rank_0(message, debug=False, force=False): autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) +# PyTorch >= 2.0 supports setup_context, which is required for +# torch.func transforms (vmap, grad, jvp, jacrev, etc.) +_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context') -class LinearFunctionForZeroStage3(torch.autograd.Function): +if _SUPPORTS_SETUP_CONTEXT: - # Note that both forward and backward are @staticmethods - @staticmethod - @autocast_custom_fwd - # bias is an optional argument - def forward(ctx, input, weight, bias=None): + class LinearFunctionForZeroStage3(torch.autograd.Function): - ctx.save_for_backward(input, weight, bias) + @staticmethod + @autocast_custom_fwd + def forward(input, weight, bias=None): - if input.dim() == 2 and bias is not None: - # fused op is marginally faster - ret = torch.addmm(bias, input, weight.t()) - else: - output = input.matmul(weight.t()) - if bias is not None: - output += bias - ret = output - - return ret - - # This function has only a single output, so it gets only one gradient - @staticmethod - @autocast_custom_bwd - def backward(ctx, grad_output): - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - input, weight, bias = ctx.saved_tensors - - grad_input = grad_weight = grad_bias = None - - #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - dim = grad_output.dim() - if ctx.needs_input_grad[0]: - #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") - grad_input = grad_output.matmul(weight) - #print(f"Computed grad input {grad_input.shape}") - if ctx.needs_input_grad[1]: - #print("Computing grad weight") - if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) + if input.dim() == 2 and bias is not None: + # fused op is marginally faster + ret = torch.addmm(bias, input, weight.t()) else: - grad_weight = grad_output.t().matmul(input) - #print(f"Computed grad weight grad_weight {grad_weight.shape}") - if bias is not None and ctx.needs_input_grad[2]: - #print("Computing grad bias") - if dim > 2: - grad_bias = grad_output.sum([i for i in range(dim - 1)]) + output = input.matmul(weight.t()) + if bias is not None: + output += bias + ret = output + + return ret + + @staticmethod + def setup_context(ctx, inputs, output): + input, weight, bias = inputs + ctx.save_for_backward(input, weight, bias) + + # This function has only a single output, so it gets only one gradient + @staticmethod + @autocast_custom_bwd + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + dim = grad_output.dim() + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight) + if ctx.needs_input_grad[1]: + if dim > 2: + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) + else: + grad_weight = grad_output.t().matmul(input) + if bias is not None and ctx.needs_input_grad[2]: + if dim > 2: + grad_bias = grad_output.sum([i for i in range(dim - 1)]) + else: + grad_bias = grad_output.sum(0) + return grad_input, grad_weight, grad_bias + +else: + + class LinearFunctionForZeroStage3(torch.autograd.Function): + + # Note that both forward and backward are @staticmethods + @staticmethod + @autocast_custom_fwd + # bias is an optional argument + def forward(ctx, input, weight, bias=None): + + ctx.save_for_backward(input, weight, bias) + + if input.dim() == 2 and bias is not None: + # fused op is marginally faster + ret = torch.addmm(bias, input, weight.t()) else: - grad_bias = grad_output.sum(0) - #print("Done computing grad bias") - #print("needs bias") - #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") - return grad_input, grad_weight, grad_bias + output = input.matmul(weight.t()) + if bias is not None: + output += bias + ret = output + + return ret + + # This function has only a single output, so it gets only one gradient + @staticmethod + @autocast_custom_bwd + def backward(ctx, grad_output): + # This is a pattern that is very convenient - at the top of backward + # unpack saved_tensors and initialize all gradients w.r.t. inputs to + # None. Thanks to the fact that additional trailing Nones are + # ignored, the return statement is simple even when the function has + # optional inputs. + input, weight, bias = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") + # These needs_input_grad checks are optional and there only to + # improve efficiency. If you want to make your code simpler, you can + # skip them. Returning gradients for inputs that don't require it is + # not an error. + dim = grad_output.dim() + if ctx.needs_input_grad[0]: + #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") + grad_input = grad_output.matmul(weight) + #print(f"Computed grad input {grad_input.shape}") + if ctx.needs_input_grad[1]: + #print("Computing grad weight") + if dim > 2: + grad_weight = grad_output.reshape(-1, + grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) + else: + grad_weight = grad_output.t().matmul(input) + #print(f"Computed grad weight grad_weight {grad_weight.shape}") + if bias is not None and ctx.needs_input_grad[2]: + #print("Computing grad bias") + if dim > 2: + grad_bias = grad_output.sum([i for i in range(dim - 1)]) + else: + grad_bias = grad_output.sum(0) + #print("Done computing grad bias") + #print("needs bias") + #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") + return grad_input, grad_weight, grad_bias def zero3_linear_wrap(input, weight, bias=None): From 39b1755ace39144ecd08dd097c9e73161a446b03 Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Sat, 21 Mar 2026 18:32:24 +0900 Subject: [PATCH 02/17] fix: always pass bias arg in zero3_linear_wrap to avoid setup_context unpack error Signed-off-by: Sung Hyun Cho --- deepspeed/runtime/zero/linear.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 275efd672b38..1ea1a438d554 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -154,10 +154,7 @@ def backward(ctx, grad_output): def zero3_linear_wrap(input, weight, bias=None): - if bias is None: - return LinearFunctionForZeroStage3.apply(input, weight) - else: - return LinearFunctionForZeroStage3.apply(input, weight, bias) + return LinearFunctionForZeroStage3.apply(input, weight, bias) class LinearModuleForZeroStage3(Module): From 6df37afb96961aeaf92a0870e674183d8e6aaf61 Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Sun, 22 Mar 2026 12:23:27 +0900 Subject: [PATCH 03/17] fix: remove @autocast_custom_fwd from forward, move autocast state to setup_context Co-authored-by: zhangj1an Signed-off-by: Sung Hyun Cho --- deepspeed/runtime/zero/linear.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 1ea1a438d554..bdc1f05fd9d4 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -44,7 +44,7 @@ def print_rank_0(message, debug=False, force=False): class LinearFunctionForZeroStage3(torch.autograd.Function): @staticmethod - @autocast_custom_fwd + # bias is an optional argument def forward(input, weight, bias=None): if input.dim() == 2 and bias is not None: @@ -60,7 +60,13 @@ def forward(input, weight, bias=None): @staticmethod def setup_context(ctx, inputs, output): - input, weight, bias = inputs + # Replicate autocast state that @autocast_custom_fwd normally sets on ctx, + # since the decorator assumes args[0] is ctx which is unavailable in the + # separate forward() + setup_context() pattern. + device_type = get_accelerator().device_name() + ctx._dtype = torch.get_autocast_dtype(device_type) + ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type) + input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None ctx.save_for_backward(input, weight, bias) # This function has only a single output, so it gets only one gradient From c0b9694cce1b1eaee9d2119ebcc57ab854b2bef0 Mon Sep 17 00:00:00 2001 From: Zhang Date: Sun, 22 Mar 2026 08:52:50 +0000 Subject: [PATCH 04/17] fix(zero3): replace custom_bwd with explicit autocast for functorch-safe linear Avoid asymmetric custom_bwd without custom_fwd on the setup_context forward path; mirror forward AMP in backward via torch.amp.autocast. Signed-off-by: Zhang --- deepspeed/runtime/zero/linear.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index bdc1f05fd9d4..9bdfd9f4471a 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -71,8 +71,17 @@ def setup_context(ctx, inputs, output): # This function has only a single output, so it gets only one gradient @staticmethod - @autocast_custom_bwd def backward(ctx, grad_output): + # Do not use @autocast_custom_bwd here: it pairs with @autocast_custom_fwd on + # legacy forward(ctx, ...). With forward + setup_context, use AMP state from setup_context. + device_type = get_accelerator().device_name() + if getattr(ctx, "_fwd_used_autocast", False): + with torch.amp.autocast(device_type=device_type, enabled=True, dtype=ctx._dtype): + return LinearFunctionForZeroStage3._backward_core(ctx, grad_output) + return LinearFunctionForZeroStage3._backward_core(ctx, grad_output) + + @staticmethod + def _backward_core(ctx, grad_output): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None From 5e83d056fc72b4909d44976cd7ad50da80a491fb Mon Sep 17 00:00:00 2001 From: Zhang Date: Sun, 22 Mar 2026 09:13:25 +0000 Subject: [PATCH 05/17] fix(zero): use setup_context for offload pre/post backward Functions PyTorch versions that expose autograd.Function.setup_context need the modern forward + setup_context shape for torch.func / functorch. Signed-off-by: Zhang --- deepspeed/runtime/zero/parameter_offload.py | 139 ++++++++++++++------ 1 file changed, 99 insertions(+), 40 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index c434ff738933..6e4be7d98ba6 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -18,6 +18,10 @@ FWD_MODULE_STACK = list() +# PyTorch >= 2.0: setup_context on autograd.Function is required for torch.func transforms. +# Match deepspeed/runtime/zero/linear.py: keep legacy forward(ctx, ...) when unavailable. +_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, "setup_context") + #for each tensor in outputs run the forward_function and register backward_function as hook def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): @@ -401,23 +405,45 @@ def _run_before_backward_function(sub_module): sub_module.applied_pre_backward_ref_cnt -= 1 #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - class PreBackwardFunctionForModule(torch.autograd.Function): + if _SUPPORTS_SETUP_CONTEXT: + + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(outputs): + return outputs.detach() - @staticmethod - def forward(ctx, outputs): - # Capture `module` and _run_before_backward_function - ctx.module = module - ctx.pre_backward_function = _run_before_backward_function - if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): - ctx.module.applied_pre_backward_ref_cnt = 0 - ctx.module.applied_pre_backward_ref_cnt += 1 - outputs = outputs.detach() - return outputs + @staticmethod + def setup_context(ctx, inputs, output): + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args + + else: - @staticmethod - def backward(ctx, *args): - ctx.pre_backward_function(ctx.module) - return args + class PreBackwardFunctionForModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, outputs): + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 + outputs = outputs.detach() + return outputs + + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args module.pre_bwd_fn = PreBackwardFunctionForModule @@ -431,31 +457,64 @@ def _run_after_backward_function(sub_module): if sub_module.ds_grads_remaining == 0: self.post_sub_module_backward_function(sub_module) - class PostBackwardFunctionModule(torch.autograd.Function): - - @staticmethod - def forward(ctx, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.post_backward_function = _run_after_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.post_backward_function(ctx.module) - return args + if _SUPPORTS_SETUP_CONTEXT: + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(output): + return output.detach() + + @staticmethod + def setup_context(ctx, inputs, output): + (output_in,) = inputs + ctx.module = module + if output_in.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args + + else: + + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(ctx, output): + ctx.module = module + if output.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + output = output.detach() + return output + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args module.post_bwd_fn = PostBackwardFunctionModule From a1e798d596d42c7ba27ad0e39a07d80203cf0b8a Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 03:26:08 +0000 Subject: [PATCH 06/17] run pre-commit checks Signed-off-by: Zhang --- deepspeed/runtime/zero/linear.py | 8 +- deepspeed/runtime/zero/parameter_offload.py | 2 +- .../v1/zero/test_zero_functorch_linear.py | 82 +++++++++++++++++++ 3 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 tests/unit/v1/zero/test_zero_functorch_linear.py diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 9bdfd9f4471a..86dd91717c4b 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -91,8 +91,8 @@ def _backward_core(ctx, grad_output): grad_input = grad_output.matmul(weight) if ctx.needs_input_grad[1]: if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) + grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul( + input.reshape(-1, input.shape[-1])) else: grad_weight = grad_output.t().matmul(input) if bias is not None and ctx.needs_input_grad[2]: @@ -151,8 +151,8 @@ def backward(ctx, grad_output): if ctx.needs_input_grad[1]: #print("Computing grad weight") if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) + grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul( + input.reshape(-1, input.shape[-1])) else: grad_weight = grad_output.t().matmul(input) #print(f"Computed grad weight grad_weight {grad_weight.shape}") diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index 6e4be7d98ba6..cad8d502f6a7 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -467,7 +467,7 @@ def forward(output): @staticmethod def setup_context(ctx, inputs, output): - (output_in,) = inputs + (output_in, ) = inputs ctx.module = module if output_in.requires_grad: #TODO SOME TIMES post backward does not seem to be triggered debug in detail diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py new file mode 100644 index 000000000000..1ff338441869 --- /dev/null +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression: ZeRO-3 patched F.linear must work with torch.func transforms. + +After deepspeed.initialize with ZeRO Stage 3, ``torch.nn.functional.linear`` is +replaced with ``LinearFunctionForZeroStage3``. That autograd.Function must use +the ``forward`` + ``setup_context`` pattern (PyTorch 2.0+); the legacy +``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward raises:: + + RuntimeError: In order to use an autograd.Function with functorch + transforms ... it must override the setup_context staticmethod. + +See ``repro_zero3_functorch_linear.py`` for a standalone script version. +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +import deepspeed +from deepspeed.accelerator import get_accelerator + +from unit.common import DistributedTest + + +def _zero3_functorch_config(): + config = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 2147483647, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 0, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + }, + }, + } + acc = get_accelerator() + if acc.is_bf16_supported(): + config["bf16"] = {"enabled": True} + elif acc.is_fp16_supported(): + config["fp16"] = {"enabled": True, "initial_scale_power": 8} + return config + + +class TestZeroFunctorchLinearRegression(DistributedTest): + """``torch.func.grad_and_value`` over ZeRO-3 memory-efficient F.linear.""" + + world_size = 1 + + def test_grad_and_value_over_patched_functional_linear(self): + if not hasattr(torch, "func"): + pytest.skip("torch.func not available") + if not hasattr(torch.autograd.Function, "setup_context"): + pytest.skip("Requires PyTorch 2.0+ autograd.Function.setup_context") + + model = nn.Linear(8, 8, bias=True) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + + device = engine.device + dtype = engine.module.weight.dtype + weight = torch.randn(8, 8, device=device, dtype=dtype, requires_grad=True) + inp = torch.randn(2, 8, device=device, dtype=dtype, requires_grad=True) + + def loss_fn(w, x): + return F.linear(x, w, None).sum() + + grads, value = torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) + assert torch.isfinite(value) + assert grads[0] is not None and torch.isfinite(grads[0]).all() + assert grads[1] is not None and torch.isfinite(grads[1]).all() From 8762d00f6bee898b57c1b0e57afe3fba9ebbec5a Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 03:47:25 +0000 Subject: [PATCH 07/17] update unit tests to reproduce main branch error Signed-off-by: Zhang --- .../v1/zero/test_zero_functorch_linear.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py index 1ff338441869..a92ae7529dd7 100644 --- a/tests/unit/v1/zero/test_zero_functorch_linear.py +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -2,26 +2,26 @@ # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team -"""Regression: ZeRO-3 patched F.linear must work with torch.func transforms. +"""Regression: ZeRO-3 linear autograd.Function must work with torch.func transforms. -After deepspeed.initialize with ZeRO Stage 3, ``torch.nn.functional.linear`` is -replaced with ``LinearFunctionForZeroStage3``. That autograd.Function must use -the ``forward`` + ``setup_context`` pattern (PyTorch 2.0+); the legacy -``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward raises:: +ZeRO Stage 3 uses ``LinearFunctionForZeroStage3`` (via ``zero3_linear_wrap``) as +the memory-efficient linear path. After ``deepspeed.initialize``, global +``torch.nn.functional.linear`` is often the built-in again, so tests call +``zero3_linear_wrap`` directly—the same ``autograd.Function`` as when the patch +is active. Legacy ``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward +raises on strict functorch builds:: RuntimeError: In order to use an autograd.Function with functorch transforms ... it must override the setup_context staticmethod. - -See ``repro_zero3_functorch_linear.py`` for a standalone script version. """ import pytest import torch import torch.nn as nn -import torch.nn.functional as F import deepspeed from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.linear import zero3_linear_wrap from unit.common import DistributedTest @@ -51,7 +51,7 @@ def _zero3_functorch_config(): class TestZeroFunctorchLinearRegression(DistributedTest): - """``torch.func.grad_and_value`` over ZeRO-3 memory-efficient F.linear.""" + """``torch.func.grad_and_value`` over ``zero3_linear_wrap`` / LinearFunctionForZeroStage3.""" world_size = 1 @@ -73,8 +73,12 @@ def test_grad_and_value_over_patched_functional_linear(self): weight = torch.randn(8, 8, device=device, dtype=dtype, requires_grad=True) inp = torch.randn(2, 8, device=device, dtype=dtype, requires_grad=True) + with torch.enable_grad(): + probe = zero3_linear_wrap(inp, weight, None) + assert "LinearFunctionForZeroStage3" in type(probe.grad_fn).__name__ + def loss_fn(w, x): - return F.linear(x, w, None).sum() + return zero3_linear_wrap(x, w, None).sum() grads, value = torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) assert torch.isfinite(value) From dd037da26042c26457e593e15ed4b8f0231e0ac2 Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 03:55:57 +0000 Subject: [PATCH 08/17] add reproduce scripts Signed-off-by: Zhang --- scripts/repro_pr7916.py | 100 ++++++++++++++++++++++++++++++++++++++++ scripts/setup_pr7916.sh | 43 +++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 scripts/repro_pr7916.py create mode 100644 scripts/setup_pr7916.sh diff --git a/scripts/repro_pr7916.py b/scripts/repro_pr7916.py new file mode 100644 index 000000000000..6ce594c92215 --- /dev/null +++ b/scripts/repro_pr7916.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 +# +# Repro: functorch over ZeRO-3 memory-efficient linear (LinearFunctionForZeroStage3). +# +# Legacy autograd.Function.forward(ctx, ...) + ctx.save_for_backward in that class +# triggers (PyTorch builds that enforce functorch custom-Function rules, e.g. 2.8+): +# +# RuntimeError: In order to use an autograd.Function with functorch transforms +# (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. +# +# Why we call zero3_linear_wrap() instead of torch.nn.functional.linear: +# After deepspeed.initialize(), the global ZeRO Init context has usually ended, so +# torch.nn.functional.linear is often restored to PyTorch's built-in. That means +# F.linear in a post-init script does NOT hit LinearFunctionForZeroStage3. The +# Stage-3 patch uses zero3_linear_wrap (see partition_parameters.py); it is the +# same autograd.Function — calling it here reliably reproduces the bug on unfixed +# trees and validates the fix on fixed trees. +# +# Regression coverage: tests/unit/v1/zero/test_zero_functorch_linear.py +# +# Run from the DeepSpeed repo root (single GPU), after scripts/setup.sh: +# torchrun --standalone --nproc_per_node=1 scripts/repro_zero3_functorch_linear.py +# +# To test an unfixed DeepSpeed tree without importing another checkout by mistake, +# copy this file outside the repo (e.g. /tmp) and set PYTHONPATH to that tree: +# cp scripts/repro_zero3_functorch_linear.py /tmp/ && cd /tmp && \ +# PYTHONPATH=/path/to/deepspeed-checkout torchrun --standalone --nproc_per_node=1 repro_zero3_functorch_linear.py +# +# Requires: PyTorch with torch.func and strict custom-Function checks (e.g. 2.8+), +# DeepSpeed ZeRO-3, CUDA (typical setup). + +import torch +import torch.nn as nn + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.linear import zero3_linear_wrap + + +def _assert_hits_zero3_linear(weight, inp): + """Sanity check: we are exercising LinearFunctionForZeroStage3, not built-in linear.""" + with torch.enable_grad(): + y = zero3_linear_wrap(inp, weight, None) + name = type(y.grad_fn).__name__ + assert "LinearFunctionForZeroStage3" in name, ( + f"Expected LinearFunctionForZeroStage3 in grad_fn, got {name!r}. " + "Repro would not test the intended autograd.Function.") + + +def main(): + if not hasattr(torch, "func"): + raise SystemExit("This repro requires torch.func (PyTorch 2.0+).") + if not hasattr(torch.autograd.Function, "setup_context"): + raise SystemExit("This repro requires autograd.Function.setup_context (PyTorch 2.0+).") + + deepspeed.init_distributed() + acc = get_accelerator() + device = acc.device_name() + ":" + str(acc.current_device()) + + model = nn.Linear(8, 8, bias=True).to(device) + + config = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 2147483647, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 0, + }, + "optimizer": {"type": "Adam", "params": {"lr": 1e-3}}, + } + if acc.is_bf16_supported(): + config["bf16"] = {"enabled": True} + elif acc.is_fp16_supported(): + config["fp16"] = {"enabled": True, "initial_scale_power": 8} + + _, _, _, _ = deepspeed.initialize( + model=model, + config=config, + model_parameters=model.parameters(), + ) + + weight = torch.randn(8, 8, device=device, dtype=model.weight.dtype, requires_grad=True) + inp = torch.randn(2, 8, device=device, dtype=model.weight.dtype, requires_grad=True) + + if deepspeed.comm.get_rank() == 0: + _assert_hits_zero3_linear(weight, inp) + + def loss_fn(w, x): + # Same op as ZeRO-3's F.linear replacement when the patch is active. + return zero3_linear_wrap(x, w, None).sum() + + torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) + if deepspeed.comm.get_rank() == 0: + print("repro: grad_and_value over zero3_linear_wrap (LinearFunctionForZeroStage3) OK.") + + +if __name__ == "__main__": + main() diff --git a/scripts/setup_pr7916.sh b/scripts/setup_pr7916.sh new file mode 100644 index 000000000000..d931739ed834 --- /dev/null +++ b/scripts/setup_pr7916.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Create .venv at the DeepSpeed repo root with: +# - PyTorch 2.8.0+cu128 (CUDA 12.8) +# - requirements from requirements/requirements.txt +# - DeepSpeed editable install from the *current* checkout (latest local code) +# - pytest (for unit tests) +# +# Usage (from anywhere): +# ./scripts/setup.sh +# +# Then from repo root: +# source .venv/bin/activate +# torchrun --standalone --nproc_per_node=1 scripts/repro_zero3_functorch_linear.py +# +# To reproduce a failure on an older DeepSpeed release instead of this tree: +# pip install 'deepspeed==0.16.4' # after venv is active; skip pip install -e . once or use a fresh venv +# +set -euo pipefail +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +cd "$ROOT" + +rm -rf .venv +python3 -m venv .venv +# shellcheck source=/dev/null +. .venv/bin/activate + +python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { + echo "Warning: expected Python 3.11; found $(python -V)" >&2 +} + +pip install -U pip setuptools wheel + +# PyTorch 2.8.0 + CUDA 12.8 (matches common functorch / ZeRO-3 bug reports) +pip install "torch==2.8.0" --index-url https://download.pytorch.org/whl/cu128 + +pip install -r requirements/requirements.txt + +# Latest DeepSpeed = this git checkout (editable) +pip install -e . + +pip install pytest + +python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" From f69c1f1a5e4413860fac64348c012e754c0e27d9 Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 13:19:24 +0000 Subject: [PATCH 09/17] update reproduce script Signed-off-by: Zhang --- scripts/repro_pr7916.py | 4 +- scripts/setup_pr7916.sh | 90 +++++++++++++++++++++++++++++++++++------ 2 files changed, 79 insertions(+), 15 deletions(-) mode change 100644 => 100755 scripts/setup_pr7916.sh diff --git a/scripts/repro_pr7916.py b/scripts/repro_pr7916.py index 6ce594c92215..83c51e44f08b 100644 --- a/scripts/repro_pr7916.py +++ b/scripts/repro_pr7916.py @@ -20,8 +20,8 @@ # # Regression coverage: tests/unit/v1/zero/test_zero_functorch_linear.py # -# Run from the DeepSpeed repo root (single GPU), after scripts/setup.sh: -# torchrun --standalone --nproc_per_node=1 scripts/repro_zero3_functorch_linear.py +# Run from the DeepSpeed repo root (single GPU), after scripts/setup_pr7916.sh (or manually): +# torchrun --standalone --nproc_per_node=1 scripts/repro_pr7916.py # # To test an unfixed DeepSpeed tree without importing another checkout by mistake, # copy this file outside the repo (e.g. /tmp) and set PYTHONPATH to that tree: diff --git a/scripts/setup_pr7916.sh b/scripts/setup_pr7916.sh old mode 100644 new mode 100755 index d931739ed834..2befa6328729 --- a/scripts/setup_pr7916.sh +++ b/scripts/setup_pr7916.sh @@ -1,28 +1,35 @@ #!/usr/bin/env bash -# Create .venv at the DeepSpeed repo root with: +# Create an isolated venv for PR 7916 repro at .venvs/pr7916 (repo root): +# - Removes only that venv if it already exists (does not touch .venv or other envs) # - PyTorch 2.8.0+cu128 (CUDA 12.8) # - requirements from requirements/requirements.txt -# - DeepSpeed editable install from the *current* checkout (latest local code) -# - pytest (for unit tests) +# - DeepSpeed editable install from the *current* checkout +# - pytest # -# Usage (from anywhere): -# ./scripts/setup.sh +# Then validates the fix by: +# 1) Running the repro with the current branch (expect success + "OK" line) +# 2) Checking out master and running the same repro script (expect original RuntimeError) +# 3) Checking back to the branch you started on # -# Then from repo root: -# source .venv/bin/activate -# torchrun --standalone --nproc_per_node=1 scripts/repro_zero3_functorch_linear.py +# Usage (from repo root): +# ./scripts/setup_pr7916.sh # -# To reproduce a failure on an older DeepSpeed release instead of this tree: -# pip install 'deepspeed==0.16.4' # after venv is active; skip pip install -e . once or use a fresh venv +# Activate later: +# source .venvs/pr7916/bin/activate # set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$ROOT" -rm -rf .venv -python3 -m venv .venv +VENV_DIR="${PR7916_VENV_DIR:-$ROOT/.venvs/pr7916}" +MAIN_REF="${PR7916_MAIN_REF:-master}" + +echo "==> Using venv: $VENV_DIR (only this path is removed if it already exists)" +rm -rf "$VENV_DIR" +mkdir -p "$(dirname "$VENV_DIR")" +python3 -m venv "$VENV_DIR" # shellcheck source=/dev/null -. .venv/bin/activate +. "$VENV_DIR/bin/activate" python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { echo "Warning: expected Python 3.11; found $(python -V)" >&2 @@ -41,3 +48,60 @@ pip install -e . pip install pytest python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" + +REPRO_SRC="$ROOT/scripts/repro_pr7916.py" +if [[ ! -f "$REPRO_SRC" ]]; then + echo "error: missing $REPRO_SRC (need repro script on current branch)" >&2 + exit 1 +fi + +REPRO_TMP="$(mktemp /tmp/repro_pr7916_XXXXXX.py)" +cp "$REPRO_SRC" "$REPRO_TMP" +cleanup_repro_tmp() { rm -f "$REPRO_TMP"; } +trap cleanup_repro_tmp EXIT + +FIX_BRANCH="$(git rev-parse --abbrev-ref HEAD)" +TORCHRUN=(torchrun --standalone --nproc_per_node=1) + +echo "" +echo "==> [1/2] Repro on fix branch: $FIX_BRANCH (expect success)" +"${TORCHRUN[@]}" "$REPRO_TMP" + +echo "" +echo "==> [2/2] Repro on $MAIN_REF (expect original functorch / setup_context error)" +STASHED=0 +if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then + echo "==> Stashing local changes so checkout to $MAIN_REF can proceed..." + git stash push -m "pr7916-setup: temp stash before main repro" + STASHED=1 +fi +if ! git checkout "$MAIN_REF"; then + echo "error: could not checkout $MAIN_REF" >&2 + if [[ "$STASHED" -eq 1 ]]; then + git stash pop || true + fi + exit 1 +fi +set +e +"${TORCHRUN[@]}" "$REPRO_TMP" +MAIN_EC=$? +set -e +if [[ "$MAIN_EC" -eq 0 ]]; then + echo "" >&2 + echo "warning: main branch run exited 0 — expected failure on unfixed tree." >&2 +else + echo "" + echo "main branch run exited with $MAIN_EC (non-zero is expected for the unfixed tree)." +fi + +echo "" +echo "==> Restoring branch: $FIX_BRANCH" +git checkout "$FIX_BRANCH" + +if [[ "$STASHED" -eq 1 ]]; then + echo "==> Restoring stashed local changes..." + git stash pop || echo "warning: stash pop failed (resolve manually with git stash list)" >&2 +fi + +echo "" +echo "Done. To use this environment: source $VENV_DIR/bin/activate" From e58ac18c4bb92e60dc5e9ddd511d0cbcc7626faa Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 13:22:43 +0000 Subject: [PATCH 10/17] update reproduce script to skip repeated env setup Signed-off-by: Zhang --- scripts/setup_pr7916.sh | 56 ++++++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/scripts/setup_pr7916.sh b/scripts/setup_pr7916.sh index 2befa6328729..2f9d7e985a0b 100755 --- a/scripts/setup_pr7916.sh +++ b/scripts/setup_pr7916.sh @@ -13,6 +13,8 @@ # # Usage (from repo root): # ./scripts/setup_pr7916.sh +# ./scripts/setup_pr7916.sh --skip-install # reuse .venvs/pr7916, no pip/venv setup +# PR7916_SKIP_INSTALL=1 ./scripts/setup_pr7916.sh # # Activate later: # source .venvs/pr7916/bin/activate @@ -24,30 +26,50 @@ cd "$ROOT" VENV_DIR="${PR7916_VENV_DIR:-$ROOT/.venvs/pr7916}" MAIN_REF="${PR7916_MAIN_REF:-master}" -echo "==> Using venv: $VENV_DIR (only this path is removed if it already exists)" -rm -rf "$VENV_DIR" -mkdir -p "$(dirname "$VENV_DIR")" -python3 -m venv "$VENV_DIR" -# shellcheck source=/dev/null -. "$VENV_DIR/bin/activate" +SKIP_INSTALL=0 +case "${PR7916_SKIP_INSTALL:-}" in + 1|true|yes|on) SKIP_INSTALL=1 ;; +esac +if [[ "${1:-}" == "--skip-install" ]]; then + SKIP_INSTALL=1 + shift +fi -python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { - echo "Warning: expected Python 3.11; found $(python -V)" >&2 -} +if [[ "$SKIP_INSTALL" -eq 1 ]]; then + echo "==> Skipping venv recreate and pip installs (reuse $VENV_DIR)" + if [[ ! -f "$VENV_DIR/bin/activate" ]]; then + echo "error: venv not found at $VENV_DIR — run once without --skip-install first." >&2 + exit 1 + fi + # shellcheck source=/dev/null + . "$VENV_DIR/bin/activate" + python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" +else + echo "==> Using venv: $VENV_DIR (only this path is removed if it already exists)" + rm -rf "$VENV_DIR" + mkdir -p "$(dirname "$VENV_DIR")" + python3 -m venv "$VENV_DIR" + # shellcheck source=/dev/null + . "$VENV_DIR/bin/activate" -pip install -U pip setuptools wheel + python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { + echo "Warning: expected Python 3.11; found $(python -V)" >&2 + } -# PyTorch 2.8.0 + CUDA 12.8 (matches common functorch / ZeRO-3 bug reports) -pip install "torch==2.8.0" --index-url https://download.pytorch.org/whl/cu128 + pip install -U pip setuptools wheel -pip install -r requirements/requirements.txt + # PyTorch 2.8.0 + CUDA 12.8 (matches common functorch / ZeRO-3 bug reports) + pip install "torch==2.8.0" --index-url https://download.pytorch.org/whl/cu128 -# Latest DeepSpeed = this git checkout (editable) -pip install -e . + pip install -r requirements/requirements.txt -pip install pytest + # Latest DeepSpeed = this git checkout (editable) + pip install -e . -python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" + pip install pytest + + python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" +fi REPRO_SRC="$ROOT/scripts/repro_pr7916.py" if [[ ! -f "$REPRO_SRC" ]]; then From 3121a7f9476d0c783aa4b597c2b20cf0a7e51a0a Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 13:25:18 +0000 Subject: [PATCH 11/17] update reproduce script to remove duplicated code Signed-off-by: Zhang --- scripts/setup_pr7916.sh | 211 +++++++++++++++++++++------------------- 1 file changed, 111 insertions(+), 100 deletions(-) diff --git a/scripts/setup_pr7916.sh b/scripts/setup_pr7916.sh index 2f9d7e985a0b..9e0d8d3aab0c 100755 --- a/scripts/setup_pr7916.sh +++ b/scripts/setup_pr7916.sh @@ -1,129 +1,140 @@ #!/usr/bin/env bash -# Create an isolated venv for PR 7916 repro at .venvs/pr7916 (repo root): -# - Removes only that venv if it already exists (does not touch .venv or other envs) -# - PyTorch 2.8.0+cu128 (CUDA 12.8) -# - requirements from requirements/requirements.txt -# - DeepSpeed editable install from the *current* checkout -# - pytest +# PR 7916: venv at .venvs/pr7916, PyTorch 2.8 + cu128, then repro on current branch vs master. # -# Then validates the fix by: -# 1) Running the repro with the current branch (expect success + "OK" line) -# 2) Checking out master and running the same repro script (expect original RuntimeError) -# 3) Checking back to the branch you started on +# Venv: reuses $VENV_DIR if bin/activate exists (no pip). --force-install always recreates. +# --skip-install reuses only and errors if the venv is missing. # -# Usage (from repo root): -# ./scripts/setup_pr7916.sh -# ./scripts/setup_pr7916.sh --skip-install # reuse .venvs/pr7916, no pip/venv setup -# PR7916_SKIP_INSTALL=1 ./scripts/setup_pr7916.sh -# -# Activate later: -# source .venvs/pr7916/bin/activate +# Usage: ./scripts/setup_pr7916.sh [--force-install] [--skip-install] +# Env: PR7916_VENV_DIR, PR7916_MAIN_REF (default master), PR7916_FORCE_INSTALL, PR7916_SKIP_INSTALL # set -euo pipefail + ROOT="$(cd "$(dirname "$0")/.." && pwd)" cd "$ROOT" VENV_DIR="${PR7916_VENV_DIR:-$ROOT/.venvs/pr7916}" MAIN_REF="${PR7916_MAIN_REF:-master}" - -SKIP_INSTALL=0 -case "${PR7916_SKIP_INSTALL:-}" in - 1|true|yes|on) SKIP_INSTALL=1 ;; -esac -if [[ "${1:-}" == "--skip-install" ]]; then - SKIP_INSTALL=1 +VENV_SH="$VENV_DIR/bin/activate" + +truthy() { case "${1:-}" in 1|true|yes|on) return 0;; *) return 1;; esac; } + +force=0 +skip_only=0 +truthy "${PR7916_FORCE_INSTALL:-}" && force=1 +truthy "${PR7916_SKIP_INSTALL:-}" && skip_only=1 +while [[ $# -gt 0 ]]; do + case "$1" in + --force-install) force=1 ;; + --skip-install) skip_only=1 ;; + *) echo "error: unknown argument: $1" >&2; exit 1 ;; + esac shift -fi +done -if [[ "$SKIP_INSTALL" -eq 1 ]]; then - echo "==> Skipping venv recreate and pip installs (reuse $VENV_DIR)" - if [[ ! -f "$VENV_DIR/bin/activate" ]]; then - echo "error: venv not found at $VENV_DIR — run once without --skip-install first." >&2 - exit 1 - fi - # shellcheck source=/dev/null - . "$VENV_DIR/bin/activate" +print_versions() { python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" -else - echo "==> Using venv: $VENV_DIR (only this path is removed if it already exists)" - rm -rf "$VENV_DIR" - mkdir -p "$(dirname "$VENV_DIR")" - python3 -m venv "$VENV_DIR" - # shellcheck source=/dev/null - . "$VENV_DIR/bin/activate" - - python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { - echo "Warning: expected Python 3.11; found $(python -V)" >&2 - } +} + +# Sets: full=1 → wipe + venv + pip; full=0 → activate existing only +decide_full_setup() { + if [[ "$force" -eq 1 ]]; then + echo 1 + elif [[ "$skip_only" -eq 1 ]]; then + echo 0 + elif [[ -f "$VENV_SH" ]]; then + echo 0 + else + echo 1 + fi +} + +setup_venv() { + local full + full="$(decide_full_setup)" + + if [[ "$full" -eq 0 ]]; then + [[ -f "$VENV_SH" ]] || { + echo "error: no venv at $VENV_DIR (drop --skip-install or run once without it)" >&2 + exit 1 + } + echo "==> Reusing venv $VENV_DIR (use --force-install to reinstall)" + else + echo "==> Creating venv at $VENV_DIR" + rm -rf "$VENV_DIR" + mkdir -p "$(dirname "$VENV_DIR")" + python3 -m venv "$VENV_DIR" + fi - pip install -U pip setuptools wheel + # shellcheck source=/dev/null + . "$VENV_SH" + + if [[ "$full" -eq 1 ]]; then + python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { + echo "Warning: expected Python 3.11; found $(python -V)" >&2 + } + pip install -U pip setuptools wheel + pip install "torch==2.8.0" --index-url https://download.pytorch.org/whl/cu128 + pip install -r requirements/requirements.txt + pip install -e . + pip install pytest + fi - # PyTorch 2.8.0 + CUDA 12.8 (matches common functorch / ZeRO-3 bug reports) - pip install "torch==2.8.0" --index-url https://download.pytorch.org/whl/cu128 + print_versions +} - pip install -r requirements/requirements.txt +run_repro_compare() { + local REPRO_SRC="$ROOT/scripts/repro_pr7916.py" REPRO_TMP FIX_BRANCH STASHED=0 MAIN_EC - # Latest DeepSpeed = this git checkout (editable) - pip install -e . + [[ -f "$REPRO_SRC" ]] || { + echo "error: missing $REPRO_SRC (need this file on the current branch)" >&2 + exit 1 + } - pip install pytest + REPRO_TMP="$(mktemp /tmp/repro_pr7916_XXXXXX.py)" + cp "$REPRO_SRC" "$REPRO_TMP" + trap 'rm -f "$REPRO_TMP"' EXIT - python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" -fi + FIX_BRANCH="$(git rev-parse --abbrev-ref HEAD)" + local -a run=(torchrun --standalone --nproc_per_node=1) -REPRO_SRC="$ROOT/scripts/repro_pr7916.py" -if [[ ! -f "$REPRO_SRC" ]]; then - echo "error: missing $REPRO_SRC (need repro script on current branch)" >&2 - exit 1 -fi + echo "" + echo "==> [1/2] Repro on $FIX_BRANCH (expect OK)" + "${run[@]}" "$REPRO_TMP" -REPRO_TMP="$(mktemp /tmp/repro_pr7916_XXXXXX.py)" -cp "$REPRO_SRC" "$REPRO_TMP" -cleanup_repro_tmp() { rm -f "$REPRO_TMP"; } -trap cleanup_repro_tmp EXIT + echo "" + echo "==> [2/2] Repro on $MAIN_REF (expect setup_context RuntimeError on unfixed tree)" + if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then + echo "==> Stashing local changes for checkout..." + git stash push -m "pr7916-setup: temp stash before main repro" + STASHED=1 + fi + if ! git checkout "$MAIN_REF"; then + echo "error: checkout $MAIN_REF failed" >&2 + [[ "$STASHED" -eq 1 ]] && git stash pop || true + exit 1 + fi -FIX_BRANCH="$(git rev-parse --abbrev-ref HEAD)" -TORCHRUN=(torchrun --standalone --nproc_per_node=1) + set +e + "${run[@]}" "$REPRO_TMP" + MAIN_EC=$? + set -e -echo "" -echo "==> [1/2] Repro on fix branch: $FIX_BRANCH (expect success)" -"${TORCHRUN[@]}" "$REPRO_TMP" + if [[ "$MAIN_EC" -eq 0 ]]; then + echo "warning: main-branch repro exited 0 (expected failure on unfixed tree)." >&2 + else + echo "main-branch repro exited $MAIN_EC (non-zero expected for unfixed tree)." + fi -echo "" -echo "==> [2/2] Repro on $MAIN_REF (expect original functorch / setup_context error)" -STASHED=0 -if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then - echo "==> Stashing local changes so checkout to $MAIN_REF can proceed..." - git stash push -m "pr7916-setup: temp stash before main repro" - STASHED=1 -fi -if ! git checkout "$MAIN_REF"; then - echo "error: could not checkout $MAIN_REF" >&2 + echo "" + echo "==> Restoring $FIX_BRANCH" + git checkout "$FIX_BRANCH" if [[ "$STASHED" -eq 1 ]]; then - git stash pop || true + git stash pop || echo "warning: stash pop failed — see git stash list" >&2 fi - exit 1 -fi -set +e -"${TORCHRUN[@]}" "$REPRO_TMP" -MAIN_EC=$? -set -e -if [[ "$MAIN_EC" -eq 0 ]]; then - echo "" >&2 - echo "warning: main branch run exited 0 — expected failure on unfixed tree." >&2 -else - echo "" - echo "main branch run exited with $MAIN_EC (non-zero is expected for the unfixed tree)." -fi - -echo "" -echo "==> Restoring branch: $FIX_BRANCH" -git checkout "$FIX_BRANCH" +} -if [[ "$STASHED" -eq 1 ]]; then - echo "==> Restoring stashed local changes..." - git stash pop || echo "warning: stash pop failed (resolve manually with git stash list)" >&2 -fi +setup_venv +run_repro_compare echo "" -echo "Done. To use this environment: source $VENV_DIR/bin/activate" +echo "Done. Activate: source $VENV_DIR/bin/activate" From 60d20da79f27295181fbd42e989482868d2a5dcf Mon Sep 17 00:00:00 2001 From: Zhang Date: Wed, 25 Mar 2026 13:34:17 +0000 Subject: [PATCH 12/17] update reproduce script to print test env Signed-off-by: Zhang --- scripts/setup_pr7916.sh | 34 +++++++++++++++++++++++++++++++--- 1 file changed, 31 insertions(+), 3 deletions(-) diff --git a/scripts/setup_pr7916.sh b/scripts/setup_pr7916.sh index 9e0d8d3aab0c..e71944675eb9 100755 --- a/scripts/setup_pr7916.sh +++ b/scripts/setup_pr7916.sh @@ -7,6 +7,16 @@ # Usage: ./scripts/setup_pr7916.sh [--force-install] [--skip-install] # Env: PR7916_VENV_DIR, PR7916_MAIN_REF (default master), PR7916_FORCE_INSTALL, PR7916_SKIP_INSTALL # +# --- Recorded test environment (original bug report / CI reference) --- +# OS: Ubuntu 22.04 +# GPU: NVIDIA H100 80GB PCIe +# Python: 3.11 +# PyTorch: 2.8.0+cu128 (torch.version.cuda: 12.8) +# DeepSpeed: 0.16.4 wheel (issue); PR validates against editable install + this script's venv +# CUDA (driver): 12.8 (via PyTorch cu128 wheels; nvcc optional / often N/A) +# Launcher: deepspeed CLI (repro uses torchrun --standalone --nproc_per_node=1) +# ------------------------------------------------------------------------- +# set -euo pipefail ROOT="$(cd "$(dirname "$0")/.." && pwd)" @@ -31,8 +41,26 @@ while [[ $# -gt 0 ]]; do shift done -print_versions() { - python -c "import torch, deepspeed; print('torch', torch.__version__, 'cuda', torch.version.cuda); print('deepspeed', deepspeed.__file__); print('deepspeed version', deepspeed.__version__)" +print_runtime_env() { + python <<'PY' +import platform +import sys + +import deepspeed +import torch + +print("==> Runtime environment (this session)") +print(f" python: {sys.version.split()[0]} ({platform.system()} {platform.release()})") +print(f" torch: {torch.__version__}") +print(f" torch.version.cuda: {torch.version.cuda}") +if torch.cuda.is_available(): + print(f" cuda available: yes ({torch.cuda.device_count()} device(s))") + print(f" cuda device 0: {torch.cuda.get_device_name(0)}") +else: + print(" cuda available: no") +print(f" deepspeed: {deepspeed.__version__}") +print(f" deepspeed path: {deepspeed.__file__}") +PY } # Sets: full=1 → wipe + venv + pip; full=0 → activate existing only @@ -79,7 +107,7 @@ setup_venv() { pip install pytest fi - print_versions + print_runtime_env } run_repro_compare() { From bb245b2ecc089d1ccdb2373b10bedd6de3b5d9b2 Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Sun, 29 Mar 2026 12:20:52 +0900 Subject: [PATCH 13/17] drop PyTorch < 2.0 support and fix autocast backward in ZeRO linear Signed-off-by: Sung Hyun Cho --- deepspeed/runtime/zero/linear.py | 145 ++++-------------- deepspeed/runtime/zero/parameter_offload.py | 141 ++++++----------- .../v1/zero/test_zero_functorch_linear.py | 123 ++++++++++++++- 3 files changed, 194 insertions(+), 215 deletions(-) diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 86dd91717c4b..db95a5ac789c 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -16,7 +16,6 @@ #when implemented outside of torch.autograd.Function import math -import functools import torch from torch import Tensor @@ -32,139 +31,57 @@ def print_rank_0(message, debug=False, force=False): print(message) -autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name()) -autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name()) - -# PyTorch >= 2.0 supports setup_context, which is required for -# torch.func transforms (vmap, grad, jvp, jacrev, etc.) -_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, 'setup_context') - -if _SUPPORTS_SETUP_CONTEXT: - - class LinearFunctionForZeroStage3(torch.autograd.Function): - - @staticmethod - # bias is an optional argument - def forward(input, weight, bias=None): - - if input.dim() == 2 and bias is not None: - # fused op is marginally faster - ret = torch.addmm(bias, input, weight.t()) - else: - output = input.matmul(weight.t()) - if bias is not None: - output += bias - ret = output - - return ret - - @staticmethod - def setup_context(ctx, inputs, output): - # Replicate autocast state that @autocast_custom_fwd normally sets on ctx, - # since the decorator assumes args[0] is ctx which is unavailable in the - # separate forward() + setup_context() pattern. - device_type = get_accelerator().device_name() - ctx._dtype = torch.get_autocast_dtype(device_type) - ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type) - input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None - ctx.save_for_backward(input, weight, bias) - - # This function has only a single output, so it gets only one gradient - @staticmethod - def backward(ctx, grad_output): - # Do not use @autocast_custom_bwd here: it pairs with @autocast_custom_fwd on - # legacy forward(ctx, ...). With forward + setup_context, use AMP state from setup_context. - device_type = get_accelerator().device_name() - if getattr(ctx, "_fwd_used_autocast", False): - with torch.amp.autocast(device_type=device_type, enabled=True, dtype=ctx._dtype): - return LinearFunctionForZeroStage3._backward_core(ctx, grad_output) - return LinearFunctionForZeroStage3._backward_core(ctx, grad_output) - - @staticmethod - def _backward_core(ctx, grad_output): - input, weight, bias = ctx.saved_tensors +class LinearFunctionForZeroStage3(torch.autograd.Function): - grad_input = grad_weight = grad_bias = None + @staticmethod + # bias is an optional argument + def forward(input, weight, bias=None): - dim = grad_output.dim() - if ctx.needs_input_grad[0]: - grad_input = grad_output.matmul(weight) - if ctx.needs_input_grad[1]: - if dim > 2: - grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul( - input.reshape(-1, input.shape[-1])) - else: - grad_weight = grad_output.t().matmul(input) - if bias is not None and ctx.needs_input_grad[2]: - if dim > 2: - grad_bias = grad_output.sum([i for i in range(dim - 1)]) - else: - grad_bias = grad_output.sum(0) - return grad_input, grad_weight, grad_bias - -else: - - class LinearFunctionForZeroStage3(torch.autograd.Function): - - # Note that both forward and backward are @staticmethods - @staticmethod - @autocast_custom_fwd - # bias is an optional argument - def forward(ctx, input, weight, bias=None): - - ctx.save_for_backward(input, weight, bias) - - if input.dim() == 2 and bias is not None: - # fused op is marginally faster - ret = torch.addmm(bias, input, weight.t()) - else: - output = input.matmul(weight.t()) - if bias is not None: - output += bias - ret = output - - return ret - - # This function has only a single output, so it gets only one gradient - @staticmethod - @autocast_custom_bwd - def backward(ctx, grad_output): - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. + if input.dim() == 2 and bias is not None: + # fused op is marginally faster + ret = torch.addmm(bias, input, weight.t()) + else: + output = input.matmul(weight.t()) + if bias is not None: + output += bias + ret = output + + return ret + + @staticmethod + def setup_context(ctx, inputs, output): + device_type = get_accelerator().device_name() + ctx._dtype = torch.get_autocast_dtype(device_type) + ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type) + input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None + ctx.save_for_backward(input, weight, bias) + + # This function has only a single output, so it gets only one gradient + @staticmethod + def backward(ctx, grad_output): + # Match @custom_bwd semantics: always run backward under the same + # autocast state as forward — including explicitly disabling autocast + # when forward did not use it, to guard against outer autocast regions. + device_type = get_accelerator().device_name() + with torch.amp.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype): input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None - #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. dim = grad_output.dim() if ctx.needs_input_grad[0]: - #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") grad_input = grad_output.matmul(weight) - #print(f"Computed grad input {grad_input.shape}") if ctx.needs_input_grad[1]: - #print("Computing grad weight") if dim > 2: grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul( input.reshape(-1, input.shape[-1])) else: grad_weight = grad_output.t().matmul(input) - #print(f"Computed grad weight grad_weight {grad_weight.shape}") if bias is not None and ctx.needs_input_grad[2]: - #print("Computing grad bias") if dim > 2: grad_bias = grad_output.sum([i for i in range(dim - 1)]) else: grad_bias = grad_output.sum(0) - #print("Done computing grad bias") - #print("needs bias") - #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") return grad_input, grad_weight, grad_bias diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index cad8d502f6a7..b42b3c8e263e 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -18,10 +18,6 @@ FWD_MODULE_STACK = list() -# PyTorch >= 2.0: setup_context on autograd.Function is required for torch.func transforms. -# Match deepspeed/runtime/zero/linear.py: keep legacy forward(ctx, ...) when unavailable. -_SUPPORTS_SETUP_CONTEXT = hasattr(torch.autograd.Function, "setup_context") - #for each tensor in outputs run the forward_function and register backward_function as hook def _apply_forward_and_backward_to_tensors_only(module, forward_function, backward_function, outputs): @@ -405,45 +401,24 @@ def _run_before_backward_function(sub_module): sub_module.applied_pre_backward_ref_cnt -= 1 #print(f"COUNTER after: {sub_module.applied_pre_backward_ref_cnt}") - if _SUPPORTS_SETUP_CONTEXT: - - class PreBackwardFunctionForModule(torch.autograd.Function): - - @staticmethod - def forward(outputs): - return outputs.detach() - - @staticmethod - def setup_context(ctx, inputs, output): - ctx.module = module - ctx.pre_backward_function = _run_before_backward_function - if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): - ctx.module.applied_pre_backward_ref_cnt = 0 - ctx.module.applied_pre_backward_ref_cnt += 1 - - @staticmethod - def backward(ctx, *args): - ctx.pre_backward_function(ctx.module) - return args + class PreBackwardFunctionForModule(torch.autograd.Function): - else: - - class PreBackwardFunctionForModule(torch.autograd.Function): + @staticmethod + def forward(outputs): + return outputs.detach() - @staticmethod - def forward(ctx, outputs): - ctx.module = module - ctx.pre_backward_function = _run_before_backward_function - if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): - ctx.module.applied_pre_backward_ref_cnt = 0 - ctx.module.applied_pre_backward_ref_cnt += 1 - outputs = outputs.detach() - return outputs + @staticmethod + def setup_context(ctx, inputs, output): + ctx.module = module + ctx.pre_backward_function = _run_before_backward_function + if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): + ctx.module.applied_pre_backward_ref_cnt = 0 + ctx.module.applied_pre_backward_ref_cnt += 1 - @staticmethod - def backward(ctx, *args): - ctx.pre_backward_function(ctx.module) - return args + @staticmethod + def backward(ctx, *args): + ctx.pre_backward_function(ctx.module) + return args module.pre_bwd_fn = PreBackwardFunctionForModule @@ -457,64 +432,34 @@ def _run_after_backward_function(sub_module): if sub_module.ds_grads_remaining == 0: self.post_sub_module_backward_function(sub_module) - if _SUPPORTS_SETUP_CONTEXT: - - class PostBackwardFunctionModule(torch.autograd.Function): - - @staticmethod - def forward(output): - return output.detach() - - @staticmethod - def setup_context(ctx, inputs, output): - (output_in, ) = inputs - ctx.module = module - if output_in.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.post_backward_function = _run_after_backward_function - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.post_backward_function(ctx.module) - return args - - else: - - class PostBackwardFunctionModule(torch.autograd.Function): - - @staticmethod - def forward(ctx, output): - ctx.module = module - if output.requires_grad: - #TODO SOME TIMES post backward does not seem to be triggered debug in detail - #Should only cause increase in memory not correctness issue - #if output.grad_fn.__class__.__name__ == 'ViewBackward': - # ctx.view=True - # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") - #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." - #if module.ds_grads_remaining == 0: - # print(f"Before Forward: {ctx.module.__class__.__name__}") - module.ds_grads_remaining += 1 - ctx.post_backward_function = _run_after_backward_function - output = output.detach() - return output - - @staticmethod - def backward(ctx, *args): - ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 - if ctx.module.ds_grads_remaining == 0: - ctx.post_backward_function(ctx.module) - return args + class PostBackwardFunctionModule(torch.autograd.Function): + + @staticmethod + def forward(output): + return output.detach() + + @staticmethod + def setup_context(ctx, inputs, output): + (output_in, ) = inputs + ctx.module = module + if output_in.requires_grad: + #TODO SOME TIMES post backward does not seem to be triggered debug in detail + #Should only cause increase in memory not correctness issue + #if output.grad_fn.__class__.__name__ == 'ViewBackward': + # ctx.view=True + # print(f"Warning view tensor for input to module : {module.__class__.__name__}. Backward hooks may not trigger properly") + #assert len(module.parameters(recurse=False)), "The input tensor to the module is a view, and autograd Function or register_hook is not triggered with view tensors." + #if module.ds_grads_remaining == 0: + # print(f"Before Forward: {ctx.module.__class__.__name__}") + module.ds_grads_remaining += 1 + ctx.post_backward_function = _run_after_backward_function + + @staticmethod + def backward(ctx, *args): + ctx.module.ds_grads_remaining = ctx.module.ds_grads_remaining - 1 + if ctx.module.ds_grads_remaining == 0: + ctx.post_backward_function(ctx.module) + return args module.post_bwd_fn = PostBackwardFunctionModule diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py index a92ae7529dd7..38b6f40a6c0a 100644 --- a/tests/unit/v1/zero/test_zero_functorch_linear.py +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -7,7 +7,7 @@ ZeRO Stage 3 uses ``LinearFunctionForZeroStage3`` (via ``zero3_linear_wrap``) as the memory-efficient linear path. After ``deepspeed.initialize``, global ``torch.nn.functional.linear`` is often the built-in again, so tests call -``zero3_linear_wrap`` directly—the same ``autograd.Function`` as when the patch +``zero3_linear_wrap`` directly—the same ``autograd.Function`` as when the patch is active. Legacy ``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward raises on strict functorch builds:: @@ -58,8 +58,6 @@ class TestZeroFunctorchLinearRegression(DistributedTest): def test_grad_and_value_over_patched_functional_linear(self): if not hasattr(torch, "func"): pytest.skip("torch.func not available") - if not hasattr(torch.autograd.Function, "setup_context"): - pytest.skip("Requires PyTorch 2.0+ autograd.Function.setup_context") model = nn.Linear(8, 8, bias=True) engine, _, _, _ = deepspeed.initialize( @@ -84,3 +82,122 @@ def loss_fn(w, x): assert torch.isfinite(value) assert grads[0] is not None and torch.isfinite(grads[0]).all() assert grads[1] is not None and torch.isfinite(grads[1]).all() + + +class TestZeroLinearAutocast(DistributedTest): + """Verify autocast state is correctly propagated through forward and backward.""" + + world_size = 1 + + def _run_forward_backward(self, device, use_autocast, dtype=None): + """Run zero3_linear_wrap forward+backward, optionally inside autocast.""" + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + bias = torch.randn(4, device=device, dtype=torch.float32, requires_grad=True) + + if use_autocast: + with torch.amp.autocast(device_type=device.type, dtype=dtype): + out = zero3_linear_wrap(inp, weight, bias) + else: + out = zero3_linear_wrap(inp, weight, bias) + + loss = out.sum() + loss.backward() + return out, weight.grad, inp.grad, bias.grad + + def test_backward_without_autocast(self): + """Backward without autocast should produce float32 gradients.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=False) + assert out.dtype == torch.float32 + assert w_grad.dtype == torch.float32 + assert i_grad.dtype == torch.float32 + assert b_grad.dtype == torch.float32 + + def test_backward_with_autocast(self): + """Backward with autocast should produce float32 gradients (autocast only affects forward).""" + acc = get_accelerator() + if acc.is_bf16_supported(): + amp_dtype = torch.bfloat16 + elif acc.is_fp16_supported(): + amp_dtype = torch.float16 + else: + pytest.skip("No half-precision support") + + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=True, dtype=amp_dtype) + # Forward output should be in reduced precision + assert out.dtype == amp_dtype + # Gradients accumulate in float32 (master weights) + assert w_grad.dtype == torch.float32 + assert i_grad.dtype == torch.float32 + assert b_grad.dtype == torch.float32 + + def test_no_autocast_leak_into_backward(self): + """When forward runs without autocast, an outer autocast during backward must not affect gradient dtype.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + acc = get_accelerator() + if acc.is_bf16_supported(): + amp_dtype = torch.bfloat16 + elif acc.is_fp16_supported(): + amp_dtype = torch.float16 + else: + pytest.skip("No half-precision support") + + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + + # Forward WITHOUT autocast + out = zero3_linear_wrap(inp, weight, None) + assert out.dtype == torch.float32 + + # Backward WITH an outer autocast region -- should NOT affect gradient computation + # because setup_context captured _fwd_used_autocast=False + with torch.amp.autocast(device_type=device.type, dtype=amp_dtype): + out.sum().backward() + + assert weight.grad.dtype == torch.float32 + assert inp.grad.dtype == torch.float32 + + def test_setup_context_stores_autocast_attrs(self): + """setup_context must store _fwd_used_autocast and _dtype on ctx.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + + # Without autocast + out = zero3_linear_wrap(inp, weight, None) + grad_fn = out.grad_fn + # The saved context is accessible via the grad_fn + assert hasattr(grad_fn, '_saved__fwd_used_autocast') or hasattr(grad_fn, '_fwd_used_autocast') or True + # Just verify backward works and produces finite gradients + out.sum().backward() + assert torch.isfinite(weight.grad).all() From 04c456f892d90714fd581bbe3df6d8a320cc0523 Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Sun, 29 Mar 2026 12:23:12 +0900 Subject: [PATCH 14/17] change PyTorch version in README Signed-off-by: Sung Hyun Cho --- README.md | 2 +- tests/unit/v1/zero/test_zero_functorch_linear.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 507232901e4a..b7d4eaffda0e 100755 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ dynamically link them at runtime. ## Requirements * [PyTorch](https://pytorch.org/) must be installed _before_ installing DeepSpeed. -* For full feature support we recommend a version of PyTorch that is >= 1.9 and ideally the latest PyTorch stable release. +* For full feature support we recommend a version of PyTorch that is >= 2.0 and ideally the latest PyTorch stable release. * A CUDA or ROCm compiler such as [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#introduction) or [hipcc](https://github.com/ROCm-Developer-Tools/HIPCC) used to compile C++/CUDA/HIP extensions. * Specific GPUs we develop and test against are listed below, this doesn't mean your GPU will not work if it doesn't fall into this category it's just DeepSpeed is most well tested on the following: * NVIDIA: Pascal, Volta, Ampere, and Hopper architectures diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py index 38b6f40a6c0a..a89d955ea9fd 100644 --- a/tests/unit/v1/zero/test_zero_functorch_linear.py +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -7,7 +7,7 @@ ZeRO Stage 3 uses ``LinearFunctionForZeroStage3`` (via ``zero3_linear_wrap``) as the memory-efficient linear path. After ``deepspeed.initialize``, global ``torch.nn.functional.linear`` is often the built-in again, so tests call -``zero3_linear_wrap`` directly—the same ``autograd.Function`` as when the patch +``zero3_linear_wrap`` directly-the same ``autograd.Function`` as when the patch is active. Legacy ``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward raises on strict functorch builds:: From e309a6f00f6e8120cd8348c01689e8d2a4508a80 Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Mon, 30 Mar 2026 02:29:19 +0000 Subject: [PATCH 15/17] remove repro scripts Signed-off-by: Zhang Jian --- scripts/repro_pr7916.py | 100 ------------------------ scripts/setup_pr7916.sh | 168 ---------------------------------------- 2 files changed, 268 deletions(-) delete mode 100644 scripts/repro_pr7916.py delete mode 100755 scripts/setup_pr7916.sh diff --git a/scripts/repro_pr7916.py b/scripts/repro_pr7916.py deleted file mode 100644 index 83c51e44f08b..000000000000 --- a/scripts/repro_pr7916.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 -# -# Repro: functorch over ZeRO-3 memory-efficient linear (LinearFunctionForZeroStage3). -# -# Legacy autograd.Function.forward(ctx, ...) + ctx.save_for_backward in that class -# triggers (PyTorch builds that enforce functorch custom-Function rules, e.g. 2.8+): -# -# RuntimeError: In order to use an autograd.Function with functorch transforms -# (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. -# -# Why we call zero3_linear_wrap() instead of torch.nn.functional.linear: -# After deepspeed.initialize(), the global ZeRO Init context has usually ended, so -# torch.nn.functional.linear is often restored to PyTorch's built-in. That means -# F.linear in a post-init script does NOT hit LinearFunctionForZeroStage3. The -# Stage-3 patch uses zero3_linear_wrap (see partition_parameters.py); it is the -# same autograd.Function — calling it here reliably reproduces the bug on unfixed -# trees and validates the fix on fixed trees. -# -# Regression coverage: tests/unit/v1/zero/test_zero_functorch_linear.py -# -# Run from the DeepSpeed repo root (single GPU), after scripts/setup_pr7916.sh (or manually): -# torchrun --standalone --nproc_per_node=1 scripts/repro_pr7916.py -# -# To test an unfixed DeepSpeed tree without importing another checkout by mistake, -# copy this file outside the repo (e.g. /tmp) and set PYTHONPATH to that tree: -# cp scripts/repro_zero3_functorch_linear.py /tmp/ && cd /tmp && \ -# PYTHONPATH=/path/to/deepspeed-checkout torchrun --standalone --nproc_per_node=1 repro_zero3_functorch_linear.py -# -# Requires: PyTorch with torch.func and strict custom-Function checks (e.g. 2.8+), -# DeepSpeed ZeRO-3, CUDA (typical setup). - -import torch -import torch.nn as nn - -import deepspeed -from deepspeed.accelerator import get_accelerator -from deepspeed.runtime.zero.linear import zero3_linear_wrap - - -def _assert_hits_zero3_linear(weight, inp): - """Sanity check: we are exercising LinearFunctionForZeroStage3, not built-in linear.""" - with torch.enable_grad(): - y = zero3_linear_wrap(inp, weight, None) - name = type(y.grad_fn).__name__ - assert "LinearFunctionForZeroStage3" in name, ( - f"Expected LinearFunctionForZeroStage3 in grad_fn, got {name!r}. " - "Repro would not test the intended autograd.Function.") - - -def main(): - if not hasattr(torch, "func"): - raise SystemExit("This repro requires torch.func (PyTorch 2.0+).") - if not hasattr(torch.autograd.Function, "setup_context"): - raise SystemExit("This repro requires autograd.Function.setup_context (PyTorch 2.0+).") - - deepspeed.init_distributed() - acc = get_accelerator() - device = acc.device_name() + ":" + str(acc.current_device()) - - model = nn.Linear(8, 8, bias=True).to(device) - - config = { - "train_micro_batch_size_per_gpu": 1, - "steps_per_print": 2147483647, - "zero_optimization": { - "stage": 3, - "stage3_param_persistence_threshold": 0, - }, - "optimizer": {"type": "Adam", "params": {"lr": 1e-3}}, - } - if acc.is_bf16_supported(): - config["bf16"] = {"enabled": True} - elif acc.is_fp16_supported(): - config["fp16"] = {"enabled": True, "initial_scale_power": 8} - - _, _, _, _ = deepspeed.initialize( - model=model, - config=config, - model_parameters=model.parameters(), - ) - - weight = torch.randn(8, 8, device=device, dtype=model.weight.dtype, requires_grad=True) - inp = torch.randn(2, 8, device=device, dtype=model.weight.dtype, requires_grad=True) - - if deepspeed.comm.get_rank() == 0: - _assert_hits_zero3_linear(weight, inp) - - def loss_fn(w, x): - # Same op as ZeRO-3's F.linear replacement when the patch is active. - return zero3_linear_wrap(x, w, None).sum() - - torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) - if deepspeed.comm.get_rank() == 0: - print("repro: grad_and_value over zero3_linear_wrap (LinearFunctionForZeroStage3) OK.") - - -if __name__ == "__main__": - main() diff --git a/scripts/setup_pr7916.sh b/scripts/setup_pr7916.sh deleted file mode 100755 index e71944675eb9..000000000000 --- a/scripts/setup_pr7916.sh +++ /dev/null @@ -1,168 +0,0 @@ -#!/usr/bin/env bash -# PR 7916: venv at .venvs/pr7916, PyTorch 2.8 + cu128, then repro on current branch vs master. -# -# Venv: reuses $VENV_DIR if bin/activate exists (no pip). --force-install always recreates. -# --skip-install reuses only and errors if the venv is missing. -# -# Usage: ./scripts/setup_pr7916.sh [--force-install] [--skip-install] -# Env: PR7916_VENV_DIR, PR7916_MAIN_REF (default master), PR7916_FORCE_INSTALL, PR7916_SKIP_INSTALL -# -# --- Recorded test environment (original bug report / CI reference) --- -# OS: Ubuntu 22.04 -# GPU: NVIDIA H100 80GB PCIe -# Python: 3.11 -# PyTorch: 2.8.0+cu128 (torch.version.cuda: 12.8) -# DeepSpeed: 0.16.4 wheel (issue); PR validates against editable install + this script's venv -# CUDA (driver): 12.8 (via PyTorch cu128 wheels; nvcc optional / often N/A) -# Launcher: deepspeed CLI (repro uses torchrun --standalone --nproc_per_node=1) -# ------------------------------------------------------------------------- -# -set -euo pipefail - -ROOT="$(cd "$(dirname "$0")/.." && pwd)" -cd "$ROOT" - -VENV_DIR="${PR7916_VENV_DIR:-$ROOT/.venvs/pr7916}" -MAIN_REF="${PR7916_MAIN_REF:-master}" -VENV_SH="$VENV_DIR/bin/activate" - -truthy() { case "${1:-}" in 1|true|yes|on) return 0;; *) return 1;; esac; } - -force=0 -skip_only=0 -truthy "${PR7916_FORCE_INSTALL:-}" && force=1 -truthy "${PR7916_SKIP_INSTALL:-}" && skip_only=1 -while [[ $# -gt 0 ]]; do - case "$1" in - --force-install) force=1 ;; - --skip-install) skip_only=1 ;; - *) echo "error: unknown argument: $1" >&2; exit 1 ;; - esac - shift -done - -print_runtime_env() { - python <<'PY' -import platform -import sys - -import deepspeed -import torch - -print("==> Runtime environment (this session)") -print(f" python: {sys.version.split()[0]} ({platform.system()} {platform.release()})") -print(f" torch: {torch.__version__}") -print(f" torch.version.cuda: {torch.version.cuda}") -if torch.cuda.is_available(): - print(f" cuda available: yes ({torch.cuda.device_count()} device(s))") - print(f" cuda device 0: {torch.cuda.get_device_name(0)}") -else: - print(" cuda available: no") -print(f" deepspeed: {deepspeed.__version__}") -print(f" deepspeed path: {deepspeed.__file__}") -PY -} - -# Sets: full=1 → wipe + venv + pip; full=0 → activate existing only -decide_full_setup() { - if [[ "$force" -eq 1 ]]; then - echo 1 - elif [[ "$skip_only" -eq 1 ]]; then - echo 0 - elif [[ -f "$VENV_SH" ]]; then - echo 0 - else - echo 1 - fi -} - -setup_venv() { - local full - full="$(decide_full_setup)" - - if [[ "$full" -eq 0 ]]; then - [[ -f "$VENV_SH" ]] || { - echo "error: no venv at $VENV_DIR (drop --skip-install or run once without it)" >&2 - exit 1 - } - echo "==> Reusing venv $VENV_DIR (use --force-install to reinstall)" - else - echo "==> Creating venv at $VENV_DIR" - rm -rf "$VENV_DIR" - mkdir -p "$(dirname "$VENV_DIR")" - python3 -m venv "$VENV_DIR" - fi - - # shellcheck source=/dev/null - . "$VENV_SH" - - if [[ "$full" -eq 1 ]]; then - python -c 'import sys; assert sys.version_info[:2] == (3, 11), "Use Python 3.11 to match the bug report"' || { - echo "Warning: expected Python 3.11; found $(python -V)" >&2 - } - pip install -U pip setuptools wheel - pip install "torch==2.8.0" --index-url https://download.pytorch.org/whl/cu128 - pip install -r requirements/requirements.txt - pip install -e . - pip install pytest - fi - - print_runtime_env -} - -run_repro_compare() { - local REPRO_SRC="$ROOT/scripts/repro_pr7916.py" REPRO_TMP FIX_BRANCH STASHED=0 MAIN_EC - - [[ -f "$REPRO_SRC" ]] || { - echo "error: missing $REPRO_SRC (need this file on the current branch)" >&2 - exit 1 - } - - REPRO_TMP="$(mktemp /tmp/repro_pr7916_XXXXXX.py)" - cp "$REPRO_SRC" "$REPRO_TMP" - trap 'rm -f "$REPRO_TMP"' EXIT - - FIX_BRANCH="$(git rev-parse --abbrev-ref HEAD)" - local -a run=(torchrun --standalone --nproc_per_node=1) - - echo "" - echo "==> [1/2] Repro on $FIX_BRANCH (expect OK)" - "${run[@]}" "$REPRO_TMP" - - echo "" - echo "==> [2/2] Repro on $MAIN_REF (expect setup_context RuntimeError on unfixed tree)" - if [[ -n "$(git status --porcelain 2>/dev/null)" ]]; then - echo "==> Stashing local changes for checkout..." - git stash push -m "pr7916-setup: temp stash before main repro" - STASHED=1 - fi - if ! git checkout "$MAIN_REF"; then - echo "error: checkout $MAIN_REF failed" >&2 - [[ "$STASHED" -eq 1 ]] && git stash pop || true - exit 1 - fi - - set +e - "${run[@]}" "$REPRO_TMP" - MAIN_EC=$? - set -e - - if [[ "$MAIN_EC" -eq 0 ]]; then - echo "warning: main-branch repro exited 0 (expected failure on unfixed tree)." >&2 - else - echo "main-branch repro exited $MAIN_EC (non-zero expected for unfixed tree)." - fi - - echo "" - echo "==> Restoring $FIX_BRANCH" - git checkout "$FIX_BRANCH" - if [[ "$STASHED" -eq 1 ]]; then - git stash pop || echo "warning: stash pop failed — see git stash list" >&2 - fi -} - -setup_venv -run_repro_compare - -echo "" -echo "Done. Activate: source $VENV_DIR/bin/activate" From e42556976885ab1e7993d14796fac35e3b21babf Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Mon, 30 Mar 2026 02:35:50 +0000 Subject: [PATCH 16/17] update unit test Signed-off-by: Zhang Jian --- tests/unit/v1/zero/test_zero_functorch_linear.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py index a89d955ea9fd..e56c214d997a 100644 --- a/tests/unit/v1/zero/test_zero_functorch_linear.py +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -193,11 +193,11 @@ def test_setup_context_stores_autocast_attrs(self): weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) - # Without autocast + # Without autocast: setup_context must record that forward did not use autocast out = zero3_linear_wrap(inp, weight, None) grad_fn = out.grad_fn - # The saved context is accessible via the grad_fn - assert hasattr(grad_fn, '_saved__fwd_used_autocast') or hasattr(grad_fn, '_fwd_used_autocast') or True - # Just verify backward works and produces finite gradients + assert hasattr(grad_fn, "_fwd_used_autocast") + assert grad_fn._fwd_used_autocast is False + assert hasattr(grad_fn, "_dtype") out.sum().backward() assert torch.isfinite(weight.grad).all() From 39f7e3ce1c2bbaaef142b526fb1d3cd1ccb43626 Mon Sep 17 00:00:00 2001 From: Zhang Jian Date: Mon, 30 Mar 2026 02:36:22 +0000 Subject: [PATCH 17/17] drop support for pytorch<2.0 Signed-off-by: Zhang Jian --- requirements/requirements-readthedocs.txt | 2 +- requirements/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index a48a47e4428d..aaac814354c4 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -7,5 +7,5 @@ py-cpuinfo pydantic>=2.0.0 recommonmark sphinx_rtd_theme -torch +torch>=2.0.0 tqdm diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 1af4c69c5807..1bbd21dd5e32 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -7,5 +7,5 @@ packaging>=20.0 psutil py-cpuinfo pydantic>=2.0.0 -torch +torch>=2.0.0 tqdm