Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/flashmask_v2/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
if constexpr (Is_causal || Is_local || Has_softcap) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Deterministic, Is_blockmask_, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
} else {
if ((params.seqlen_q >= 1024 || params.seqlen_k >= 1024) && !(Has_lt_end && Has_ut_start)) {
if (params.seqlen_q >= 1024 || params.seqlen_k >= 1024) {
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Deterministic, Is_blockmask_, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

} else {
run_mha_bwd_dispatch<Arch, T, 64, 64, 128, Is_causal, Is_local, Has_softcap, Is_flashmask_, Has_lt_end, Has_ut_start, Deterministic, Is_blockmask_, 2, 2, false, true, false, 2, 1, 2, 1, false>(params, stream);
Expand Down
32 changes: 28 additions & 4 deletions csrc/flashmask_v2/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,32 @@ class FlashAttnFwdSm90 {
// static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);

static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160);
// static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160);

static constexpr int kHeadDim = CollectiveMainloop::kHeadDim;

static constexpr uint32_t NBlockRegisterRequirement = [] {
if constexpr (kHeadDim <= 64) {
return 56;
} else {
return NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
}
}();
static constexpr uint32_t LoadRegisterRequirement = [] {
if constexpr (kHeadDim <= 64) {
return 32;
} else {
return NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? 24 : 32);
}
}();
static constexpr uint32_t MmaRegisterRequirement = [] {
if constexpr (kHeadDim <= 64) {
return 224;
} else {
return NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? 240 : 160);
}
}();

// If you want to print from the producer warp, you'd need to increase the number of registers
// Otherwise you'll get CUDA error.
Expand Down Expand Up @@ -272,7 +296,7 @@ class FlashAttnFwdSm90 {
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));

if (warp_group_idx == 0 && warp_idx_in_warpgroup != 0) { // n_block generator
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
cutlass::arch::warpgroup_reg_dealloc<NBlockRegisterRequirement>();
cutlass::PipelineState<CollectiveMainloop::kNBlockStages> n_block_pipe_write = cutlass::make_producer_start_state<MainloopPipelineNBlock>();
// Manually specify the scheduler role: producer. For StaticPersistentTileSch, passing template args won't change the behavior
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
Expand Down Expand Up @@ -556,4 +580,4 @@ class FlashAttnFwdSm90 {

};

} // namespace flash
} // namespace flash
133 changes: 83 additions & 50 deletions csrc/flashmask_v2/mainloop_bwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1123,8 +1123,6 @@ struct CollectiveMainloopBwdSm90 {
);
// int32_t flashmask_mem_[8]s;
// load_n_block_info(n_block, flashmask_mem_, params);

int m_block = m_block_min;
// if(thread_idx == 0) printf("m_block:%d",m_block);
// get_next_m_block(n_block,m_block,partially_masked,m_block_max - 1,params);

