diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu old mode 100644 new mode 100755 index 5158630937..9d609b450e --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2898,11 +2898,25 @@ void fused_attn_fp8_bwd( void* devPtrdK = output_dK->data.dptr; void* devPtrdV = output_dV->data.dptr; void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dQ->amax.dptr; - void* devPtrAmaxdV = output_dQ->amax.dptr; + void* devPtrAmaxdK = output_dK->amax.dptr; + void* devPtrAmaxdV = output_dV->amax.dptr; void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dQ->scale.dptr; - void* devPtrScaledV = output_dQ->scale.dptr; + void* devPtrScaledK = output_dK->scale.dptr; + void* devPtrScaledV = output_dV->scale.dptr; + // For MXFP8, amax pointers are unused but cuDNN still writes to them. + // They may alias other output tensors (e.g. dSoftmaxOffset) since the allocator + // can reuse memory for these dummy scalars. Use scratch space to prevent stomping. + float* amax_dqkv_scratch = nullptr; + if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING && softmax_type != NVTE_VANILLA_SOFTMAX && + output_dSoftmaxOffset->data.dptr != nullptr && + (devPtrAmaxdQ == output_dSoftmaxOffset->data.dptr || + devPtrAmaxdK == output_dSoftmaxOffset->data.dptr || + devPtrAmaxdV == output_dSoftmaxOffset->data.dptr)) { + NVTE_CHECK_CUDA(cudaMallocAsync(&amax_dqkv_scratch, 3 * sizeof(float), stream)); + devPtrAmaxdQ = &amax_dqkv_scratch[0]; + devPtrAmaxdK = &amax_dqkv_scratch[1]; + devPtrAmaxdV = &amax_dqkv_scratch[2]; + } void* devPtrcuSeqlensQ = reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); @@ -2954,13 +2968,22 @@ void fused_attn_fp8_bwd( if (workspace->data.dptr == nullptr) { workspace->data.shape = {workspace_size}; workspace->data.dtype = DType::kByte; + if (amax_dqkv_scratch != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream)); + } return; } } else if (workspace_size == 0) { workspace->data.shape = {1}; workspace->data.dtype = DType::kByte; + if (amax_dqkv_scratch != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream)); + } return; } + if (amax_dqkv_scratch != nullptr) { + NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream)); + } } #endif // end of CUDNN>=8900 } // namespace transformer_engine