From 74a1aae66e26a0b3ebd4d151ca927a44f9d516e4 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla Date: Mon, 30 Mar 2026 02:18:59 -0700 Subject: [PATCH] fix: prevent MXFP8 amax pointer aliasing with dSoftmaxOffset in FP8 attention backward In fused_attn_fp8_bwd, the amax pointers for dQ/dK/dV all pointed to output_dQ->amax.dptr. For MXFP8 (which doesn't use per-tensor amax), PyTorch's caching allocator could reuse this memory for the dSoftmaxOffset tensor. cuDNN's second bprop kernel then wrote amax_dQ to that address, corrupting d_softmax_offset[0]. Fix: use each tensor's own amax/scale pointers, and when aliasing with dSoftmaxOffset is detected, allocate scratch space for the amax outputs. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../common/fused_attn/fused_attn_fp8.cu | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) mode change 100644 => 100755 transformer_engine/common/fused_attn/fused_attn_fp8.cu diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu old mode 100644 new mode 100755 index 5158630937..9d609b450e --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2898,11 +2898,25 @@ void fused_attn_fp8_bwd( void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dQ->amax.dptr; - void* devPtrAmaxdV = output_dQ->amax.dptr; + void* devPtrAmaxdK = output_dK->amax.dptr; + void* devPtrAmaxdV = output_dV->amax.dptr; void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dQ->scale.dptr; - void* devPtrScaledV = output_dQ->scale.dptr; + void* devPtrScaledK = output_dK->scale.dptr; + void* devPtrScaledV = output_dV->scale.dptr; + // For MXFP8, amax pointers are unused but cuDNN still writes to them. + // They may alias other output tensors (e.g. dSoftmaxOffset) since the allocator + // can reuse memory for these dummy scalars. Use scratch space to prevent stomping. + float* amax_dqkv_scratch = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING && softmax_type != NVTE_VANILLA_SOFTMAX && + output_dSoftmaxOffset->data.dptr != nullptr && + (devPtrAmaxdQ == output_dSoftmaxOffset->data.dptr || + devPtrAmaxdK == output_dSoftmaxOffset->data.dptr || + devPtrAmaxdV == output_dSoftmaxOffset->data.dptr)) { + NVTE_CHECK_CUDA(cudaMallocAsync(&amax_dqkv_scratch, 3 * sizeof(float), stream)); + devPtrAmaxdQ = &amax_dqkv_scratch[0]; + devPtrAmaxdK = &amax_dqkv_scratch[1]; + devPtrAmaxdV = &amax_dqkv_scratch[2]; + } void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); @@ -2954,13 +2968,22 @@ void fused_attn_fp8_bwd( if (workspace->data.dptr == nullptr) { workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; + if (amax_dqkv_scratch != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream)); + } return; } } else if (workspace_size == 0) { workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; + if (amax_dqkv_scratch != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream)); + } return; } + if (amax_dqkv_scratch != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream)); + } } #endif // end of CUDNN>=8900 } // namespace transformer_engine