Expand Down Expand Up @@ -1330,76 +1328,111 @@ struct CollectiveMainloopBwdSm90 {
// this helps quite a bit to not have to do causal masking for most of the iterations.

auto mask_fn = [&](auto& tSrS, int m_block) { mask.template apply<true /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
int loop_end = m_block_max;
if constexpr(!Is_causal){
if constexpr (Has_ut_start) {
loop_end = std::min(flashmask_mem_[5]/*ut_start_nblockmin*/, m_block_max);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block < loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;

// State machine reduces bwd_step call sites from 7 to 1, lowering register pressure.
// Only beneficial when all 9 states are active (!Is_causal && Has_ut_start && Has_lt_end).
// Other cases use the original multi-for-loop (fewer call sites, compiler can optimize each loop).
if constexpr (!Is_causal && Has_ut_start && Has_lt_end) {
// State machine version: single bwd_step call site for all 9 states.
// state -----------------> 0 1 (2 3) 4 5 (6 7) 8
// is partially masked ---> F T ( T ) F T ( T ) F
// smem_pos = state ^ 5 --> 5 4 (7 6) 1 0 (3 2) N/A
// Combined states (2,3)/(6,7): smem_index = nblockmin, smem_index-1 = nblockmax.
// State 0/4: fm[x]-1 as loop_end for strict < semantics.
int loop_end = m_block_min - 1;
CUTLASS_PRAGMA_NO_UNROLL
for (int state = -1, m_block = m_block_min; ; m_block++) {
while (m_block > loop_end) {
++ state;
if (state >= 9) break;
const int smem_index = state ^ 0x5;
if (state < 8) {
if ((state & 0x03) == 0x02) {
m_block = std::max(flashmask_mem_[smem_index], m_block);
loop_end = std::min(flashmask_mem_[smem_index - 1], m_block_max - 1);
++ state;
} else {
loop_end = std::min(flashmask_mem_[smem_index] - !(state & 0x03), m_block_max - 1);
}
} else {
loop_end = m_block_max - 1;
}
}
if (state >= 9) break;
if constexpr (Is_blockmask) {
if (blockmask_smem_[m_block / params.m_factor]) {
bwd_step(m_block, mask_fn, (state & 0x3) > 0, flashmask_index_smem_);
}
} else {
bwd_step(m_block, mask_fn, (state & 0x3) > 0, flashmask_index_smem_);
}
}
} else {
// Original multi-for-loop version.
int m_block = m_block_min;
int loop_end = m_block_max;
if constexpr(!Is_causal){
if constexpr (Has_ut_start) {
loop_end = std::min(flashmask_mem_[5]/*ut_start_nblockmin*/, m_block_max);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block < loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
bwd_step(m_block, mask_fn, false, flashmask_index_smem_);
}
loop_end = flashmask_mem_[4]/*ut_start_nblockmax*/;
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block <= loop_end; ++m_block) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
bwd_step(m_block, mask_fn, true, flashmask_index_smem_);
}
// if(threadIdx.x == 128) printf("consumer0 m_block,n_block: %d, %d\n", m_block,n_block);
bwd_step(m_block, mask_fn, false, flashmask_index_smem_);
}
loop_end = flashmask_mem_[4]/*ut_start_nblockmax*/;
m_block = std::max(m_block, flashmask_mem_[7]/*ut_end_nblockmin*/);
loop_end = std::min(flashmask_mem_[6]/*ut_end_nblockmax*/, m_block_max - 1);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block <= loop_end; ++m_block) {
for (; m_block <= loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
// if(threadIdx.x == 128) printf("consumer0 m_block,n_block: %d, %d\n", m_block,n_block);
bwd_step(m_block, mask_fn, true, flashmask_index_smem_);
}
}
m_block = std::max(m_block, flashmask_mem_[7]/*ut_end_nblockmin*/);
loop_end = std::min(flashmask_mem_[6]/*ut_end_nblockmax*/, m_block_max - 1);
loop_end = std::min(flashmask_mem_[1]/*lt_start_nblockmin*/, m_block_max);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block <= loop_end; m_block++) {
for (; m_block < loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
// if(threadIdx.x == 128) printf("consumer-u-2 m_block,n_block,m_block_max,flashmask_mem_[2]: %d, %d, %d,%d\n", m_block,n_block,m_block_max,flashmask_mem_[6]);
bwd_step(m_block, mask_fn, true, flashmask_index_smem_);
}
}
loop_end = std::min(flashmask_mem_[1]/*lt_start_nblockmin*/, m_block_max);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block < loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
// if(threadIdx.x == 128) printf("consumer-l-0 m_block,n_block: %d, %d\n", m_block,n_block);
bwd_step(m_block, mask_fn, false, flashmask_index_smem_);
}
//partial_maskloop_end
loop_end = std::min(m_block_max - 1, flashmask_mem_[0]/*lt_start_nblockmax*/);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block <= loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
bwd_step(m_block, mask_fn, false, flashmask_index_smem_);
}
// if(threadIdx.x == 128) printf("consumer-l-1 m_block,n_block, flashmask_mem_[0]: %d, %d, %d\n", m_block,n_block,flashmask_mem_[0]);
bwd_step(m_block, mask_fn, true, flashmask_index_smem_);
}
if constexpr (Has_lt_end) {
m_block = std::max(m_block, flashmask_mem_[3]/*lt_end_nblockmin*/);
//partial_maskloop_end
loop_end = std::min(flashmask_mem_[2]/*lt_end_nblockmax*/, m_block_max - 1);
loop_end = std::min(m_block_max - 1, flashmask_mem_[0]/*lt_start_nblockmax*/);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block <= loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
// if(threadIdx.x == 128) printf("consumer2 m_block,n_block,m_block_max,flashmask_mem_[2]: %d, %d, %d,%d\n", m_block,n_block,m_block_max,flashmask_mem_[2]);
bwd_step(m_block, mask_fn, true, flashmask_index_smem_);
}
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block < m_block_max; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
if constexpr (Has_lt_end) {
m_block = std::max(m_block, flashmask_mem_[3]/*lt_end_nblockmin*/);
loop_end = std::min(flashmask_mem_[2]/*lt_end_nblockmax*/, m_block_max - 1);
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block <= loop_end; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
bwd_step(m_block, mask_fn, true, flashmask_index_smem_);
}
CUTLASS_PRAGMA_NO_UNROLL
for (; m_block < m_block_max; m_block++) {
if constexpr (Is_blockmask){
if(!blockmask_smem_[m_block / params.m_factor]) continue;
}
bwd_step(m_block, mask_fn, false, flashmask_index_smem_);
}
bwd_step(m_block, mask_fn, false, flashmask_index_smem_);
}
}

