Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions transformer_engine/common/fused_attn/fused_attn_fp8.cu
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2898,11 +2898,25 @@ void fused_attn_fp8_bwd(
void* devPtrdK = output_dK->data.dptr;
void* devPtrdV = output_dV->data.dptr;
void* devPtrAmaxdQ = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dQ->amax.dptr;
void* devPtrAmaxdV = output_dQ->amax.dptr;
void* devPtrAmaxdK = output_dK->amax.dptr;
void* devPtrAmaxdV = output_dV->amax.dptr;
void* devPtrScaledQ = output_dQ->scale.dptr;
void* devPtrScaledK = output_dQ->scale.dptr;
void* devPtrScaledV = output_dQ->scale.dptr;
void* devPtrScaledK = output_dK->scale.dptr;
void* devPtrScaledV = output_dV->scale.dptr;
// For MXFP8, amax pointers are unused but cuDNN still writes to them.
// They may alias other output tensors (e.g. dSoftmaxOffset) since the allocator
// can reuse memory for these dummy scalars. Use scratch space to prevent stomping.
float* amax_dqkv_scratch = nullptr;
if (input_Q->scaling_mode == NVTE_MXFP8_1D_SCALING && softmax_type != NVTE_VANILLA_SOFTMAX &&
output_dSoftmaxOffset->data.dptr != nullptr &&
(devPtrAmaxdQ == output_dSoftmaxOffset->data.dptr ||
devPtrAmaxdK == output_dSoftmaxOffset->data.dptr ||
devPtrAmaxdV == output_dSoftmaxOffset->data.dptr)) {
NVTE_CHECK_CUDA(cudaMallocAsync(&amax_dqkv_scratch, 3 * sizeof(float), stream));
devPtrAmaxdQ = &amax_dqkv_scratch[0];
devPtrAmaxdK = &amax_dqkv_scratch[1];
devPtrAmaxdV = &amax_dqkv_scratch[2];
}

void* devPtrcuSeqlensQ =
reinterpret_cast<void*>(reinterpret_cast<int32_t*>(cu_seqlens_q->data.dptr));
Expand Down Expand Up @@ -2954,13 +2968,22 @@ void fused_attn_fp8_bwd(
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {workspace_size};
workspace->data.dtype = DType::kByte;
if (amax_dqkv_scratch != nullptr) {
NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream));
}
return;
}
} else if (workspace_size == 0) {
workspace->data.shape = {1};
workspace->data.dtype = DType::kByte;
if (amax_dqkv_scratch != nullptr) {
NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream));
}
return;
}
if (amax_dqkv_scratch != nullptr) {
NVTE_CHECK_CUDA(cudaFreeAsync(amax_dqkv_scratch, stream));
}
}
#endif // end of CUDNN>=8900
} // namespace transformer_engine