Expand Down
2 changes: 1 addition & 1 deletion csrc/flashmask_v2/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ struct CollectiveMainloopFwdSm90 {

// These are tuned for speed. They don't affect correctness.
static constexpr bool UseSchedulerBarrier = (IntraWGOverlap
? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim <= 128 : kHeadDim >= 128)
? (NumMmaWarpGroups >= 2) && (!Is_FP8 ? kHeadDim < 128 : kHeadDim >= 128)
: NumMmaWarpGroups == 2)
&& !LargeHeadDimV;
static constexpr bool RescaleOBeforeGemm = kHeadDim > 128 && (!Is_FP8 || V_colmajor) && IntraWGOverlap;
Expand Down
7 changes: 4 additions & 3 deletions csrc/flashmask_v2/tile_size.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@ constexpr std::tuple<int, int, bool, bool> tile_size_fwd_sm90(
return {64, 64, false, true};
}
if (headdim <= 64) {
bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512
// bool same_hdim = (headdim == headdim_v); // if not same hdim, we're targeting hdimv=512
// return {same_hdim ? 192 : 64, same_hdim ? 128 : 64, same_hdim, same_hdim};
// With this workaround in Cutlass 3.8, tile size 192 x 128 got slower for non-causal, idk why
// https://github.com/NVIDIA/cutlass/blob/833f6990e031b48b4cd2fcf55e0849c51ef6bac2/include/cute/container/tuple.hpp#L131
// Switch to tile size 192 x 192 for now
bool const use_blockN_128 = is_causal || is_local;
// bool const use_blockN_128 = is_causal || is_local;
// return {same_hdim ? 192 : 64, same_hdim ? (use_blockN_128 ? 128 : 192) : 64, same_hdim && use_blockN_128, same_hdim};
return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim};
// return {192, use_blockN_128 ? 80 : 144, same_hdim && use_blockN_128, same_hdim};
// Good for long seqlen (>= 4k) but suffers from tile quantization at short seqlen
// return {192, is_causal || is_local ? 192 : 176, true, false};
return {128, 128, true, true};
} else if (headdim <= 96) {
return {192, is_local || paged_kv_non_TMA ? 128 : 144, false, true};
} else if (headdim <= 128) {
Expand Down